candle_approx/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use candle_core::{Result, Tensor};

pub const DEFAULT_EPSILON: f64 = 1e-4;
pub const DEFAULT_MAX_RELATIVE: f64 = 1e-4;

pub fn abs_diff_eq(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<Tensor> {
    a.sub(b)?.abs()?.le(epsilon)
}

pub fn relative_eq(a: &Tensor, b: &Tensor, epsilon: f64, max_relative: f64) -> Result<Tensor> {
    let norm2 = a.abs()?.add(&b.abs()?)?;
    let diff = a.sub(b)?.abs()?;
    diff.le(&(norm2 * (0.5 * max_relative))?.maximum(epsilon)?)
}

pub fn all(a: &Tensor) -> Result<bool> {
    Ok(a.flatten_all()?.min(0)?.to_scalar::<u8>()? != 0)
}

#[macro_export]
macro_rules! assert_eq {
    ($a:expr, $b:expr $(,)?) => {{
        let (a, b) = (&$a, &$b);
        assert!(
            $crate::all(&a.eq(b).unwrap()).unwrap(),
            "assert_eq failed:\n{}:\n{}\n{}:\n{}",
            stringify!($a),
            a,
            stringify!($b),
            b,
        );
    }};
}

#[macro_export]
macro_rules! assert_abs_diff_eq {
    ($a:expr, $b:expr, $eps:expr $(,)?) => {{
        let (a, b) = (&$a, &$b);
        let eps = $eps;
        assert!(
            $crate::all(&$crate::abs_diff_eq(a, b, eps).unwrap()).unwrap(),
            "assert_abs_diff_eq failed (epsilon = {:.0e}):\n{}:\n{}\n{}:\n{}",
            eps,
            stringify!($a),
            a,
            stringify!($b),
            b,
        );
    }};
    ($a:expr, $b:expr $(,)?) => {{
        assert_abs_diff_eq!($a, $b, $crate::DEFAULT_EPSILON);
    }};
}

#[macro_export]
macro_rules! assert_relative_eq {
    ($a:expr, $b:expr, $eps:expr, $max_rel:expr $(,)?) => {{
        let (a, b) = (&$a, &$b);
        let (eps, max_rel) = ($eps, $max_rel);
        assert!(
            $crate::all(&$crate::relative_eq(a, b, eps, max_rel).unwrap()).unwrap(),
            "assert_relative_eq failed (epsilon = {:.0e}, max_relative = {:.0e}):\n{}:\n{}\n{}:\n{}",
            eps,
            max_rel,
            stringify!($a),
            a,
            stringify!($b),
            b,
        );
    }};
    ($a:expr, $b:expr $(,)?) => {{
        assert_relative_eq!(
            $a,
            $b,
            $crate::DEFAULT_EPSILON,
            $crate::DEFAULT_MAX_RELATIVE,
        );
    }};
}

#[cfg(test)]
mod tests;