use crate::error::{Error, Result};
use crate::feature_importance::types::StandardizedCoefficientsOutput;
use crate::stats::stddev;
pub fn standardized_coefficients(
coefficients: &[f64],
x_vars: &[Vec<f64>],
) -> Result<StandardizedCoefficientsOutput> {
let n_predictors = x_vars.len();
if coefficients.len() != n_predictors + 1 {
return Err(Error::InvalidInput(format!(
"coefficients length ({}) must equal x_vars length + 1 ({})",
coefficients.len(),
n_predictors + 1
)));
}
for (_i, var) in x_vars.iter().enumerate() {
if var.len() < 2 {
return Err(Error::InsufficientData {
required: 2,
available: var.len(),
});
}
}
let y_std = 1.0;
let mut standardized_coefficients = Vec::with_capacity(n_predictors);
let mut variable_names = Vec::with_capacity(n_predictors);
for (i, x_col) in x_vars.iter().enumerate() {
let x_std = stddev(x_col);
if !x_std.is_finite() || x_std == 0.0 {
return Err(Error::InvalidInput(format!(
"Predictor X{} has zero or invalid standard deviation",
i + 1
)));
}
let coef = coefficients[i + 1];
let std_coef = coef * (x_std / y_std);
standardized_coefficients.push(std_coef);
variable_names.push(format!("X{}", i + 1));
}
Ok(StandardizedCoefficientsOutput {
variable_names,
standardized_coefficients,
y_std,
})
}
pub fn standardized_coefficients_named(
coefficients: &[f64],
x_vars: &[Vec<f64>],
variable_names: &[String],
y_std: f64,
) -> Result<StandardizedCoefficientsOutput> {
let n_predictors = x_vars.len();
if coefficients.len() != n_predictors + 1 {
return Err(Error::InvalidInput(format!(
"coefficients length ({}) must equal x_vars length + 1 ({})",
coefficients.len(),
n_predictors + 1
)));
}
if variable_names.len() != n_predictors {
return Err(Error::InvalidInput(format!(
"variable_names length ({}) must equal x_vars length ({})",
variable_names.len(),
n_predictors
)));
}
if y_std <= 0.0 || !y_std.is_finite() {
return Err(Error::InvalidInput(
"y_std must be positive and finite".to_string(),
));
}
for (_i, var) in x_vars.iter().enumerate() {
if var.len() < 2 {
return Err(Error::InsufficientData {
required: 2,
available: var.len(),
});
}
}
let mut standardized_coefficients = Vec::with_capacity(n_predictors);
for (i, x_col) in x_vars.iter().enumerate() {
let x_std = stddev(x_col);
if !x_std.is_finite() || x_std == 0.0 {
return Err(Error::InvalidInput(format!(
"Predictor {} has zero or invalid standard deviation",
variable_names[i]
)));
}
let coef = coefficients[i + 1];
let std_coef = coef * (x_std / y_std);
standardized_coefficients.push(std_coef);
}
Ok(StandardizedCoefficientsOutput {
variable_names: variable_names.to_vec(),
standardized_coefficients,
y_std,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standardized_coefficients_basic() {
let coefficients = vec![1.0, 2.0, 0.5];
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let result = standardized_coefficients(&coefficients, &[x1, x2]).unwrap();
assert_eq!(result.variable_names, vec!["X1", "X2"]);
assert_eq!(result.standardized_coefficients.len(), 2);
assert!(result.y_std > 0.0);
let x1_std = stddev(&vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let x2_std = stddev(&vec![10.0, 20.0, 30.0, 40.0, 50.0]);
assert_eq!(
result.standardized_coefficients[0],
2.0 * x1_std
);
assert_eq!(
result.standardized_coefficients[1],
0.5 * x2_std
);
}
#[test]
fn test_standardized_coefficients_named() {
let coefficients = vec![1.0, 2.0, 0.5];
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let names = vec!["Temp".to_string(), "Pressure".to_string()];
let y_std = 2.5;
let result = standardized_coefficients_named(
&coefficients,
&[x1, x2],
&names,
y_std,
).unwrap();
assert_eq!(result.variable_names, names);
assert_eq!(result.y_std, y_std);
let x1_std = stddev(&vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let expected = 2.0 * x1_std / y_std;
assert!((result.standardized_coefficients[0] - expected).abs() < 1e-10);
}
#[test]
fn test_standardized_coefficients_ranking() {
let coefficients = vec![1.0, 0.5, -0.8, 0.1];
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let x3 = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let result = standardized_coefficients(&coefficients, &[x1, x2, x3]).unwrap();
let ranking = result.ranking();
assert_eq!(ranking[0].0, "X3");
assert!(ranking[0].1 > ranking[1].1);
}
#[test]
fn test_standardized_coefficients_invalid_input() {
let coefficients = vec![1.0, 2.0]; let x1 = vec![1.0, 2.0, 3.0];
let x2 = vec![1.0, 2.0, 3.0];
let x3 = vec![1.0, 2.0, 3.0];
let result = standardized_coefficients(&coefficients, &[x1, x2, x3]);
assert!(result.is_err());
}
#[test]
fn test_standardized_coefficients_insufficient_data() {
let coefficients = vec![1.0, 2.0];
let x1 = vec![1.0];
let result = standardized_coefficients(&coefficients, &[x1]);
assert!(result.is_err());
}
#[test]
fn test_standardized_coefficients_constant_predictor() {
let coefficients = vec![1.0, 2.0];
let x1 = vec![5.0, 5.0, 5.0, 5.0];
let result = standardized_coefficients(&coefficients, &[x1]);
assert!(result.is_err());
}
}