1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6pub enum NormType {
7 L1,
9 L2,
11 Max,
13}
14
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct Normalizer {
22 pub norm: NormType,
24}
25
26impl Normalizer {
27 pub fn new() -> Self {
29 Self { norm: NormType::L2 }
30 }
31
32 pub fn with_norm(mut self, norm: NormType) -> Self {
34 self.norm = norm;
35 self
36 }
37}
38
39impl Default for Normalizer {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
48pub struct FittedNormalizer<F: Float> {
49 norm: NormType,
50 _marker: std::marker::PhantomData<F>,
51}
52
53impl<F: Float> FitUnsupervised<F> for Normalizer {
54 type Fitted = FittedNormalizer<F>;
55
56 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
57 if x.is_empty() {
58 return Err(RustMlError::EmptyInput("input array is empty".into()));
59 }
60
61 Ok(FittedNormalizer {
62 norm: self.norm,
63 _marker: std::marker::PhantomData,
64 })
65 }
66}
67
68impl<F: Float> Transform<F> for FittedNormalizer<F> {
69 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
70 let eps = F::from_f64(1e-15).unwrap();
71 let mut result = x.to_owned();
72
73 for mut row in result.rows_mut() {
74 let norm = match self.norm {
75 NormType::L1 => {
76 let mut s = F::zero();
77 for &v in row.iter() {
78 s = s + v.abs();
79 }
80 s
81 }
82 NormType::L2 => {
83 let mut s = F::zero();
84 for &v in row.iter() {
85 s = s + v * v;
86 }
87 s.sqrt()
88 }
89 NormType::Max => {
90 let mut m = F::zero();
91 for &v in row.iter() {
92 let a = v.abs();
93 if a > m {
94 m = a;
95 }
96 }
97 m
98 }
99 };
100
101 if norm > eps {
102 for val in row.iter_mut() {
103 *val = *val / norm;
104 }
105 }
106 }
107 Ok(result)
108 }
109}
110
111impl<F: Float> FittedNormalizer<F> {
112 pub fn norm(&self) -> NormType {
114 self.norm
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use approx::assert_abs_diff_eq;
122 use ndarray::array;
123
124 #[test]
125 fn test_l2_unit_norm() {
126 let x = array![[3.0, 4.0], [1.0, 0.0], [0.0, 0.0,]];
127 let normalizer = Normalizer::new();
128 let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
129 let transformed = fitted.transform(&x).unwrap();
130
131 assert_abs_diff_eq!(transformed[[0, 0]], 0.6, epsilon = 1e-10);
133 assert_abs_diff_eq!(transformed[[0, 1]], 0.8, epsilon = 1e-10);
134
135 for row_idx in 0..2 {
137 let row = transformed.row(row_idx);
138 let norm: f64 = row.iter().map(|&v| v * v).sum::<f64>().sqrt();
139 assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
140 }
141 }
142
143 #[test]
144 fn test_l1_norm() {
145 let x = array![[1.0, -2.0, 3.0], [4.0, 0.0, -1.0]];
146 let normalizer = Normalizer::new().with_norm(NormType::L1);
147 let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
148 let transformed = fitted.transform(&x).unwrap();
149
150 for row_idx in 0..transformed.nrows() {
152 let row = transformed.row(row_idx);
153 let l1: f64 = row.iter().map(|&v| v.abs()).sum();
154 assert_abs_diff_eq!(l1, 1.0, epsilon = 1e-10);
155 }
156
157 assert_abs_diff_eq!(transformed[[0, 0]], 1.0 / 6.0, epsilon = 1e-10);
159 assert_abs_diff_eq!(transformed[[0, 1]], -2.0 / 6.0, epsilon = 1e-10);
160 }
161
162 #[test]
163 fn test_max_norm() {
164 let x = array![[1.0, -3.0, 2.0], [0.5, 0.0, -4.0]];
165 let normalizer = Normalizer::new().with_norm(NormType::Max);
166 let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
167 let transformed = fitted.transform(&x).unwrap();
168
169 for row_idx in 0..transformed.nrows() {
171 let row = transformed.row(row_idx);
172 let max_abs: f64 = row.iter().map(|&v| v.abs()).fold(0.0, f64::max);
173 assert_abs_diff_eq!(max_abs, 1.0, epsilon = 1e-10);
174 }
175
176 assert_abs_diff_eq!(transformed[[0, 1]], -1.0, epsilon = 1e-10);
178 }
179
180 #[test]
181 fn test_zero_row_handled() {
182 let x = array![[0.0, 0.0], [3.0, 4.0]];
184 let normalizer = Normalizer::new();
185 let fitted = FitUnsupervised::<f64>::fit(&normalizer, &x).unwrap();
186 let transformed = fitted.transform(&x).unwrap();
187
188 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
189 assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
190
191 for &v in transformed.iter() {
192 assert!(v.is_finite(), "zero row produced non-finite: {}", v);
193 }
194 }
195
196 #[test]
197 fn test_f32_support() {
198 let x = array![[3.0f32, 4.0], [1.0, 0.0]];
199 let normalizer = Normalizer::new();
200 let fitted = FitUnsupervised::<f32>::fit(&normalizer, &x).unwrap();
201 let transformed = fitted.transform(&x).unwrap();
202
203 let row = transformed.row(0);
204 let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt();
205 assert_abs_diff_eq!(norm, 1.0f32, epsilon = 1e-5);
206 }
207
208 #[test]
209 fn test_empty_input() {
210 let x: Array2<f64> = Array2::zeros((0, 0));
211 let normalizer = Normalizer::new();
212 let result = FitUnsupervised::<f64>::fit(&normalizer, &x);
213 assert!(result.is_err());
214 }
215}