use crate::linalg::Matrix;
#[derive(Clone, Debug)]
pub struct StandardizationInfo {
pub x_mean: Vec<f64>,
pub x_scale: Vec<f64>,
pub column_squared_norms: Vec<f64>,
pub y_mean: f64,
pub y_scale: Option<f64>,
pub y_scale_before_sqrt_weights_normalized: Option<f64>,
pub intercept: bool,
pub standardized_x: bool,
pub standardized_y: bool,
}
#[derive(Clone, Debug)]
pub struct StandardizeOptions {
pub intercept: bool,
pub standardize_x: bool,
pub standardize_y: bool,
pub weights: Option<Vec<f64>>,
}
impl Default for StandardizeOptions {
fn default() -> Self {
StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: None,
}
}
}
#[allow(clippy::needless_range_loop)]
pub fn standardize_xy(
x: &Matrix,
y: &[f64],
options: &StandardizeOptions,
) -> (Matrix, Vec<f64>, StandardizationInfo) {
let n = x.rows;
let p = x.cols;
if let Some(ref w) = options.weights {
if w.len() != n {
return (
Matrix::new(n, p, vec![0.0; n * p]),
vec![0.0; n],
StandardizationInfo {
x_mean: vec![0.0; p],
x_scale: vec![1.0; p],
column_squared_norms: vec![0.0; p],
y_mean: 0.0,
y_scale: None,
y_scale_before_sqrt_weights_normalized: None,
intercept: options.intercept,
standardized_x: options.standardize_x,
standardized_y: options.standardize_y,
},
);
}
if w.iter().any(|&wi| wi < 0.0) {
panic!("Weights must be non-negative");
}
}
let (weights_normalized, sqrt_weights_normalized): (Vec<f64>, Vec<f64>) = if let Some(ref w) = options.weights {
let w_sum: f64 = w.iter().sum();
if w_sum > 0.0 {
let weights_normalized_vec: Vec<f64> = w.iter().map(|&wi| wi / w_sum).collect();
let sqrt_weights_normalized_vec: Vec<f64> = weights_normalized_vec.iter().map(|&wi| wi.sqrt()).collect();
(weights_normalized_vec, sqrt_weights_normalized_vec)
} else {
(vec![0.0; n], vec![0.0; n])
}
} else {
let w_uniform = vec![1.0 / n as f64; n];
let sqrt_weights_normalized_uniform = vec![1.0 / (n as f64).sqrt(); n];
(w_uniform, sqrt_weights_normalized_uniform)
};
let mut x_standardized = x.clone();
let mut y_standardized = y.to_vec();
let mut x_mean = vec![0.0; p];
let mut x_scale = vec![1.0; p];
let mut column_squared_norms = vec![0.0; p];
let y_mean = if options.intercept && !y.is_empty() {
weights_normalized.iter().zip(y.iter()).map(|(&w, &yi)| w * yi).sum()
} else {
0.0
};
let (y_scale, y_scale_before_sqrt_weights_normalized) = if options.intercept {
let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
let y_ss_before_sqrt_weights_normalized: f64 = y_centered.iter().map(|&yi| yi * yi).sum();
let y_scale_before_sqrt_weights_normalized_val = y_ss_before_sqrt_weights_normalized.sqrt();
for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
*yi = sqrt_weight * (*yi - y_mean);
}
let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
let y_scale_val = y_ss.sqrt();
if y_scale_val > 0.0 {
for yi in y_standardized.iter_mut() {
*yi /= y_scale_val;
}
}
(Some(y_scale_val), Some(y_scale_before_sqrt_weights_normalized_val))
} else {
for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
*yi *= sqrt_weight;
}
let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
let y_scale_val = y_ss.sqrt();
if y_scale_val > 0.0 {
for yi in y_standardized.iter_mut() {
*yi /= y_scale_val;
}
}
(Some(y_scale_val), Some(y_scale_val)) };
let first_penalized_column_index = if options.intercept { 1 } else { 0 };
if options.intercept {
for j in first_penalized_column_index..p {
let col_mean: f64 = (0..n)
.map(|i| x_standardized.get(i, j) * weights_normalized[i])
.sum();
x_mean[j] = col_mean;
for i in 0..n {
let val = sqrt_weights_normalized[i] * (x_standardized.get(i, j) - col_mean);
x_standardized.set(i, j, val);
}
let col_squared_norm_val: f64 = (0..n)
.map(|i| {
let val = x_standardized.get(i, j);
val * val
})
.sum();
if options.standardize_x {
let col_scale = col_squared_norm_val.sqrt();
if col_scale > 0.0 {
for i in 0..n {
let val = x_standardized.get(i, j) / col_scale;
x_standardized.set(i, j, val);
}
x_scale[j] = col_scale;
column_squared_norms[j] = 1.0; }
} else {
column_squared_norms[j] = col_squared_norm_val;
x_scale[j] = 1.0;
}
}
} else {
for j in first_penalized_column_index..p {
x_mean[j] = 0.0;
for i in 0..n {
let val = sqrt_weights_normalized[i] * x_standardized.get(i, j);
x_standardized.set(i, j, val);
}
let col_squared_norm_val: f64 = (0..n)
.map(|i| {
let val = x_standardized.get(i, j);
val * val
})
.sum();
if options.standardize_x {
let x_squared_mean: f64 = (0..n)
.map(|i| sqrt_weights_normalized[i] * x_standardized.get(i, j))
.sum::<f64>().powi(2);
let x_centered_variance = col_squared_norm_val - x_squared_mean;
if x_centered_variance > 0.0 {
let col_scale = x_centered_variance.sqrt();
for i in 0..n {
let val = x_standardized.get(i, j) / col_scale;
x_standardized.set(i, j, val);
}
x_scale[j] = col_scale;
column_squared_norms[j] = 1.0 + x_squared_mean / x_centered_variance; } else {
column_squared_norms[j] = 1.0;
x_scale[j] = 1.0;
}
} else {
column_squared_norms[j] = col_squared_norm_val;
x_scale[j] = 1.0;
}
}
}
if options.intercept && p > 0 {
x_scale[0] = 1.0;
x_mean[0] = 0.0; column_squared_norms[0] = 1.0; }
let info = StandardizationInfo {
x_mean,
x_scale,
column_squared_norms,
y_mean,
y_scale,
y_scale_before_sqrt_weights_normalized,
intercept: options.intercept,
standardized_x: options.standardize_x,
standardized_y: options.standardize_y,
};
(x_standardized, y_standardized, info)
}
#[allow(clippy::needless_range_loop)]
pub fn unstandardize_coefficients(coefficients_standardized: &[f64], info: &StandardizationInfo) -> (f64, Vec<f64>) {
let p = coefficients_standardized.len();
let y_scale = info.y_scale.unwrap_or(1.0);
let start_idx = if info.intercept { 1 } else { 0 };
let n_slopes = p - start_idx;
let mut beta_slopes = vec![0.0; n_slopes];
for j in start_idx..p {
let slope_idx = j - start_idx;
beta_slopes[slope_idx] = (y_scale * coefficients_standardized[j]) / info.x_scale[j];
}
let beta0 = if info.intercept {
let mut sum = 0.0;
for j in 1..p {
sum += info.x_mean[j] * beta_slopes[j - 1];
}
info.y_mean - sum
} else {
0.0
};
(beta0, beta_slopes)
}
#[allow(clippy::needless_range_loop)]
pub fn predict(x_new: &Matrix, beta0: f64, beta: &[f64]) -> Vec<f64> {
let n = x_new.rows;
let p = x_new.cols;
let mut predictions = vec![0.0; n];
let has_intercept_col = beta.len() == p - 1;
let first_penalized_column_index = if has_intercept_col { 1 } else { 0 };
for i in 0..n {
let mut sum = beta0;
for (j, &beta_j) in beta.iter().enumerate() {
let col = first_penalized_column_index + j;
if col < p {
sum += x_new.get(i, col) * beta_j;
}
}
predictions[i] = sum;
}
predictions
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standardize_xy_with_intercept() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
let x = Matrix::new(3, 3, x_data);
let y = vec![3.0, 5.0, 7.0];
let options = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false, weights: None,
};
let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(x_standardized.get(0, 0), 1.0);
assert_eq!(x_standardized.get(1, 0), 1.0);
assert_eq!(x_standardized.get(2, 0), 1.0);
let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
assert!((y_standardized[0] - (-inv_sqrt2)).abs() < 1e-10);
assert!((y_standardized[1] - 0.0).abs() < 1e-10);
assert!((y_standardized[2] - inv_sqrt2).abs() < 1e-10);
assert_eq!(info.x_mean[0], 0.0); assert!((info.x_mean[1] - 4.0).abs() < 1e-10);
assert!((info.x_mean[2] - 6.0).abs() < 1e-10);
}
#[test]
fn test_unstandardize_coefficients() {
let x_mean = vec![0.0, 4.0, 6.0];
let x_scale = vec![1.0, 2.0, 3.0];
let column_squared_norms = vec![1.0, 1.0, 1.0]; let y_mean = 5.0;
let y_scale = Some(2.0);
let info = StandardizationInfo {
x_mean: x_mean.clone(),
x_scale: x_scale.clone(),
column_squared_norms,
y_mean,
y_scale,
y_scale_before_sqrt_weights_normalized: None,
intercept: true,
standardized_x: true,
standardized_y: true,
};
let coefficients_standardized = vec![0.0, 1.0, 2.0];
let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
assert!((beta_slopes[0] - 1.0).abs() < 1e-10);
assert!((beta_slopes[1] - 4.0 / 3.0).abs() < 1e-10);
assert!((beta0 - (-7.0)).abs() < 1e-10);
assert_eq!(beta_slopes.len(), 2);
}
#[test]
fn test_predict() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let beta0 = 1.0;
let beta = vec![2.0, 3.0];
let preds = predict(&x, beta0, &beta);
assert!((preds[0] - 14.0).abs() < 1e-10);
assert!((preds[1] - 27.0).abs() < 1e-10);
}
#[test]
fn test_weighted_standardize_xy() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
let x = Matrix::new(3, 3, x_data);
let y = vec![3.0, 5.0, 7.0];
let weights = vec![1.0, 2.0, 1.0];
let options = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false, weights: Some(weights),
};
let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(x_standardized.get(0, 0), 1.0);
assert_eq!(x_standardized.get(1, 0), 1.0);
assert_eq!(x_standardized.get(2, 0), 1.0);
assert!((info.y_mean - 5.0).abs() < 1e-10);
let expected_0 = -1.0 / (2.0_f64).sqrt();
assert!((y_standardized[0] - expected_0).abs() < 1e-10);
assert!((y_standardized[1] - 0.0).abs() < 1e-10);
assert!((y_standardized[2] + expected_0).abs() < 1e-10); }
#[test]
fn test_weighted_standardize_uniform_weights() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let y = vec![3.0, 5.0];
let weights = vec![1.0, 1.0];
let options_with_weights = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: Some(weights),
};
let options_no_weights = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: None,
};
let (_x_standardized_w, y_standardized_w, info_w) = standardize_xy(&x, &y, &options_with_weights);
let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options_no_weights);
assert_eq!(info_w.y_mean, info.y_mean);
for i in 0..2 {
assert!((y_standardized_w[i] - y_standardized[i]).abs() < 1e-10);
}
}
#[test]
fn test_standardize_xy_weights_dimension_mismatch() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let y = vec![3.0, 5.0];
let weights = vec![1.0, 1.0, 1.0];
let options = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: Some(weights),
};
let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(x_standardized.rows, 2);
assert_eq!(x_standardized.cols, 3);
assert_eq!(y_standardized, vec![0.0, 0.0]);
assert!(!info.standardized_y);
assert!(info.intercept);
assert!(info.standardized_x);
}
#[test]
#[should_panic(expected = "Weights must be non-negative")]
fn test_standardize_xy_negative_weights_panics() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let y = vec![3.0, 5.0];
let weights = vec![1.0, -0.5];
let options = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: Some(weights),
};
let _ = standardize_xy(&x, &y, &options);
}
#[test]
fn test_standardize_xy_zero_sum_weights() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let y = vec![3.0, 5.0];
let weights = vec![0.0, 0.0];
let options = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: Some(weights),
};
let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(info.y_mean, 0.0);
assert_eq!(y_standardized, vec![0.0, 0.0]);
}
#[test]
fn test_standardize_xy_without_intercept() {
let x_data = vec![2.0, 3.0, 4.0, 6.0, 8.0, 9.0]; let x = Matrix::new(2, 3, x_data);
let y = vec![3.0, 5.0];
let options = StandardizeOptions {
intercept: false, standardize_x: true,
standardize_y: false,
weights: None,
};
let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(info.y_mean, 0.0);
assert!(!info.intercept);
let y_norm: f64 = y_standardized.iter().map(|&v| v * v).sum::<f64>().sqrt();
assert!((y_norm - 1.0).abs() < 1e-10);
}
#[test]
fn test_standardize_xy_constant_y() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let y = vec![5.0, 5.0];
let options = StandardizeOptions {
intercept: true,
standardize_x: true,
standardize_y: false,
weights: None,
};
let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(y_standardized, vec![0.0, 0.0]);
assert_eq!(info.y_mean, 5.0);
assert!(info.y_scale.unwrap_or(0.0) == 0.0);
}
#[test]
fn test_unstandardize_coefficients_no_intercept() {
let x_mean = vec![0.0, 4.0, 6.0];
let x_scale = vec![1.0, 2.0, 3.0];
let column_squared_norms = vec![1.0, 1.0, 1.0];
let y_mean = 0.0;
let y_scale = Some(2.0);
let info = StandardizationInfo {
x_mean: x_mean.clone(),
x_scale: x_scale.clone(),
column_squared_norms,
y_mean,
y_scale,
y_scale_before_sqrt_weights_normalized: None,
intercept: false, standardized_x: true,
standardized_y: true,
};
let coefficients_standardized = vec![1.0, 2.0, 3.0];
let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
assert_eq!(beta0, 0.0);
assert_eq!(beta_slopes.len(), 3);
assert!((beta_slopes[0] - 2.0).abs() < 1e-10); assert!((beta_slopes[1] - (2.0 * 2.0 / 2.0)).abs() < 1e-10); assert!((beta_slopes[2] - (2.0 * 3.0 / 3.0)).abs() < 1e-10); }
#[test]
fn test_unstandardize_coefficients_no_y_scale() {
let x_mean = vec![0.0, 4.0, 6.0];
let x_scale = vec![1.0, 2.0, 3.0];
let column_squared_norms = vec![1.0, 1.0, 1.0];
let y_mean = 5.0;
let y_scale = None;
let info = StandardizationInfo {
x_mean: x_mean.clone(),
x_scale: x_scale.clone(),
column_squared_norms,
y_mean,
y_scale,
y_scale_before_sqrt_weights_normalized: None,
intercept: true,
standardized_x: true,
standardized_y: false,
};
let coefficients_standardized = vec![0.0, 1.0, 2.0];
let (_beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
assert!((beta_slopes[0] - 0.5).abs() < 1e-10); }
#[test]
fn test_predict_no_intercept_column() {
let x_data = vec![2.0, 3.0, 4.0, 6.0];
let x = Matrix::new(2, 2, x_data); let beta0 = 1.0;
let beta = vec![2.0, 3.0];
let preds = predict(&x, beta0, &beta);
assert!((preds[0] - 14.0).abs() < 1e-10);
assert!((preds[1] - 27.0).abs() < 1e-10);
}
#[test]
fn test_predict_beta_longer_than_columns() {
let x_data = vec![1.0, 2.0, 3.0];
let x = Matrix::new(1, 3, x_data);
let beta0 = 5.0;
let beta = vec![1.0, 2.0, 3.0, 4.0];
let preds = predict(&x, beta0, &beta);
assert!((preds[0] - 19.0).abs() < 1e-10);
}
#[test]
fn test_standardize_xy_no_standardize_x() {
let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
let x = Matrix::new(2, 3, x_data);
let y = vec![3.0, 5.0];
let options = StandardizeOptions {
intercept: true,
standardize_x: false, standardize_y: false,
weights: None,
};
let (x_standardized, _y_standardized, info) = standardize_xy(&x, &y, &options);
assert_eq!(x_standardized.get(0, 0), 1.0);
assert_eq!(x_standardized.get(1, 0), 1.0);
let sqrt_half = (0.5_f64).sqrt();
assert!((x_standardized.get(0, 1) - (-sqrt_half)).abs() < 1e-10);
assert!((x_standardized.get(1, 1) - sqrt_half).abs() < 1e-10);
assert_eq!(info.x_scale[1], 1.0);
assert_eq!(info.x_scale[2], 1.0);
}
}