Skip to main content

anofox_ml_preprocessing/
normalizer.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4/// The type of norm used to normalize each sample (row).
5#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6pub enum NormType {
7    /// L1 norm: sum of absolute values.
8    L1,
9    /// L2 norm (Euclidean): square root of sum of squares.
10    L2,
11    /// Max norm: maximum absolute value.
12    Max,
13}
14
15/// Parameters for Normalizer (unfitted state).
16///
17/// Normalizes each **row** (sample) independently so that it has unit norm
18/// according to the chosen [`NormType`]. Unlike most scalers this operates
19/// per-sample, not per-feature.
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct Normalizer {
22    /// The type of norm to apply.
23    pub norm: NormType,
24}
25
26impl Normalizer {
27    /// Create a new `Normalizer` with the default L2 norm.
28    pub fn new() -> Self {
29        Self { norm: NormType::L2 }
30    }
31
32    /// Set the norm type.
33    pub fn with_norm(mut self, norm: NormType) -> Self {
34        self.norm = norm;
35        self
36    }
37}
38
39impl Default for Normalizer {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45/// Fitted Normalizer — stateless (fit is a validation-only no-op).
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
48pub struct FittedNormalizer<F: Float> {
49    norm: NormType,
50    _marker: std::marker::PhantomData<F>,
51}
52
53impl<F: Float> FitUnsupervised<F> for Normalizer {
54    type Fitted = FittedNormalizer<F>;
55
56    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
57        if x.is_empty() {
58            return Err(RustMlError::EmptyInput("input array is empty".into()));
59        }
60
61        Ok(FittedNormalizer {
62            norm: self.norm,
63            _marker: std::marker::PhantomData,
64        })
65    }
66}
67
68impl<F: Float> Transform<F> for FittedNormalizer<F> {
69    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
70        let eps = F::from_f64(1e-15).unwrap();
71        let mut result = x.to_owned();
72
73        for mut row in result.rows_mut() {
74            let norm = match self.norm {
75                NormType::L1 => {
76                    let mut s = F::zero();
77                    for &v in row.iter() {
78                        s = s + v.abs();
79                    }
80                    s
81                }
82                NormType::L2 => {
83                    let mut s = F::zero();
84                    for &v in row.iter() {
85                        s = s + v * v;
86                    }
87                    s.sqrt()
88                }
89                NormType::Max => {
90                    let mut m = F::zero();
91                    for &v in row.iter() {
92                        let a = v.abs();
93                        if a > m {
94                            m = a;
95                        }
96                    }
97                    m
98                }
99            };
100
101            if norm > eps {
102                for val in row.iter_mut() {
103                    *val = *val / norm;
104                }
105            }
106        }
107        Ok(result)
108    }
109}
110
111impl<F: Float> FittedNormalizer<F> {
112    /// Return the norm type used for normalization.
113    pub fn norm(&self) -> NormType {
114        self.norm
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use approx::assert_abs_diff_eq;
122    use ndarray::array;
123
124    #[test]
125    fn test_l2_unit_norm() {
126        let x = array![[3.0, 4.0], [1.0, 0.0], [0.0, 0.0,]];
127        let normalizer = Normalizer::new();
128        let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
129        let transformed = fitted.transform(&x).unwrap();
130
131        // Row 0: norm = 5, so [3/5, 4/5]
132        assert_abs_diff_eq!(transformed[[0, 0]], 0.6, epsilon = 1e-10);
133        assert_abs_diff_eq!(transformed[[0, 1]], 0.8, epsilon = 1e-10);
134
135        // Check each non-zero row has unit L2 norm
136        for row_idx in 0..2 {
137            let row = transformed.row(row_idx);
138            let norm: f64 = row.iter().map(|&v| v * v).sum::<f64>().sqrt();
139            assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
140        }
141    }
142
143    #[test]
144    fn test_l1_norm() {
145        let x = array![[1.0, -2.0, 3.0], [4.0, 0.0, -1.0]];
146        let normalizer = Normalizer::new().with_norm(NormType::L1);
147        let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
148        let transformed = fitted.transform(&x).unwrap();
149
150        // Each row's absolute values should sum to 1
151        for row_idx in 0..transformed.nrows() {
152            let row = transformed.row(row_idx);
153            let l1: f64 = row.iter().map(|&v| v.abs()).sum();
154            assert_abs_diff_eq!(l1, 1.0, epsilon = 1e-10);
155        }
156
157        // Row 0: L1 = |1|+|-2|+|3| = 6, so [1/6, -2/6, 3/6]
158        assert_abs_diff_eq!(transformed[[0, 0]], 1.0 / 6.0, epsilon = 1e-10);
159        assert_abs_diff_eq!(transformed[[0, 1]], -2.0 / 6.0, epsilon = 1e-10);
160    }
161
162    #[test]
163    fn test_max_norm() {
164        let x = array![[1.0, -3.0, 2.0], [0.5, 0.0, -4.0]];
165        let normalizer = Normalizer::new().with_norm(NormType::Max);
166        let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
167        let transformed = fitted.transform(&x).unwrap();
168
169        // Each row's max absolute value should be 1
170        for row_idx in 0..transformed.nrows() {
171            let row = transformed.row(row_idx);
172            let max_abs: f64 = row.iter().map(|&v| v.abs()).fold(0.0, f64::max);
173            assert_abs_diff_eq!(max_abs, 1.0, epsilon = 1e-10);
174        }
175
176        // Row 0: max_abs = 3, so [1/3, -1, 2/3]
177        assert_abs_diff_eq!(transformed[[0, 1]], -1.0, epsilon = 1e-10);
178    }
179
180    #[test]
181    fn test_zero_row_handled() {
182        // A row of all zeros should remain all zeros (no division by zero)
183        let x = array![[0.0, 0.0], [3.0, 4.0]];
184        let normalizer = Normalizer::new();
185        let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
186        let transformed = fitted.transform(&x).unwrap();
187
188        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
189        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
190
191        for &v in transformed.iter() {
192            assert!(v.is_finite(), "zero row produced non-finite: {}", v);
193        }
194    }
195
196    #[test]
197    fn test_f32_support() {
198        let x = array![[3.0f32, 4.0], [1.0, 0.0]];
199        let normalizer = Normalizer::new();
200        let fitted = FitUnsupervised::<f32>::fit(&normalizer, &x).unwrap();
201        let transformed = fitted.transform(&x).unwrap();
202
203        let row = transformed.row(0);
204        let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt();
205        assert_abs_diff_eq!(norm, 1.0f32, epsilon = 1e-5);
206    }
207
208    #[test]
209    fn test_empty_input() {
210        let x: Array2<f64> = Array2::zeros((0, 0));
211        let normalizer = Normalizer::new();
212        let result = FitUnsupervised::<f64>::fit(&normalizer, &x);
213        assert!(result.is_err());
214    }
215}