rstsr_dtype_traits/
isclose.rs1use crate::*;
2use core::ops::*;
3use derive_builder::Builder;
4
5#[derive(Builder, Clone, PartialEq, Eq, Debug)]
18pub struct IsCloseArgs<TE: 'static> {
19 #[builder(default = "default_rtol()?")]
21 pub rtol: TE,
22 #[builder(default = "default_atol()?")]
24 pub atol: TE,
25 #[builder(default = "false")]
27 pub equal_nan: bool,
28}
29
30fn default_rtol<TE: 'static>() -> Result<TE, String> {
31 use core::any::*;
32 if TypeId::of::<TE>() == TypeId::of::<f64>() {
33 Ok(unsafe { core::mem::transmute_copy::<f64, TE>(&1.0e-5_f64) })
34 } else {
35 let type_name = type_name::<TE>();
36 Err(format!("default rtol is not defined for type `{}`. Please specify `rtol` explicitly.", type_name))
37 }
38}
39
40fn default_atol<TE: 'static>() -> Result<TE, String> {
41 use core::any::*;
42 if TypeId::of::<TE>() == TypeId::of::<f64>() {
43 Ok(unsafe { core::mem::transmute_copy::<f64, TE>(&1.0e-8_f64) })
44 } else {
45 let type_name = type_name::<TE>();
46 Err(format!("default atol is not defined for type `{}`. Please specify `atol` explicitly.", type_name))
47 }
48}
49
50impl Default for IsCloseArgs<f64> {
51 fn default() -> Self {
52 Self { rtol: 1.0e-5, atol: 1.0e-8, equal_nan: false }
53 }
54}
55
56#[inline]
92pub fn isclose<TA, TB, TE>(a: &TA, b: &TB, args: &IsCloseArgs<TE>) -> bool
93where
94 TA: Clone + DTypePromoteAPI<TB>,
95 TB: Clone,
96 <TA as DTypePromoteAPI<TB>>::Res: ExtNum<AbsOut: DTypeCastAPI<TE>>,
97 TE: ExtFloat + Add<TE, Output = TE> + Mul<TE, Output = TE> + PartialOrd + Clone,
98{
99 let IsCloseArgs { rtol, atol, equal_nan } = args;
100 let (a, b) = TA::promote_pair(a.clone(), b.clone());
101 let diff: TE = a.clone().ext_abs_diff(b.clone()).into_cast();
102 let abs_b: TE = b.clone().ext_abs().into_cast();
103 let comp = diff <= atol.clone() + rtol.clone() * abs_b;
104 let nan_check = *equal_nan && a.is_nan() && b.is_nan();
105 comp || nan_check
106}
107
108impl From<f64> for IsCloseArgs<f64> {
109 #[inline]
110 fn from(rtol: f64) -> Self {
111 Self { rtol, atol: 1.0e-8, equal_nan: false }
112 }
113}
114
115impl From<(f64,)> for IsCloseArgs<f64> {
116 #[inline]
117 fn from(v: (f64,)) -> Self {
118 let (rtol,) = v;
119 Self { rtol, atol: 1.0e-8, equal_nan: false }
120 }
121}
122
123impl From<(f64, f64)> for IsCloseArgs<f64> {
124 #[inline]
125 fn from(v: (f64, f64)) -> Self {
126 let (rtol, atol) = v;
127 Self { rtol, atol, equal_nan: false }
128 }
129}
130
131impl From<(f64, f64, bool)> for IsCloseArgs<f64> {
132 #[inline]
133 fn from(v: (f64, f64, bool)) -> Self {
134 let (rtol, atol, equal_nan) = v;
135 Self { rtol, atol, equal_nan }
136 }
137}
138
139impl From<Option<f64>> for IsCloseArgs<f64> {
140 #[inline]
141 fn from(rtol: Option<f64>) -> Self {
142 match rtol {
143 Some(rtol) => Self { rtol, atol: 1.0e-8, equal_nan: false },
144 None => Self { rtol: 1.0e-5, atol: 1.0e-8, equal_nan: false },
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 #[test]
153 fn test_isclose_f64() {
154 let a = 1.00001_f64;
155 let b = 1.00002_f64;
156 let args = None.into();
157 assert!(isclose(&a, &b, &args));
158 let args = IsCloseArgsBuilder::default().rtol(1.0e-6).atol(1.0e-9).equal_nan(false).build().unwrap();
159 assert!(!isclose(&a, &b, &args));
160 }
161
162 #[test]
163 fn test_isclose_usize() {
164 let a: usize = 100;
165 let b: usize = 102;
166 let args = None.into();
167 assert!(!isclose(&a, &b, &args));
168 }
169
170 #[test]
171 fn test_isclose_usize_c32() {
172 use num::Complex;
173 let a: usize = 100;
174 let b: Complex<f32> = Complex::new(100.0, 0.0);
175 let args = None.into();
176 assert!(isclose(&a, &b, &args));
177 let c: Complex<f32> = Complex::new(100.01, 0.0);
178 assert!(!isclose(&a, &c, &args));
179 }
180}