rs_stats/utils/
numeric.rs

1use num_traits::NumCast;
2/// Provides numerical utility functions for statistical calculations.
3use std::fmt::Debug;
4
5use crate::error::{StatsError, StatsResult};
6
7/// Computes the natural logarithm of x, handling edge cases safely.
8///
9/// # Arguments
10/// * `x` - The input number.
11///
12/// # Returns
13/// * `StatsResult<f64>` - The natural logarithm of x, or an error if x is invalid.
14///
15/// # Errors
16/// Returns `StatsError::InvalidInput` if x is less than or equal to 0.
17///
18/// # Examples
19/// ```
20/// use rs_stats::utils::safe_log;
21///
22/// let result = safe_log(2.71828).unwrap();
23/// assert!((result - 1.0).abs() < 1e-5);
24///
25/// // Error case
26/// let result = safe_log(0.0);
27/// assert!(result.is_err());
28/// ```
29pub fn safe_log(x: f64) -> StatsResult<f64> {
30    if x <= 0.0 {
31        Err(StatsError::invalid_input(
32            "Logarithm is only defined for positive numbers.",
33        ))
34    } else {
35        Ok(x.ln())
36    }
37}
38
39/// Check if two numeric values are approximately equal within a specified epsilon.
40///
41/// This function works with any floating point or integer type that can be converted to a float.
42/// For integer types, it converts them to f64 for comparison.
43///
44/// # Arguments
45/// * `a` - First value to compare
46/// * `b` - Second value to compare
47/// * `epsilon` - Tolerance for equality comparison (defaults to 1e-10 if not specified)
48///
49/// # Returns
50/// * `bool` - True if the values are approximately equal, false otherwise
51///
52pub fn approx_equal<T, U>(a: T, b: U, epsilon: Option<f64>) -> bool
53where
54    T: NumCast + Copy + Debug,
55    U: NumCast + Copy + Debug,
56{
57    // Convert to f64 for comparison
58    let a_f64 = match T::to_f64(&a) {
59        Some(val) => val,
60        None => return false, // Can't compare if conversion fails
61    };
62
63    let b_f64 = match U::to_f64(&b) {
64        Some(val) => val,
65        None => return false, // Can't compare if conversion fails
66    };
67    let eps = epsilon.unwrap_or(1e-10);
68
69    // Handle special casesc
70    if a_f64.is_nan() || b_f64.is_nan() {
71        return false;
72    }
73
74    if a_f64.is_infinite() && b_f64.is_infinite() {
75        return (a_f64 > 0.0 && b_f64 > 0.0) || (a_f64 < 0.0 && b_f64 < 0.0);
76    }
77
78    // Calculate absolute and relative differences
79    let abs_diff = (a_f64 - b_f64).abs();
80
81    // For values close to zero, use absolute difference
82    if a_f64.abs() < eps || b_f64.abs() < eps {
83        return abs_diff <= eps;
84    }
85
86    // Otherwise use relative difference
87    let rel_diff = abs_diff / f64::max(a_f64.abs(), b_f64.abs());
88    rel_diff <= eps
89}
90
91/// Simpler interface for approximate equality with default epsilon
92pub fn approx_eq<T, U>(a: T, b: U) -> bool
93where
94    T: NumCast + Copy + Debug,
95    U: NumCast + Copy + Debug,
96{
97    approx_equal(a, b, None)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_float_equality() {
106        assert!(approx_equal(1.0, 1.0, None));
107        assert!(approx_equal(1.0, 1.0000000001, Some(1e-9)));
108        assert!(!approx_equal(1.0, 1.0000000001, Some(1e-10)));
109    }
110
111    #[test]
112    fn test_integer_equality() {
113        assert!(approx_equal(1i32, 1i32, None));
114        assert!(approx_equal(1000i32, 1000, None));
115        assert!(approx_equal(1000u64, 1000.0001, Some(1e-6)));
116        assert!(!approx_equal(1000i32, 1001i32, None));
117    }
118
119    #[test]
120    fn test_mixed_type_equality() {
121        assert!(approx_equal(1i32, 1.0f64, None));
122        assert!(approx_equal(1000u16, 1000.0f32, None));
123        assert!(approx_equal(0i8, 0.0, None));
124        assert!(!approx_equal(5u8, 5.1f64, None));
125    }
126
127    #[test]
128    fn test_edge_cases() {
129        assert!(!approx_equal(f64::NAN, f64::NAN, None));
130        assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
131        assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
132        assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
133        assert!(!approx_equal(f64::INFINITY, 1e100, None));
134    }
135
136    #[test]
137    fn test_near_zero() {
138        assert!(approx_equal(0.0, 1e-11, None));
139        assert!(!approx_equal(0.0, 1e-9, None));
140    }
141
142    // Note: Testing conversion failures in approx_equal is difficult because
143    // NumCast::to_f64() for standard numeric types always succeeds.
144    // The conversion failure path is mainly for custom types that don't implement NumCast properly.
145    // However, we can test the edge cases that are testable.
146
147    #[test]
148    fn test_approx_equal_infinity_combinations() {
149        // Test all infinity combinations
150        assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
151        assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
152        assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
153        assert!(!approx_equal(f64::NEG_INFINITY, f64::INFINITY, None));
154        assert!(!approx_equal(f64::INFINITY, 0.0, None));
155        assert!(!approx_equal(f64::NEG_INFINITY, 0.0, None));
156    }
157
158    #[test]
159    fn test_approx_equal_nan_combinations() {
160        // Test NaN combinations
161        assert!(!approx_equal(f64::NAN, f64::NAN, None));
162        assert!(!approx_equal(f64::NAN, 0.0, None));
163        assert!(!approx_equal(0.0, f64::NAN, None));
164        assert!(!approx_equal(f64::NAN, f64::INFINITY, None));
165        assert!(!approx_equal(f64::INFINITY, f64::NAN, None));
166    }
167
168    #[test]
169    fn test_approx_equal_relative_difference() {
170        // Test relative difference calculation (for values not near zero)
171        // Relative diff = |1000.0 - 1000.1| / max(|1000.0|, |1000.1|) = 0.1 / 1000.1 ≈ 0.0001 < 1e-3
172        assert!(approx_equal(1000.0, 1000.1, Some(1e-3)));
173        // Relative diff = |1000.0 - 1001.0| / max(|1000.0|, |1001.0|) = 1.0 / 1001.0 ≈ 0.001 = 1e-3
174        // Since relative_diff <= eps (1e-3), they should be equal
175        assert!(approx_equal(1000.0, 1001.0, Some(1e-3)));
176        // But with stricter epsilon, they should not be equal
177        assert!(!approx_equal(1000.0, 1001.0, Some(1e-4)));
178    }
179
180    #[test]
181    fn test_approx_equal_absolute_difference_near_zero() {
182        // Test absolute difference calculation (for values near zero)
183        assert!(approx_equal(1e-11, 0.0, None));
184        assert!(approx_equal(0.0, 1e-11, None));
185        assert!(!approx_equal(1e-9, 0.0, None));
186    }
187
188    #[test]
189    fn test_safe_log_positive() {
190        let result = safe_log(1.0);
191        assert!(result.is_ok());
192        assert_eq!(result.unwrap(), 0.0);
193    }
194
195    #[test]
196    fn test_safe_log_zero() {
197        let result = safe_log(0.0);
198        assert!(result.is_err());
199        assert!(matches!(
200            result.unwrap_err(),
201            StatsError::InvalidInput { .. }
202        ));
203    }
204
205    #[test]
206    fn test_safe_log_negative() {
207        let result = safe_log(-1.0);
208        assert!(result.is_err());
209        assert!(matches!(
210            result.unwrap_err(),
211            StatsError::InvalidInput { .. }
212        ));
213    }
214
215    #[test]
216    fn test_safe_log_known_value() {
217        // ln(e) = 1
218        let result = safe_log(std::f64::consts::E);
219        assert!(result.is_ok());
220        assert!((result.unwrap() - 1.0).abs() < 1e-10);
221    }
222}