rs_stats/utils/
numeric.rs1use num_traits::NumCast;
2use std::fmt::Debug;
4
5pub fn safe_log(x: f64) -> Result<f64, String> {
16 if x <= 0.0 {
17 Err("Logarithm is only defined for positive numbers.".to_string())
18 } else {
19 Ok(x.ln())
20 }
21}
22
23pub fn approx_equal<T, U>(a: T, b: U, epsilon: Option<f64>) -> bool
37where
38 T: NumCast + Copy + Debug,
39 U: NumCast + Copy + Debug,
40{
41 let a_f64 = match T::to_f64(&a) {
43 Some(val) => val,
44 None => return false, };
46
47 let b_f64 = match U::to_f64(&b) {
48 Some(val) => val,
49 None => return false, };
51 let eps = epsilon.unwrap_or(1e-10);
52
53 if a_f64.is_nan() || b_f64.is_nan() {
55 return false;
56 }
57
58 if a_f64.is_infinite() && b_f64.is_infinite() {
59 return (a_f64 > 0.0 && b_f64 > 0.0) || (a_f64 < 0.0 && b_f64 < 0.0);
60 }
61
62 let abs_diff = (a_f64 - b_f64).abs();
64
65 if a_f64.abs() < eps || b_f64.abs() < eps {
67 return abs_diff <= eps;
68 }
69
70 let rel_diff = abs_diff / f64::max(a_f64.abs(), b_f64.abs());
72 rel_diff <= eps
73}
74
75pub fn approx_eq<T, U>(a: T, b: U) -> bool
77where
78 T: NumCast + Copy + Debug,
79 U: NumCast + Copy + Debug,
80{
81 approx_equal(a, b, None)
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn test_float_equality() {
90 assert!(approx_equal(1.0, 1.0, None));
91 assert!(approx_equal(1.0, 1.0000000001, Some(1e-9)));
92 assert!(!approx_equal(1.0, 1.0000000001, Some(1e-10)));
93 }
94
95 #[test]
96 fn test_integer_equality() {
97 assert!(approx_equal(1i32, 1i32, None));
98 assert!(approx_equal(1000i32, 1000, None));
99 assert!(approx_equal(1000u64, 1000.0001, Some(1e-6)));
100 assert!(!approx_equal(1000i32, 1001i32, None));
101 }
102
103 #[test]
104 fn test_mixed_type_equality() {
105 assert!(approx_equal(1i32, 1.0f64, None));
106 assert!(approx_equal(1000u16, 1000.0f32, None));
107 assert!(approx_equal(0i8, 0.0, None));
108 assert!(!approx_equal(5u8, 5.1f64, None));
109 }
110
111 #[test]
112 fn test_edge_cases() {
113 assert!(!approx_equal(f64::NAN, f64::NAN, None));
114 assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
115 assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
116 assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
117 assert!(!approx_equal(f64::INFINITY, 1e100, None));
118 }
119
120 #[test]
121 fn test_near_zero() {
122 assert!(approx_equal(0.0, 1e-11, None));
123 assert!(!approx_equal(0.0, 1e-9, None));
124 }
125}