use num_traits::NumCast;
use std::fmt::Debug;
use crate::error::{StatsError, StatsResult};
pub fn safe_log(x: f64) -> StatsResult<f64> {
if x <= 0.0 {
Err(StatsError::invalid_input(
"Logarithm is only defined for positive numbers.",
))
} else {
Ok(x.ln())
}
}
pub fn approx_equal<T, U>(a: T, b: U, epsilon: Option<f64>) -> bool
where
T: NumCast + Copy + Debug,
U: NumCast + Copy + Debug,
{
let a_f64 = match T::to_f64(&a) {
Some(val) => val,
None => return false, };
let b_f64 = match U::to_f64(&b) {
Some(val) => val,
None => return false, };
let eps = epsilon.unwrap_or(1e-10);
if a_f64.is_nan() || b_f64.is_nan() {
return false;
}
if a_f64.is_infinite() && b_f64.is_infinite() {
return (a_f64 > 0.0 && b_f64 > 0.0) || (a_f64 < 0.0 && b_f64 < 0.0);
}
let abs_diff = (a_f64 - b_f64).abs();
if a_f64.abs() < eps || b_f64.abs() < eps {
return abs_diff <= eps;
}
let rel_diff = abs_diff / f64::max(a_f64.abs(), b_f64.abs());
rel_diff <= eps
}
pub fn approx_eq<T, U>(a: T, b: U) -> bool
where
T: NumCast + Copy + Debug,
U: NumCast + Copy + Debug,
{
approx_equal(a, b, None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_float_equality() {
assert!(approx_equal(1.0, 1.0, None));
assert!(approx_equal(1.0, 1.0000000001, Some(1e-9)));
assert!(!approx_equal(1.0, 1.0000000001, Some(1e-10)));
}
#[test]
fn test_integer_equality() {
assert!(approx_equal(1i32, 1i32, None));
assert!(approx_equal(1000i32, 1000, None));
assert!(approx_equal(1000u64, 1000.0001, Some(1e-6)));
assert!(!approx_equal(1000i32, 1001i32, None));
}
#[test]
fn test_mixed_type_equality() {
assert!(approx_equal(1i32, 1.0f64, None));
assert!(approx_equal(1000u16, 1000.0f32, None));
assert!(approx_equal(0i8, 0.0, None));
assert!(!approx_equal(5u8, 5.1f64, None));
}
#[test]
fn test_edge_cases() {
assert!(!approx_equal(f64::NAN, f64::NAN, None));
assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
assert!(!approx_equal(f64::INFINITY, 1e100, None));
}
#[test]
fn test_near_zero() {
assert!(approx_equal(0.0, 1e-11, None));
assert!(!approx_equal(0.0, 1e-9, None));
}
#[test]
fn test_approx_equal_infinity_combinations() {
assert!(approx_equal(f64::INFINITY, f64::INFINITY, None));
assert!(approx_equal(f64::NEG_INFINITY, f64::NEG_INFINITY, None));
assert!(!approx_equal(f64::INFINITY, f64::NEG_INFINITY, None));
assert!(!approx_equal(f64::NEG_INFINITY, f64::INFINITY, None));
assert!(!approx_equal(f64::INFINITY, 0.0, None));
assert!(!approx_equal(f64::NEG_INFINITY, 0.0, None));
}
#[test]
fn test_approx_equal_nan_combinations() {
assert!(!approx_equal(f64::NAN, f64::NAN, None));
assert!(!approx_equal(f64::NAN, 0.0, None));
assert!(!approx_equal(0.0, f64::NAN, None));
assert!(!approx_equal(f64::NAN, f64::INFINITY, None));
assert!(!approx_equal(f64::INFINITY, f64::NAN, None));
}
#[test]
fn test_approx_equal_relative_difference() {
assert!(approx_equal(1000.0, 1000.1, Some(1e-3)));
assert!(approx_equal(1000.0, 1001.0, Some(1e-3)));
assert!(!approx_equal(1000.0, 1001.0, Some(1e-4)));
}
#[test]
fn test_approx_equal_absolute_difference_near_zero() {
assert!(approx_equal(1e-11, 0.0, None));
assert!(approx_equal(0.0, 1e-11, None));
assert!(!approx_equal(1e-9, 0.0, None));
}
#[test]
fn test_safe_log_positive() {
let result = safe_log(1.0);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.0);
}
#[test]
fn test_safe_log_zero() {
let result = safe_log(0.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_safe_log_negative() {
let result = safe_log(-1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_safe_log_known_value() {
let result = safe_log(std::f64::consts::E);
assert!(result.is_ok());
assert!((result.unwrap() - 1.0).abs() < 1e-10);
}
}