use crate::error::ModelError;
use ndarray::{Array, ArrayBase, ArrayViewMut1, Axis, Data, Dimension};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rayon::prelude::IntoParallelRefMutIterator;
const STANDARDIZE_PARALLEL_THRESHOLD: usize = 10000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StandardizationAxis {
Row,
Column,
Global,
}
pub fn standardize<S, D>(
data: &ArrayBase<S, D>,
axis: StandardizationAxis,
epsilon: f64,
) -> Result<Array<f64, D>, ModelError>
where
S: Data<Elem = f64>,
D: Dimension,
{
if data.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot standardize empty array".to_string(),
));
}
if data.iter().any(|&x| !x.is_finite()) {
return Err(ModelError::InputValidationError(
"Input contains NaN or infinite values".to_string(),
));
}
if epsilon <= 0.0 || !epsilon.is_finite() {
return Err(ModelError::InputValidationError(
"Epsilon must be positive and finite".to_string(),
));
}
let mut result = data.to_owned();
match axis {
StandardizationAxis::Global => {
standardize_global(&mut result, epsilon)?;
}
StandardizationAxis::Row => {
standardize_by_rows(&mut result, epsilon)?;
}
StandardizationAxis::Column => {
standardize_by_columns(&mut result, epsilon)?;
}
}
Ok(result)
}
fn standardize_global<D>(data: &mut Array<f64, D>, epsilon: f64) -> Result<(), ModelError>
where
D: Dimension,
{
let n = data.len() as f64;
if n == 0.0 {
return Err(ModelError::ProcessingError(
"No values to standardize".to_string(),
));
}
if n as usize >= STANDARDIZE_PARALLEL_THRESHOLD {
let mean = data.par_iter().sum::<f64>() / n;
let variance = data.par_iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
let std_dev = (variance + epsilon * epsilon).sqrt();
data.par_mapv_inplace(|x| (x - mean) / std_dev);
} else {
let mean = data.iter().sum::<f64>() / n;
let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
let std_dev = (variance + epsilon * epsilon).sqrt();
data.mapv_inplace(|x| (x - mean) / std_dev);
}
Ok(())
}
fn compute_mean_and_std(values: &[f64], epsilon: f64) -> (f64, f64) {
let n = values.len() as f64;
if n == 0.0 {
return (0.0, epsilon);
}
if values.len() >= STANDARDIZE_PARALLEL_THRESHOLD {
let mean = values.par_iter().sum::<f64>() / n;
let variance = values.par_iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
let std_dev = (variance + epsilon * epsilon).sqrt();
(mean, std_dev)
} else {
let mean = values.iter().sum::<f64>() / n;
let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
let std_dev = (variance + epsilon * epsilon).sqrt();
(mean, std_dev)
}
}
fn standardize_lanes<D>(
data: &mut Array<f64, D>,
axis: Axis,
epsilon: f64,
operation_name: &str,
) -> Result<(), ModelError>
where
D: Dimension,
{
let ndim = data.ndim();
if ndim < 2 {
return Err(ModelError::InputValidationError(format!(
"{} requires at least 2 dimensions",
operation_name
)));
}
let mut lanes: Vec<_> = data.lanes_mut(axis).into_iter().collect();
let process_lane = |lane: &mut ArrayViewMut1<f64>| -> Result<(), ModelError> {
let values: Vec<f64> = lane.iter().copied().collect();
let (mean, std_dev) = compute_mean_and_std(&values, epsilon);
lane.mapv_inplace(|x| (x - mean) / std_dev);
Ok(())
};
if lanes.len() >= STANDARDIZE_PARALLEL_THRESHOLD / 100 {
lanes.par_iter_mut().try_for_each(process_lane)?;
} else {
lanes.iter_mut().try_for_each(process_lane)?;
}
Ok(())
}
fn standardize_by_rows<D>(data: &mut Array<f64, D>, epsilon: f64) -> Result<(), ModelError>
where
D: Dimension,
{
let ndim = data.ndim();
let last_axis = Axis(ndim - 1);
standardize_lanes(data, last_axis, epsilon, "Row standardization")
}
fn standardize_by_columns<D>(data: &mut Array<f64, D>, epsilon: f64) -> Result<(), ModelError>
where
D: Dimension,
{
let ndim = data.ndim();
if ndim < 2 {
return Err(ModelError::InputValidationError(
"Column standardization requires at least 2 dimensions".to_string(),
));
}
let axis = Axis(ndim - 2);
standardize_lanes(data, axis, epsilon, "Column standardization")
}