1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//! # Licensing
//! This Source Code is subject to the terms of the Mozilla Public License
//! version 2.0 (the "License"). You can obtain a copy of the License at
//! http://mozilla.org/MPL/2.0/.

use crate::AbsError;
use crate::ApproEqError;
use crate::ApproEqResult;
use crate::RelError;
use ndarray::{ArrayBase, Axis, Data, Dimension};

fn max<D: PartialOrd, T: Iterator<Item = ApproEqResult<D>>>(iter: T) -> ApproEqResult<D> {
    iter.fold(Ok(None), move |m, i| {
        if match (&m, &i) {
            (&Err(_), _) => false,
            (_, &Err(_)) => true,
            (&Ok(ref m), &Ok(ref i)) => match (m, i) {
                (&None, _) => true,
                (_, &None) => false,
                (&Some(ref m), &Some(ref i)) => i > m,
            },
        } {
            i
        } else {
            m
        }
    })
}

#[cfg_attr(feature = "docs", stable(feature = "ndarray", since = "0.1.0"))]
impl<A: Data, B: PartialOrd, C: Data, D: Dimension> AbsError<ArrayBase<A, D>, B> for ArrayBase<C, D>
where
    C::Elem: AbsError<A::Elem, B> + Sized,
{
    fn abs_error(&self, expected: &ArrayBase<A, D>) -> ApproEqResult<B> {
        if self.ndim() != expected.ndim() {
            return Err(ApproEqError::LengthMismatch);
        }
        for n in 0..self.ndim() {
            if self.len_of(Axis(n)) != expected.len_of(Axis(n)) {
                return Err(ApproEqError::LengthMismatch);
            }
        }

        max(self
            .iter()
            .zip(expected.iter())
            .map(move |(i, j)| i.abs_error(j)))
    }
}

#[cfg_attr(feature = "docs", stable(feature = "ndarray", since = "0.1.0"))]
impl<A: Data, B: PartialOrd, C: Data, D: Dimension> RelError<ArrayBase<A, D>, B> for ArrayBase<C, D>
where
    C::Elem: RelError<A::Elem, B> + Sized,
{
    fn rel_error(&self, expected: &ArrayBase<A, D>) -> ApproEqResult<B> {
        if self.ndim() != expected.ndim() {
            return Err(ApproEqError::LengthMismatch);
        }
        for n in 0..self.ndim() {
            if self.len_of(Axis(n)) != expected.len_of(Axis(n)) {
                return Err(ApproEqError::LengthMismatch);
            }
        }

        max(self
            .iter()
            .zip(expected.iter())
            .map(move |(i, j)| i.rel_error(j)))
    }
}