use crate::{Result, TreeBoostError};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct YeoJohnsonTransform {
lambdas: Vec<f32>,
fitted: bool,
max_iter: usize,
tolerance: f32,
}
impl Default for YeoJohnsonTransform {
fn default() -> Self {
Self::new()
}
}
impl YeoJohnsonTransform {
pub fn new() -> Self {
Self {
lambdas: Vec::new(),
fitted: false,
max_iter: 100,
tolerance: 1e-6,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tolerance(mut self, tolerance: f32) -> Self {
self.tolerance = tolerance;
self
}
pub fn with_lambdas(lambdas: Vec<f32>) -> Self {
Self {
lambdas,
fitted: true,
max_iter: 100,
tolerance: 1e-6,
}
}
pub fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if data.is_empty() {
return Err(TreeBoostError::Data("Cannot fit on empty data".into()));
}
let num_rows = data.len() / num_features;
if data.len() != num_rows * num_features {
return Err(TreeBoostError::Data(format!(
"Data length {} is not divisible by num_features {}",
data.len(),
num_features
)));
}
self.lambdas = Vec::with_capacity(num_features);
for col in 0..num_features {
let values: Vec<f32> = (0..num_rows)
.map(|row| data[row * num_features + col])
.filter(|v| !v.is_nan())
.collect();
if values.is_empty() {
self.lambdas.push(1.0);
continue;
}
let optimal_lambda = self.find_optimal_lambda(&values);
self.lambdas.push(optimal_lambda);
}
self.fitted = true;
Ok(())
}
pub fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
if !self.fitted {
return Err(TreeBoostError::Config(
"YeoJohnsonTransform not fitted. Call fit() first.".into(),
));
}
if self.lambdas.len() != num_features {
return Err(TreeBoostError::Config(format!(
"Feature count mismatch: fitted with {} features, got {}",
self.lambdas.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
for row in 0..num_rows {
for col in 0..num_features {
let idx = row * num_features + col;
if !data[idx].is_nan() {
data[idx] = yeo_johnson_transform(data[idx], self.lambdas[col]);
}
}
}
Ok(())
}
pub fn inverse_transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
if !self.fitted {
return Err(TreeBoostError::Config(
"YeoJohnsonTransform not fitted. Call fit() first.".into(),
));
}
if self.lambdas.len() != num_features {
return Err(TreeBoostError::Config(format!(
"Feature count mismatch: fitted with {} features, got {}",
self.lambdas.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
for row in 0..num_rows {
for col in 0..num_features {
let idx = row * num_features + col;
if !data[idx].is_nan() {
data[idx] = yeo_johnson_inverse(data[idx], self.lambdas[col]);
}
}
}
Ok(())
}
pub fn fit_transform(&mut self, data: &mut [f32], num_features: usize) -> Result<()> {
self.fit(data, num_features)?;
self.transform(data, num_features)?;
Ok(())
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn lambdas(&self) -> &[f32] {
&self.lambdas
}
fn find_optimal_lambda(&self, values: &[f32]) -> f32 {
let mut a = -5.0f32;
let mut b = 5.0f32;
let phi = (1.0 + 5.0f32.sqrt()) / 2.0;
let resphi = 2.0 - phi;
let mut x1 = a + resphi * (b - a);
let mut x2 = b - resphi * (b - a);
let mut f1 = -self.log_likelihood(values, x1);
let mut f2 = -self.log_likelihood(values, x2);
for _ in 0..self.max_iter {
if (b - a).abs() < self.tolerance {
break;
}
if f1 < f2 {
b = x2;
x2 = x1;
f2 = f1;
x1 = a + resphi * (b - a);
f1 = -self.log_likelihood(values, x1);
} else {
a = x1;
x1 = x2;
f1 = f2;
x2 = b - resphi * (b - a);
f2 = -self.log_likelihood(values, x2);
}
}
(a + b) / 2.0
}
fn log_likelihood(&self, values: &[f32], lambda: f32) -> f32 {
let n = values.len() as f32;
if n == 0.0 {
return f32::NEG_INFINITY;
}
let transformed: Vec<f32> = values
.iter()
.map(|&x| yeo_johnson_transform(x, lambda))
.collect();
let mean: f32 = transformed.iter().sum::<f32>() / n;
let variance: f32 = transformed.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
if variance <= 0.0 || variance.is_nan() {
return f32::NEG_INFINITY;
}
let jacobian_term: f32 = values
.iter()
.map(|&x| {
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
(lambda - 1.0) * sign * (x.abs() + 1.0).ln()
})
.sum();
-0.5 * n * variance.ln() + jacobian_term
}
}
#[inline]
pub fn yeo_johnson_transform(x: f32, lambda: f32) -> f32 {
if x >= 0.0 {
if lambda.abs() > 1e-10 {
((x + 1.0).powf(lambda) - 1.0) / lambda
} else {
(x + 1.0).ln()
}
} else {
let neg_x = -x;
if (lambda - 2.0).abs() > 1e-10 {
-((neg_x + 1.0).powf(2.0 - lambda) - 1.0) / (2.0 - lambda)
} else {
-(neg_x + 1.0).ln()
}
}
}
#[inline]
pub fn yeo_johnson_inverse(y: f32, lambda: f32) -> f32 {
if y >= 0.0 {
if lambda.abs() > 1e-10 {
(lambda * y + 1.0).powf(1.0 / lambda) - 1.0
} else {
y.exp() - 1.0
}
} else {
if (lambda - 2.0).abs() > 1e-10 {
1.0 - ((2.0 - lambda) * (-y) + 1.0).powf(1.0 / (2.0 - lambda))
} else {
1.0 - (-y).exp()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_yeo_johnson_transform_positive() {
let x = 5.0;
let lambda = 0.5;
let y = yeo_johnson_transform(x, lambda);
assert!(y > 0.0);
assert!(y.is_finite());
}
#[test]
fn test_yeo_johnson_transform_negative() {
let x = -5.0;
let lambda = 0.5;
let y = yeo_johnson_transform(x, lambda);
assert!(y < 0.0);
assert!(y.is_finite());
}
#[test]
fn test_yeo_johnson_transform_zero_lambda() {
let x = 5.0;
let y = yeo_johnson_transform(x, 0.0);
let expected = (x + 1.0).ln();
assert!((y - expected).abs() < 1e-5);
}
#[test]
fn test_yeo_johnson_transform_lambda_two() {
let x = -5.0;
let y = yeo_johnson_transform(x, 2.0);
let expected = -((-x) + 1.0).ln();
assert!((y - expected).abs() < 1e-5);
}
#[test]
fn test_yeo_johnson_inverse_positive() {
let x = 5.0;
let lambda = 0.5;
let y = yeo_johnson_transform(x, lambda);
let x_recovered = yeo_johnson_inverse(y, lambda);
assert!((x - x_recovered).abs() < 1e-4);
}
#[test]
fn test_yeo_johnson_inverse_negative() {
let x = -5.0;
let lambda = 0.5;
let y = yeo_johnson_transform(x, lambda);
let x_recovered = yeo_johnson_inverse(y, lambda);
assert!((x - x_recovered).abs() < 1e-4);
}
#[test]
fn test_yeo_johnson_fit() {
let mut transform = YeoJohnsonTransform::new();
let data = vec![0.1, 10.0, 1.0, 100.0, 5.0, 1000.0];
transform.fit(&data, 2).unwrap();
assert!(transform.is_fitted());
assert_eq!(transform.lambdas().len(), 2);
for &lambda in transform.lambdas() {
assert!(lambda > -5.0 && lambda < 5.0);
}
}
#[test]
fn test_yeo_johnson_fit_transform() {
let mut transform = YeoJohnsonTransform::new();
let mut data = vec![0.1, 1.0, 10.0, 100.0];
transform.fit_transform(&mut data, 2).unwrap();
assert!(transform.is_fitted());
for &v in &data {
assert!(v.is_finite());
}
}
#[test]
fn test_yeo_johnson_inverse_transform() {
let mut transform = YeoJohnsonTransform::new();
let original = vec![0.5, 2.0, 5.0, 10.0]; let mut data = original.clone();
transform.fit_transform(&mut data, 2).unwrap();
transform.inverse_transform(&mut data, 2).unwrap();
for (orig, recovered) in original.iter().zip(data.iter()) {
assert!(
(orig - recovered).abs() < 0.01,
"orig={}, recovered={}",
orig,
recovered
);
}
}
#[test]
fn test_yeo_johnson_with_lambdas() {
let transform = YeoJohnsonTransform::with_lambdas(vec![0.5, 1.0]);
assert!(transform.is_fitted());
assert_eq!(transform.lambdas(), &[0.5, 1.0]);
}
#[test]
fn test_yeo_johnson_not_fitted_error() {
let transform = YeoJohnsonTransform::new();
let mut data = vec![1.0, 2.0];
let result = transform.transform(&mut data, 2);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not fitted"));
}
#[test]
fn test_yeo_johnson_with_nan() {
let mut transform = YeoJohnsonTransform::new();
let mut data = vec![1.0, f32::NAN, 5.0, 10.0];
transform.fit_transform(&mut data, 2).unwrap();
assert!(!data[0].is_nan()); assert!(data[1].is_nan()); assert!(!data[2].is_nan()); assert!(!data[3].is_nan()); }
#[test]
fn test_yeo_johnson_serialization() {
let mut transform = YeoJohnsonTransform::new();
transform.fit(&[1.0, 2.0, 3.0, 4.0], 2).unwrap();
let json = serde_json::to_string(&transform).unwrap();
let loaded: YeoJohnsonTransform = serde_json::from_str(&json).unwrap();
assert!(loaded.is_fitted());
assert_eq!(loaded.lambdas(), transform.lambdas());
}
#[test]
fn test_yeo_johnson_identity_lambda_one() {
let x = 5.0;
let y = yeo_johnson_transform(x, 1.0);
assert!((y - x).abs() < 1e-5);
}
#[test]
fn test_yeo_johnson_all_nan_column() {
let mut transform = YeoJohnsonTransform::new();
let data = vec![f32::NAN, 1.0, f32::NAN, 2.0];
transform.fit(&data, 2).unwrap();
assert!(transform.is_fitted());
assert_eq!(transform.lambdas().len(), 2);
}
}