use crate::error::{CoreError, ErrorContext};
use crate::validation::check_finite;
use num_traits::Float;
use std::fmt::{Debug, Display};
#[inline]
#[allow(dead_code)]
pub fn safe_divide<T>(numerator: T, denominator: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
if denominator == T::zero() {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Division by zero: {numerator} / 0"
))));
}
let epsilon = T::epsilon();
if denominator.abs() < epsilon {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Division by near-zero value: {numerator} / {denominator} (threshold: {epsilon})"
))));
}
let result = numerator / denominator;
check_finite(result, "division result").map_err(|_| {
CoreError::ComputationError(ErrorContext::new(format!(
"Division produced non-finite result: {numerator} / {denominator} = {result:?}"
)))
})?;
Ok(result)
}
#[inline]
#[allow(dead_code)]
pub fn safe_sqrt<T>(value: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
if value < T::zero() {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Cannot compute sqrt of negative value: {value}"
))));
}
let result = value.sqrt();
check_finite(result, "sqrt result").map_err(|_| {
CoreError::ComputationError(ErrorContext::new(format!(
"Square root produced non-finite result: sqrt({value}) = {result:?}"
)))
})?;
Ok(result)
}
#[inline]
#[allow(dead_code)]
pub fn safelog<T>(value: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
if value <= T::zero() {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Cannot compute log of non-positive value: {value}"
))));
}
let result = value.ln();
check_finite(result, "log result").map_err(|_| {
CoreError::ComputationError(ErrorContext::new(format!(
"Logarithm produced non-finite result: ln({value}) = {result:?}"
)))
})?;
Ok(result)
}
#[inline]
#[allow(dead_code)]
pub fn safelog10<T>(value: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
if value <= T::zero() {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Cannot compute log10 of non-positive value: {value}"
))));
}
let result = value.log10();
check_finite(result, "log10 result").map_err(|_| {
CoreError::ComputationError(ErrorContext::new(format!(
"Base-10 logarithm produced non-finite result: log10({value}) = {result:?}"
)))
})?;
Ok(result)
}
#[inline]
#[allow(dead_code)]
pub fn safe_pow<T>(base: T, exponent: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
if base < T::zero() && exponent.fract() != T::zero() {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Cannot compute fractional power of negative number: {base}^{exponent}"
))));
}
if base == T::zero() && exponent < T::zero() {
return Err(CoreError::DomainError(ErrorContext::new(format!(
"Cannot compute negative power of zero: 0^{exponent}"
))));
}
let result = base.powf(exponent);
check_finite(result, "power result").map_err(|_| {
CoreError::ComputationError(ErrorContext::new(format!(
"Power operation produced non-finite result: {base}^{exponent} = {result:?}"
)))
})?;
Ok(result)
}
#[inline]
#[allow(dead_code)]
pub fn safe_exp<T>(value: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
let max_exp = T::from(700.0).unwrap_or(T::infinity());
if value > max_exp {
return Err(CoreError::ComputationError(ErrorContext::new(format!(
"Exponential would overflow: exp({value}) > exp({max_exp})"
))));
}
let result = value.exp();
check_finite(result, "exp result").map_err(|_| {
CoreError::ComputationError(ErrorContext::new(format!(
"Exponential produced non-finite result: exp({value}) = {result:?}"
)))
})?;
Ok(result)
}
#[inline]
#[allow(dead_code)]
pub fn safe_normalize<T>(value: T, norm: T) -> Result<T, CoreError>
where
T: Float + Display + Debug,
{
if value == T::zero() && norm == T::zero() {
return Ok(T::zero());
}
safe_divide(value, norm)
}
#[allow(dead_code)]
pub fn safe_mean<T>(values: &[T]) -> Result<T, CoreError>
where
T: Float + Display + Debug + std::iter::Sum,
{
if values.is_empty() {
return Err(CoreError::InvalidArgument(ErrorContext::new(
"Cannot compute mean of empty array",
)));
}
let sum: T = values.iter().copied().sum();
let len = values.len();
let count = T::from(len).ok_or_else(|| {
CoreError::ComputationError(ErrorContext::new(format!(
"Failed to convert array length {len} to numeric type"
)))
})?;
safe_divide(sum, count)
}
#[allow(dead_code)]
pub fn safe_variance<T>(values: &[T], mean: T) -> Result<T, CoreError>
where
T: Float + Display + Debug + std::iter::Sum,
{
let len = values.len();
if len < 2 {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"Cannot compute variance with {len} values (need at least 2)"
))));
}
let sum_sq_diff: T = values
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum();
let count = values.len() - 1;
let n_minus_1 = T::from(count).ok_or_else(|| {
CoreError::ComputationError(ErrorContext::new(format!(
"Failed to convert count {count} to numeric type"
)))
})?;
safe_divide(sum_sq_diff, n_minus_1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_divide() {
assert_eq!(safe_divide(10.0, 2.0).expect("Operation failed"), 5.0);
assert_eq!(safe_divide(-10.0, 2.0).expect("Operation failed"), -5.0);
assert!(safe_divide(10.0, 0.0).is_err());
assert!(safe_divide(10.0, 1e-100).is_err());
assert!(safe_divide(f64::MAX, f64::MIN_POSITIVE).is_err());
}
#[test]
fn test_safe_sqrt() {
assert_eq!(safe_sqrt(4.0).expect("Operation failed"), 2.0);
assert_eq!(safe_sqrt(0.0).expect("Operation failed"), 0.0);
assert!(safe_sqrt(-1.0).is_err());
assert!(safe_sqrt(-1e-10).is_err());
}
#[test]
fn test_safelog() {
assert!((safelog(std::f64::consts::E).expect("Operation failed") - 1.0).abs() < 1e-10);
assert_eq!(safelog(1.0).expect("Operation failed"), 0.0);
assert!(safelog(0.0).is_err());
assert!(safelog(-1.0).is_err());
}
#[test]
fn test_safe_pow() {
assert_eq!(safe_pow(2.0, 3.0).expect("Operation failed"), 8.0);
assert_eq!(safe_pow(4.0, 0.5).expect("Operation failed"), 2.0);
assert!(safe_pow(-2.0, 0.5).is_err()); assert!(safe_pow(0.0, -1.0).is_err());
assert!(safe_pow(10.0, 1000.0).is_err());
}
#[test]
fn test_safe_exp() {
assert!((safe_exp(1.0).expect("Operation failed") - std::f64::consts::E).abs() < 1e-10);
assert_eq!(safe_exp(0.0).expect("Operation failed"), 1.0);
assert!(safe_exp(1000.0).is_err());
}
#[test]
fn test_safe_mean() {
assert_eq!(safe_mean(&[1.0, 2.0, 3.0]).expect("Operation failed"), 2.0);
assert!(safe_mean::<f64>(&[]).is_err());
assert_eq!(safe_mean(&[5.0]).expect("Operation failed"), 5.0);
}
#[test]
fn test_safe_variance() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mean = 3.0;
assert!((safe_variance(&values, mean).expect("Operation failed") - 2.5).abs() < 1e-10);
assert!(safe_variance(&[1.0], 1.0).is_err());
assert!(safe_variance::<f64>(&[], 0.0).is_err());
}
#[test]
fn test_safe_normalize() {
assert_eq!(safe_normalize(3.0, 4.0).expect("Operation failed"), 0.75);
assert!(safe_normalize(1.0, 0.0).is_err());
assert_eq!(safe_normalize(0.0, 0.0).expect("Operation failed"), 0.0);
}
}