use ferrolearn_core::error::FerroError;
use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
use ferrolearn_core::traits::Transform;
use ndarray::{Array1, Array2};
use num_traits::Float;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NormType {
L1,
#[default]
L2,
Max,
}
#[derive(Debug, Clone)]
pub struct Normalizer<F> {
pub(crate) norm: NormType,
_marker: std::marker::PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> Normalizer<F> {
#[must_use]
pub fn new(norm: NormType) -> Self {
Self {
norm,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn l2() -> Self {
Self::new(NormType::L2)
}
#[must_use]
pub fn l1() -> Self {
Self::new(NormType::L1)
}
#[must_use]
pub fn max() -> Self {
Self::new(NormType::Max)
}
#[must_use]
pub fn norm(&self) -> NormType {
self.norm
}
}
impl<F: Float + Send + Sync + 'static> Default for Normalizer<F> {
fn default() -> Self {
Self::new(NormType::L2)
}
}
impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for Normalizer<F> {
type Output = Array2<F>;
type Error = FerroError;
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let mut out = x.to_owned();
for mut row in out.rows_mut() {
let norm_val =
match self.norm {
NormType::L1 => row.iter().copied().fold(F::zero(), |acc, v| acc + v.abs()),
NormType::L2 => row
.iter()
.copied()
.fold(F::zero(), |acc, v| acc + v * v)
.sqrt(),
NormType::Max => row.iter().copied().fold(F::zero(), |acc, v| {
if v.abs() > acc { v.abs() } else { acc }
}),
};
if norm_val == F::zero() {
continue;
}
for v in row.iter_mut() {
*v = *v / norm_val;
}
}
Ok(out)
}
}
impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for Normalizer<F> {
fn fit_pipeline(
&self,
_x: &Array2<F>,
_y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
Ok(Box::new(self.clone()))
}
}
impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for Normalizer<F> {
fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
self.transform(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_l2_norm_basic() {
let norm = Normalizer::<f64>::l2();
let x = array![[3.0, 4.0]];
let out = norm.transform(&x).unwrap();
assert_abs_diff_eq!(out[[0, 0]], 0.6, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], 0.8, epsilon = 1e-10);
}
#[test]
fn test_l2_unit_norm_after_transform() {
let norm = Normalizer::<f64>::l2();
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let out = norm.transform(&x).unwrap();
for row in out.rows() {
let row_norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_l1_norm_basic() {
let norm = Normalizer::<f64>::l1();
let x = array![[1.0, 2.0, 3.0]];
let out = norm.transform(&x).unwrap();
assert_abs_diff_eq!(out[[0, 0]], 1.0 / 6.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], 2.0 / 6.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 2]], 3.0 / 6.0, epsilon = 1e-10);
}
#[test]
fn test_l1_unit_norm_after_transform() {
let norm = Normalizer::<f64>::l1();
let x = array![[1.0, 2.0, 3.0], [-4.0, 5.0, 6.0]];
let out = norm.transform(&x).unwrap();
for row in out.rows() {
let row_norm: f64 = row.iter().map(|v| v.abs()).sum();
assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_max_norm_basic() {
let norm = Normalizer::<f64>::max();
let x = array![[-5.0, 3.0, 1.0]];
let out = norm.transform(&x).unwrap();
assert_abs_diff_eq!(out[[0, 0]], -1.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], 0.6, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 2]], 0.2, epsilon = 1e-10);
}
#[test]
fn test_zero_row_unchanged() {
let norm = Normalizer::<f64>::l2();
let x = array![[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]];
let out = norm.transform(&x).unwrap();
assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-15);
assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-15);
assert_abs_diff_eq!(out[[0, 2]], 0.0, epsilon = 1e-15);
}
#[test]
fn test_negative_values_l2() {
let norm = Normalizer::<f64>::l2();
let x = array![[-3.0, -4.0]];
let out = norm.transform(&x).unwrap();
assert_abs_diff_eq!(out[[0, 0]], -0.6, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], -0.8, epsilon = 1e-10);
}
#[test]
fn test_default_is_l2() {
let norm = Normalizer::<f64>::default();
assert_eq!(norm.norm(), NormType::L2);
}
#[test]
fn test_multiple_rows_independent() {
let norm = Normalizer::<f64>::l2();
let x = array![[3.0, 4.0], [0.0, 5.0]];
let out = norm.transform(&x).unwrap();
assert_abs_diff_eq!(out[[0, 0]], 0.6, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], 0.8, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 1]], 1.0, epsilon = 1e-10);
}
#[test]
fn test_pipeline_integration() {
use ferrolearn_core::pipeline::PipelineTransformer;
let norm = Normalizer::<f64>::l2();
let x = array![[3.0, 4.0], [0.0, 2.0]];
let y = Array1::zeros(2);
let fitted = norm.fit_pipeline(&x, &y).unwrap();
let result = fitted.transform_pipeline(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], 0.6, epsilon = 1e-10);
assert_abs_diff_eq!(result[[0, 1]], 0.8, epsilon = 1e-10);
}
#[test]
fn test_f32_normalizer() {
let norm = Normalizer::<f32>::l2();
let x: Array2<f32> = array![[3.0f32, 4.0]];
let out = norm.transform(&x).unwrap();
assert!((out[[0, 0]] - 0.6f32).abs() < 1e-6);
assert!((out[[0, 1]] - 0.8f32).abs() < 1e-6);
}
}