use crate::error::ModelError;
use ndarray::{Array, ArrayBase, ArrayViewMut1, Axis, Data, Dimension};
const NORM_ZERO_THRESHOLD: f64 = 1e-15;
const NORMALIZE_PARALLEL_THRESHOLD: usize = 10000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormalizationAxis {
Row,
Column,
Global,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum NormalizationOrder {
L1,
L2,
Max,
Lp(f64),
}
pub fn normalize<S, D>(
data: &ArrayBase<S, D>,
axis: NormalizationAxis,
order: NormalizationOrder,
) -> Result<Array<f64, D>, ModelError>
where
S: Data<Elem = f64>,
D: Dimension,
{
if data.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot normalize empty array".to_string(),
));
}
if data.iter().any(|&x| !x.is_finite()) {
return Err(ModelError::InputValidationError(
"Input contains NaN or infinite values".to_string(),
));
}
if let NormalizationOrder::Lp(p) = order {
if p <= 0.0 || !p.is_finite() {
return Err(ModelError::InputValidationError(
"Lp norm parameter must be positive and finite".to_string(),
));
}
}
let mut result = data.to_owned();
match axis {
NormalizationAxis::Global => {
normalize_global(&mut result, order)?;
}
NormalizationAxis::Row => {
normalize_by_rows(&mut result, order)?;
}
NormalizationAxis::Column => {
normalize_by_columns(&mut result, order)?;
}
}
Ok(result)
}
fn normalize_array_view(mut array: ArrayViewMut1<f64>, norm: f64) {
if norm > NORM_ZERO_THRESHOLD {
if array.len() >= NORMALIZE_PARALLEL_THRESHOLD {
array.par_mapv_inplace(|x| x / norm);
} else {
array.mapv_inplace(|x| x / norm);
}
} else {
array.fill(0.0);
}
}
fn normalize_global<D>(
data: &mut Array<f64, D>,
order: NormalizationOrder,
) -> Result<(), ModelError>
where
D: Dimension,
{
let norm = compute_norm(data.view().into_iter().copied(), order)?;
if norm > NORM_ZERO_THRESHOLD {
if data.len() >= NORMALIZE_PARALLEL_THRESHOLD {
data.par_mapv_inplace(|x| x / norm);
} else {
data.mapv_inplace(|x| x / norm);
}
}
Ok(())
}
fn normalize_by_rows<D>(
data: &mut Array<f64, D>,
order: NormalizationOrder,
) -> Result<(), ModelError>
where
D: Dimension,
{
let ndim = data.ndim();
if ndim < 2 {
return Err(ModelError::InputValidationError(
"Row normalization requires at least 2 dimensions".to_string(),
));
}
let last_axis = Axis(ndim - 1);
for row in data.lanes_mut(last_axis) {
let norm = compute_norm(row.iter().copied(), order)?;
normalize_array_view(row, norm);
}
Ok(())
}
fn normalize_by_columns<D>(
data: &mut Array<f64, D>,
order: NormalizationOrder,
) -> Result<(), ModelError>
where
D: Dimension,
{
let ndim = data.ndim();
if ndim < 2 {
return Err(ModelError::InputValidationError(
"Column normalization requires at least 2 dimensions".to_string(),
));
}
let axis = Axis(ndim - 2);
for col in data.lanes_mut(axis) {
let norm = compute_norm(col.iter().copied(), order)?;
normalize_array_view(col, norm);
}
Ok(())
}
fn compute_norm<I>(values: I, order: NormalizationOrder) -> Result<f64, ModelError>
where
I: Iterator<Item = f64>,
{
match order {
NormalizationOrder::L1 => {
let norm: f64 = values.map(|x| x.abs()).sum();
if norm.is_finite() {
Ok(norm)
} else {
Err(ModelError::ProcessingError(
"L1 norm computation resulted in non-finite value".to_string(),
))
}
}
NormalizationOrder::L2 => {
let norm_squared: f64 = values.map(|x| x * x).sum();
if norm_squared.is_finite() && norm_squared >= 0.0 {
Ok(norm_squared.sqrt())
} else {
Err(ModelError::ProcessingError(
"L2 norm computation resulted in non-finite value".to_string(),
))
}
}
NormalizationOrder::Max => {
let norm = values.map(|x| x.abs()).fold(f64::NEG_INFINITY, f64::max);
if norm.is_finite() && norm >= 0.0 {
Ok(norm)
} else if norm == f64::NEG_INFINITY {
Ok(0.0)
} else {
Err(ModelError::ProcessingError(
"Max norm computation resulted in non-finite value".to_string(),
))
}
}
NormalizationOrder::Lp(p) => {
let sum: f64 = values.map(|x| x.abs().powf(p)).sum();
if sum.is_finite() && sum >= 0.0 {
Ok(sum.powf(1.0 / p))
} else {
Err(ModelError::ProcessingError(format!(
"Lp norm (p={}) computation resulted in non-finite value",
p
)))
}
}
}
}