rstsr_dtype_traits/
isclose.rs

1use crate::*;
2use core::ops::*;
3use derive_builder::Builder;
4
5/// Arguments for `isclose` function.
6///
7/// For type [f64], you can also use the `From` trait implementations to create
8/// `IsCloseArgs<f64>` from simpler types:
9/// - `f64`: specifies `rtol`;
10/// - `(f64, f64)`: specifies `rtol` and `atol`;
11/// - `(f64, f64, bool)`: specifies `rtol`, `atol`, and `equal_nan`;
12/// - `None`: uses default values.
13///
14/// # See also
15///
16/// [`isclose`](isclose())
17#[derive(Builder, Clone, PartialEq, Eq, Debug)]
18pub struct IsCloseArgs<TE: 'static> {
19    /// Relative tolerance. For type [f64], the default is `1.0e-5`.
20    #[builder(default = "default_rtol()?")]
21    pub rtol: TE,
22    /// Absolute tolerance. For type [f64], the default is `1.0e-8`.
23    #[builder(default = "default_atol()?")]
24    pub atol: TE,
25    /// Whether to consider NaNs as equal. For type [f64], the default is `false`.
26    #[builder(default = "false")]
27    pub equal_nan: bool,
28}
29
30fn default_rtol<TE: 'static>() -> Result<TE, String> {
31    use core::any::*;
32    if TypeId::of::<TE>() == TypeId::of::<f64>() {
33        Ok(unsafe { core::mem::transmute_copy::<f64, TE>(&1.0e-5_f64) })
34    } else {
35        let type_name = type_name::<TE>();
36        Err(format!("default rtol is not defined for type `{}`. Please specify `rtol` explicitly.", type_name))
37    }
38}
39
40fn default_atol<TE: 'static>() -> Result<TE, String> {
41    use core::any::*;
42    if TypeId::of::<TE>() == TypeId::of::<f64>() {
43        Ok(unsafe { core::mem::transmute_copy::<f64, TE>(&1.0e-8_f64) })
44    } else {
45        let type_name = type_name::<TE>();
46        Err(format!("default atol is not defined for type `{}`. Please specify `atol` explicitly.", type_name))
47    }
48}
49
50impl Default for IsCloseArgs<f64> {
51    fn default() -> Self {
52        Self { rtol: 1.0e-5, atol: 1.0e-8, equal_nan: false }
53    }
54}
55
56/// Checks whether two numbers are close to each other within a given tolerance.
57///
58/// # Notes to definition of closeness
59///
60/// For finite values, isclose uses the following equation to test whether two floating point values
61/// are equivalent:
62///
63/// ```text
64/// |a - b| <= atol + rtol * |b|
65/// ```
66///
67/// Note that this equation is not symmetric in `a` and `b`: it assumes that `b` is the reference
68/// value; so that `isclose(a, b, args)` may not be the same as `isclose(b, a, args)`.
69///
70/// # Notes to arguments
71///
72/// The argument `args` in this function should be usually of type [f64]. You can create it
73/// - manually by specifying all fields;
74/// - by using the `From` trait implementations for type [`IsCloseArgs<f64>`]:
75///   - `f64`: specifies `rtol`;
76///   - `(f64, f64)`: specifies `rtol` and `atol`;
77///   - `(f64, f64, bool)`: specifies `rtol`, `atol`, and `equal_nan`;
78///   - `None`: uses default values.
79/// - by using the builder pattern:
80///
81/// ```rust
82/// # use rstsr_dtype_traits::{isclose, IsCloseArgs, IsCloseArgsBuilder};
83/// let args_by_builder = IsCloseArgsBuilder::<f64>::default()
84///     .rtol(1.0e-6)
85///     .atol(1.0e-9)
86///     .equal_nan(true)
87///     .build().unwrap();
88/// let args_by_tuple: IsCloseArgs<f64> = (1.0e-6, 1.0e-9, true).into();
89/// assert_eq!(args_by_builder, args_by_tuple);
90/// ```
91#[inline]
92pub fn isclose<TA, TB, TE>(a: &TA, b: &TB, args: &IsCloseArgs<TE>) -> bool
93where
94    TA: Clone + DTypePromoteAPI<TB>,
95    TB: Clone,
96    <TA as DTypePromoteAPI<TB>>::Res: ExtNum<AbsOut: DTypeCastAPI<TE>>,
97    TE: ExtFloat + Add<TE, Output = TE> + Mul<TE, Output = TE> + PartialOrd + Clone,
98{
99    let IsCloseArgs { rtol, atol, equal_nan } = args;
100    let (a, b) = TA::promote_pair(a.clone(), b.clone());
101    let diff: TE = a.clone().ext_abs_diff(b.clone()).into_cast();
102    let abs_b: TE = b.clone().ext_abs().into_cast();
103    let comp = diff <= atol.clone() + rtol.clone() * abs_b;
104    let nan_check = *equal_nan && a.is_nan() && b.is_nan();
105    comp || nan_check
106}
107
108impl From<f64> for IsCloseArgs<f64> {
109    #[inline]
110    fn from(rtol: f64) -> Self {
111        Self { rtol, atol: 1.0e-8, equal_nan: false }
112    }
113}
114
115impl From<(f64,)> for IsCloseArgs<f64> {
116    #[inline]
117    fn from(v: (f64,)) -> Self {
118        let (rtol,) = v;
119        Self { rtol, atol: 1.0e-8, equal_nan: false }
120    }
121}
122
123impl From<(f64, f64)> for IsCloseArgs<f64> {
124    #[inline]
125    fn from(v: (f64, f64)) -> Self {
126        let (rtol, atol) = v;
127        Self { rtol, atol, equal_nan: false }
128    }
129}
130
131impl From<(f64, f64, bool)> for IsCloseArgs<f64> {
132    #[inline]
133    fn from(v: (f64, f64, bool)) -> Self {
134        let (rtol, atol, equal_nan) = v;
135        Self { rtol, atol, equal_nan }
136    }
137}
138
139impl From<Option<f64>> for IsCloseArgs<f64> {
140    #[inline]
141    fn from(rtol: Option<f64>) -> Self {
142        match rtol {
143            Some(rtol) => Self { rtol, atol: 1.0e-8, equal_nan: false },
144            None => Self { rtol: 1.0e-5, atol: 1.0e-8, equal_nan: false },
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    #[test]
153    fn test_isclose_f64() {
154        let a = 1.00001_f64;
155        let b = 1.00002_f64;
156        let args = None.into();
157        assert!(isclose(&a, &b, &args));
158        let args = IsCloseArgsBuilder::default().rtol(1.0e-6).atol(1.0e-9).equal_nan(false).build().unwrap();
159        assert!(!isclose(&a, &b, &args));
160    }
161
162    #[test]
163    fn test_isclose_usize() {
164        let a: usize = 100;
165        let b: usize = 102;
166        let args = None.into();
167        assert!(!isclose(&a, &b, &args));
168    }
169
170    #[test]
171    fn test_isclose_usize_c32() {
172        use num::Complex;
173        let a: usize = 100;
174        let b: Complex<f32> = Complex::new(100.0, 0.0);
175        let args = None.into();
176        assert!(isclose(&a, &b, &args));
177        let c: Complex<f32> = Complex::new(100.01, 0.0);
178        assert!(!isclose(&a, &c, &args));
179    }
180}