use std::fmt;
use crate::stats::population_variance;
#[derive(Debug, Clone, PartialEq)]
pub enum TransformError {
NonPositiveData,
InsufficientData,
InvalidInverse,
}
impl fmt::Display for TransformError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TransformError::NonPositiveData => {
write!(f, "Box-Cox requires all y > 0")
}
TransformError::InsufficientData => {
write!(f, "need at least 2 data points")
}
TransformError::InvalidInverse => {
write!(f, "inverse transformation produced non-finite values")
}
}
}
}
impl std::error::Error for TransformError {}
fn validate_positive_slice(y: &[f64]) -> Result<(), TransformError> {
if y.len() < 2 {
return Err(TransformError::InsufficientData);
}
if y.iter().any(|&v| v <= 0.0) {
return Err(TransformError::NonPositiveData);
}
Ok(())
}
pub fn box_cox(y: &[f64], lambda: f64) -> Result<Vec<f64>, TransformError> {
validate_positive_slice(y)?;
let result = if lambda.abs() < 1e-10 {
y.iter().map(|&v| v.ln()).collect()
} else {
y.iter().map(|&v| (v.powf(lambda) - 1.0) / lambda).collect()
};
Ok(result)
}
pub fn inverse_box_cox(y_t: &[f64], lambda: f64) -> Result<Vec<f64>, TransformError> {
let result: Vec<f64> = if lambda.abs() < 1e-10 {
y_t.iter().map(|&v| v.exp()).collect()
} else {
y_t.iter()
.map(|&v| (v * lambda + 1.0).powf(1.0 / lambda))
.collect()
};
if result.iter().any(|v| !v.is_finite()) {
return Err(TransformError::InvalidInverse);
}
Ok(result)
}
pub fn estimate_lambda(y: &[f64], lambda_min: f64, lambda_max: f64) -> Result<f64, TransformError> {
if lambda_min >= lambda_max {
return Err(TransformError::InsufficientData);
}
validate_positive_slice(y)?;
let n = y.len() as f64;
let log_sum: f64 = y.iter().map(|&v| v.ln()).sum::<f64>();
let profile_ll = |lambda: f64| -> f64 {
let y_t: Vec<f64> = if lambda.abs() < 1e-10 {
y.iter().map(|&v| v.ln()).collect()
} else {
y.iter().map(|&v| (v.powf(lambda) - 1.0) / lambda).collect()
};
let var = population_variance(&y_t)
.expect("slice has >= 2 elements — variance is defined");
if var <= 0.0 {
return f64::NEG_INFINITY;
}
-(n / 2.0) * var.ln() + (lambda - 1.0) * log_sum
};
const PHI: f64 = 0.618_033_988_749_895; let mut a = lambda_min;
let mut b = lambda_max;
let mut x1 = b - PHI * (b - a);
let mut x2 = a + PHI * (b - a);
let mut f1 = profile_ll(x1);
let mut f2 = profile_ll(x2);
for _ in 0..100 {
if (b - a).abs() < 1e-6 {
break;
}
if f1 < f2 {
a = x1;
x1 = x2;
f1 = f2;
x2 = a + PHI * (b - a);
f2 = profile_ll(x2);
} else {
b = x2;
x2 = x1;
f2 = f1;
x1 = b - PHI * (b - a);
f1 = profile_ll(x1);
}
}
Ok((a + b) / 2.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn box_cox_log_transform() {
let y = vec![1.0, std::f64::consts::E, std::f64::consts::E.powi(2)];
let y_t = box_cox(&y, 0.0).unwrap();
assert!((y_t[0] - 0.0).abs() < 1e-10);
assert!((y_t[1] - 1.0).abs() < 1e-9);
assert!((y_t[2] - 2.0).abs() < 1e-9);
}
#[test]
fn box_cox_identity_lambda_1() {
let y = vec![2.0, 5.0, 10.0];
let y_t = box_cox(&y, 1.0).unwrap();
assert!((y_t[0] - 1.0).abs() < 1e-10);
assert!((y_t[1] - 4.0).abs() < 1e-10);
}
#[test]
fn box_cox_sqrt_lambda_half() {
let y = vec![4.0, 9.0];
let y_t = box_cox(&y, 0.5).unwrap();
assert!((y_t[0] - 2.0).abs() < 1e-10); assert!((y_t[1] - 4.0).abs() < 1e-10); }
#[test]
fn inverse_roundtrip_multiple_lambdas() {
let y = vec![1.5, 2.3, 4.7, 8.1, 15.2];
for &lambda in &[-0.5_f64, 0.0, 0.5, 1.0, 2.0] {
let y_t = box_cox(&y, lambda).unwrap();
let y_rec = inverse_box_cox(&y_t, lambda).unwrap();
for (orig, rec) in y.iter().zip(y_rec.iter()) {
assert!(
(orig - rec).abs() < 1e-9,
"lambda={lambda} orig={orig} rec={rec}"
);
}
}
}
#[test]
fn estimate_lambda_near_zero_for_exponential() {
let y: Vec<f64> = (1..=30).map(|i| (i as f64 * 0.2).exp()).collect();
let lambda = estimate_lambda(&y, -2.0, 2.0).unwrap();
assert!(lambda.abs() < 0.3, "Expected lambda near 0, got {lambda}");
}
#[test]
fn estimate_lambda_near_half_for_quadratic() {
let y: Vec<f64> = (1..=20).map(|i| (i as f64).powi(2)).collect();
let lambda = estimate_lambda(&y, -2.0, 2.0).unwrap();
assert!(
lambda > 0.2 && lambda < 0.8,
"Expected lambda ~0.5, got {lambda}"
);
}
#[test]
fn non_positive_returns_error() {
assert!(box_cox(&[1.0, -1.0, 2.0], 0.5).is_err());
assert!(box_cox(&[0.0, 1.0, 2.0], 0.5).is_err());
}
#[test]
fn insufficient_data_returns_error() {
assert!(box_cox(&[1.0], 0.5).is_err());
assert!(estimate_lambda(&[1.0], -2.0, 2.0).is_err());
}
#[test]
fn inverse_invalid_returns_error() {
let y_t = vec![-1.0, -0.8];
assert!(inverse_box_cox(&y_t, 2.0).is_err());
}
#[test]
fn estimate_lambda_invalid_range() {
let y = vec![1.0, 2.0, 3.0, 4.0];
assert!(estimate_lambda(&y, 1.0, 0.0).is_err()); assert!(estimate_lambda(&y, 0.5, 0.5).is_err()); }
#[test]
fn box_cox_negative_lambda() {
let y = vec![2.0, 4.0];
let y_t = box_cox(&y, -1.0).unwrap();
assert!((y_t[0] - 0.5).abs() < 1e-10);
assert!((y_t[1] - 0.75).abs() < 1e-10);
}
}