use crate::dataset::Dataset;
use crate::error::{Result, ScryLearnError};
use crate::preprocess::Transformer;
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Norm {
L1,
L2,
Max,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Normalizer {
norm: Norm,
#[cfg_attr(feature = "serde", serde(default))]
_schema_version: u32,
}
impl Normalizer {
pub fn new(norm: Norm) -> Self {
Self {
norm,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
pub fn l2() -> Self {
Self {
norm: Norm::L2,
_schema_version: crate::version::SCHEMA_VERSION,
}
}
}
impl Default for Normalizer {
fn default() -> Self {
Self::l2()
}
}
impl Transformer for Normalizer {
fn fit(&mut self, data: &Dataset) -> Result<()> {
data.validate_finite()?;
if data.n_samples() == 0 {
return Err(ScryLearnError::EmptyDataset);
}
Ok(())
}
fn transform(&self, data: &mut Dataset) -> Result<()> {
crate::version::check_schema_version(self._schema_version)?;
let n = data.n_samples();
let m = data.n_features();
for i in 0..n {
let norm_val = match self.norm {
Norm::L1 => {
let mut s = 0.0_f64;
for col in &data.features {
s += col[i].abs();
}
s
}
Norm::L2 => {
let mut s = 0.0_f64;
for col in &data.features {
s += col[i] * col[i];
}
s.sqrt()
}
Norm::Max => {
let mut mx = 0.0_f64;
for col in &data.features {
mx = mx.max(col[i].abs());
}
mx
}
};
if norm_val > 1e-12 {
for j in 0..m {
data.features[j][i] /= norm_val;
}
}
}
data.sync_matrix();
Ok(())
}
fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
Err(ScryLearnError::InvalidParameter(
"Normalizer is not invertible (row norms are lost)".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ds(rows: &[Vec<f64>]) -> Dataset {
let n = rows.len();
let m = rows[0].len();
let mut features = vec![vec![0.0; n]; m];
for (i, row) in rows.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
features[j][i] = val;
}
}
let names: Vec<String> = (0..m).map(|j| format!("f{j}")).collect();
Dataset::new(features, vec![0.0; n], names, "y")
}
#[test]
fn test_normalizer_l2_unit_norm() {
let mut ds = make_ds(&[vec![3.0, 4.0], vec![1.0, 0.0]]);
let mut norm = Normalizer::new(Norm::L2);
norm.fit_transform(&mut ds).unwrap();
assert!((ds.features[0][0] - 0.6).abs() < 1e-10);
assert!((ds.features[1][0] - 0.8).abs() < 1e-10);
for i in 0..ds.n_samples() {
let mut sq_sum = 0.0;
for col in &ds.features {
sq_sum += col[i] * col[i];
}
assert!(
(sq_sum - 1.0).abs() < 1e-10,
"row {i} L2 norm² = {sq_sum}, expected 1.0"
);
}
}
#[test]
fn test_normalizer_l1() {
let mut ds = make_ds(&[vec![1.0, 2.0, 3.0]]);
let mut norm = Normalizer::new(Norm::L1);
norm.fit_transform(&mut ds).unwrap();
let abs_sum: f64 = ds.features.iter().map(|c| c[0].abs()).sum();
assert!(
(abs_sum - 1.0).abs() < 1e-10,
"L1 norm should be 1.0, got {abs_sum}"
);
}
#[test]
fn test_normalizer_max() {
let mut ds = make_ds(&[vec![-5.0, 2.0, 3.0]]);
let mut norm = Normalizer::new(Norm::Max);
norm.fit_transform(&mut ds).unwrap();
assert!((ds.features[0][0] - (-1.0)).abs() < 1e-10);
let max_abs: f64 = ds
.features
.iter()
.map(|c| c[0].abs())
.fold(0.0_f64, f64::max);
assert!(
(max_abs - 1.0).abs() < 1e-10,
"Max norm should be 1.0, got {max_abs}"
);
}
#[test]
fn test_normalizer_zero_row() {
let mut ds = make_ds(&[vec![0.0, 0.0]]);
let mut norm = Normalizer::new(Norm::L2);
norm.fit_transform(&mut ds).unwrap();
for col in &ds.features {
assert!((col[0]).abs() < 1e-10);
}
}
}