rs_stats/utils/
numeric.rs

1use num_traits::NumCast;
2/// Provides numerical utility functions for statistical calculations.
3use std::fmt::Debug;
4
5/// Computes the natural logarithm of x, handling edge cases safely.
6///
7/// # Arguments
8/// * `x` - The input number.
9///
10/// # Returns
11/// * `Result<f64>` - The natural logarithm of x, or an error message if x is invalid.
12///
13/// # Errors
14/// Returns an error if x is less than or equal to 0.
15pub fn safe_log(x: f64) -> Result<f64, String> {
16    if x <= 0.0 {
17        Err("Logarithm is only defined for positive numbers.".to_string())
18    } else {
19        Ok(x.ln())
20    }
21}
22
23/// Check if two numeric values are approximately equal within a specified epsilon.
24///
25/// This function works with any floating point or integer type that can be converted to a float.
26/// For integer types, it converts them to f64 for comparison.
27///
28/// # Arguments
29/// * `a` - First value to compare
30/// * `b` - Second value to compare
31/// * `epsilon` - Tolerance for equality comparison (defaults to 1e-10 if not specified)
32///
33/// # Returns
34/// * `bool` - True if the values are approximately equal, false otherwise
35///
36pub fn approx_equal<T, U>(a: T, b: U, epsilon: Option<f64>) -> bool
37where
38    T: NumCast + Copy + Debug,
39    U: NumCast + Copy + Debug,
40{
41    // Convert to f64 for comparison
42    let a_f64 = match T::to_f64(&a) {
43        Some(val) => val,
44        None => return false, // Can't compare if conversion fails
45    };
46
47    let b_f64 = match U::to_f64(&b) {
48        Some(val) => val,
49        None => return false, // Can't compare if conversion fails
50    };
51    let eps = epsilon.unwrap_or(1e-10);
52
53    // Handle special casesc
54    if a_f64.is_nan() || b_f64.is_nan() {
55        return false;
56    }
57
58    if a_f64.is_infinite() && b_f64.is_infinite() {
59        return (a_f64 > 0.0 && b_f64 > 0.0) || (a_f64 < 0.0 && b_f64 < 0.0);
60    }
61
62    // Calculate absolute and relative differences
63    let abs_diff = (a_f64 - b_f64).abs();
64
65    // For values close to zero, use absolute difference
66    if a_f64.abs() < eps || b_f64.abs() < eps {
67        return abs_diff <= eps;
68    }
69
70    // Otherwise use relative difference
71    let rel_diff = abs_diff / f64::max(a_f64.abs(), b_f64.abs());
72    rel_diff <= eps
73}
74
75/// Simpler interface for approximate equality with default epsilon
76pub fn approx_eq<T, U>(a: T, b: U) -> bool
77where
78    T: NumCast + Copy + Debug,
79    U: NumCast + Copy + Debug,
80{
81    approx_equal(a, b, None)
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn test_float_equality() {
90        assert!(approx_equal(1.0, 1.0, None));
91        assert!(approx_equal(1.0, 1.0000000001, Some(1e-9)));
92        assert!(!approx_equal(1.0, 1.0000000001, Some(1e-10)));
93    }
94
95    #[test]
96    fn test_integer_equality() {
97        assert!(approx_equal(1i32, 1i32, None));
98        assert!(approx_equal(1000i32, 1000, None));
99        assert!(approx_equal(1000u64, 1000.0001, Some(1e-6)));
100        assert!(!approx_equal(1000i32, 1001i32, None));
101    }
102
103    #[test]
104    fn test_mixed_type_equality() {
105        assert!(approx_equal(1i32, 1.0f64, None));
106        assert!(approx_equal(1000u16, 1000.0f32, None));
107        assert!(approx_equal(0i8, 0.0, None));
108        assert!(!approx_equal(5u8, 5.1f64, None));
109    }
110
111    #[test]
112    fn test_edge_cases() {
113        assert!(!approx_equal(f64::NAN, f64::NAN, None));
114        assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
115        assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
116        assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
117        assert!(!approx_equal(f64::INFINITY, 1e100, None));
118    }
119
120    #[test]
121    fn test_near_zero() {
122        assert!(approx_equal(0.0, 1e-11, None));
123        assert!(!approx_equal(0.0, 1e-9, None));
124    }
125}