1use num_traits::NumCast;
2use std::fmt::Debug;
4
5use crate::error::{StatsError, StatsResult};
6
7pub fn safe_log(x: f64) -> StatsResult<f64> {
30 if x <= 0.0 {
31 Err(StatsError::invalid_input(
32 "Logarithm is only defined for positive numbers.",
33 ))
34 } else {
35 Ok(x.ln())
36 }
37}
38
39pub fn approx_equal<T, U>(a: T, b: U, epsilon: Option<f64>) -> bool
53where
54 T: NumCast + Copy + Debug,
55 U: NumCast + Copy + Debug,
56{
57 let a_f64 = match T::to_f64(&a) {
59 Some(val) => val,
60 None => return false, };
62
63 let b_f64 = match U::to_f64(&b) {
64 Some(val) => val,
65 None => return false, };
67 let eps = epsilon.unwrap_or(1e-10);
68
69 if a_f64.is_nan() || b_f64.is_nan() {
71 return false;
72 }
73
74 if a_f64.is_infinite() && b_f64.is_infinite() {
75 return (a_f64 > 0.0 && b_f64 > 0.0) || (a_f64 < 0.0 && b_f64 < 0.0);
76 }
77
78 let abs_diff = (a_f64 - b_f64).abs();
80
81 if a_f64.abs() < eps || b_f64.abs() < eps {
83 return abs_diff <= eps;
84 }
85
86 let rel_diff = abs_diff / f64::max(a_f64.abs(), b_f64.abs());
88 rel_diff <= eps
89}
90
91pub fn approx_eq<T, U>(a: T, b: U) -> bool
93where
94 T: NumCast + Copy + Debug,
95 U: NumCast + Copy + Debug,
96{
97 approx_equal(a, b, None)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_float_equality() {
106 assert!(approx_equal(1.0, 1.0, None));
107 assert!(approx_equal(1.0, 1.0000000001, Some(1e-9)));
108 assert!(!approx_equal(1.0, 1.0000000001, Some(1e-10)));
109 }
110
111 #[test]
112 fn test_integer_equality() {
113 assert!(approx_equal(1i32, 1i32, None));
114 assert!(approx_equal(1000i32, 1000, None));
115 assert!(approx_equal(1000u64, 1000.0001, Some(1e-6)));
116 assert!(!approx_equal(1000i32, 1001i32, None));
117 }
118
119 #[test]
120 fn test_mixed_type_equality() {
121 assert!(approx_equal(1i32, 1.0f64, None));
122 assert!(approx_equal(1000u16, 1000.0f32, None));
123 assert!(approx_equal(0i8, 0.0, None));
124 assert!(!approx_equal(5u8, 5.1f64, None));
125 }
126
127 #[test]
128 fn test_edge_cases() {
129 assert!(!approx_equal(f64::NAN, f64::NAN, None));
130 assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
131 assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
132 assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
133 assert!(!approx_equal(f64::INFINITY, 1e100, None));
134 }
135
136 #[test]
137 fn test_near_zero() {
138 assert!(approx_equal(0.0, 1e-11, None));
139 assert!(!approx_equal(0.0, 1e-9, None));
140 }
141
142 #[test]
148 fn test_approx_equal_infinity_combinations() {
149 assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
151 assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
152 assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
153 assert!(!approx_equal(f64::NEG_INFINITY, f64::INFINITY, None));
154 assert!(!approx_equal(f64::INFINITY, 0.0, None));
155 assert!(!approx_equal(f64::NEG_INFINITY, 0.0, None));
156 }
157
158 #[test]
159 fn test_approx_equal_nan_combinations() {
160 assert!(!approx_equal(f64::NAN, f64::NAN, None));
162 assert!(!approx_equal(f64::NAN, 0.0, None));
163 assert!(!approx_equal(0.0, f64::NAN, None));
164 assert!(!approx_equal(f64::NAN, f64::INFINITY, None));
165 assert!(!approx_equal(f64::INFINITY, f64::NAN, None));
166 }
167
168 #[test]
169 fn test_approx_equal_relative_difference() {
170 assert!(approx_equal(1000.0, 1000.1, Some(1e-3)));
173 assert!(approx_equal(1000.0, 1001.0, Some(1e-3)));
176 assert!(!approx_equal(1000.0, 1001.0, Some(1e-4)));
178 }
179
180 #[test]
181 fn test_approx_equal_absolute_difference_near_zero() {
182 assert!(approx_equal(1e-11, 0.0, None));
184 assert!(approx_equal(0.0, 1e-11, None));
185 assert!(!approx_equal(1e-9, 0.0, None));
186 }
187
188 #[test]
189 fn test_safe_log_positive() {
190 let result = safe_log(1.0);
191 assert!(result.is_ok());
192 assert_eq!(result.unwrap(), 0.0);
193 }
194
195 #[test]
196 fn test_safe_log_zero() {
197 let result = safe_log(0.0);
198 assert!(result.is_err());
199 assert!(matches!(
200 result.unwrap_err(),
201 StatsError::InvalidInput { .. }
202 ));
203 }
204
205 #[test]
206 fn test_safe_log_negative() {
207 let result = safe_log(-1.0);
208 assert!(result.is_err());
209 assert!(matches!(
210 result.unwrap_err(),
211 StatsError::InvalidInput { .. }
212 ));
213 }
214
215 #[test]
216 fn test_safe_log_known_value() {
217 let result = safe_log(std::f64::consts::E);
219 assert!(result.is_ok());
220 assert!((result.unwrap() - 1.0).abs() < 1e-10);
221 }
222}