1use ferrolearn_core::error::FerroError;
17use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
18use ferrolearn_core::traits::Transform;
19use ndarray::{Array1, Array2};
20use num_traits::Float;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
28pub enum NormType {
29 L1,
31 #[default]
33 L2,
34 Max,
36}
37
38#[derive(Debug, Clone)]
63pub struct Normalizer<F> {
64 pub(crate) norm: NormType,
66 _marker: std::marker::PhantomData<F>,
67}
68
69impl<F: Float + Send + Sync + 'static> Normalizer<F> {
70 #[must_use]
72 pub fn new(norm: NormType) -> Self {
73 Self {
74 norm,
75 _marker: std::marker::PhantomData,
76 }
77 }
78
79 #[must_use]
81 pub fn l2() -> Self {
82 Self::new(NormType::L2)
83 }
84
85 #[must_use]
87 pub fn l1() -> Self {
88 Self::new(NormType::L1)
89 }
90
91 #[must_use]
93 pub fn max() -> Self {
94 Self::new(NormType::Max)
95 }
96
97 #[must_use]
99 pub fn norm(&self) -> NormType {
100 self.norm
101 }
102}
103
104impl<F: Float + Send + Sync + 'static> Default for Normalizer<F> {
105 fn default() -> Self {
106 Self::new(NormType::L2)
107 }
108}
109
110impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for Normalizer<F> {
115 type Output = Array2<F>;
116 type Error = FerroError;
117
118 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
127 let mut out = x.to_owned();
128 for mut row in out.rows_mut() {
129 let norm_val =
130 match self.norm {
131 NormType::L1 => row.iter().copied().fold(F::zero(), |acc, v| acc + v.abs()),
132 NormType::L2 => row
133 .iter()
134 .copied()
135 .fold(F::zero(), |acc, v| acc + v * v)
136 .sqrt(),
137 NormType::Max => row.iter().copied().fold(F::zero(), |acc, v| {
138 if v.abs() > acc { v.abs() } else { acc }
139 }),
140 };
141 if norm_val == F::zero() {
142 continue;
144 }
145 for v in row.iter_mut() {
146 *v = *v / norm_val;
147 }
148 }
149 Ok(out)
150 }
151}
152
153impl PipelineTransformer for Normalizer<f64> {
158 fn fit_pipeline(
167 &self,
168 _x: &Array2<f64>,
169 _y: &Array1<f64>,
170 ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
171 Ok(Box::new(self.clone()))
172 }
173}
174
175impl FittedPipelineTransformer for Normalizer<f64> {
176 fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
182 self.transform(x)
183 }
184}
185
186#[cfg(test)]
191mod tests {
192 use super::*;
193 use approx::assert_abs_diff_eq;
194 use ndarray::array;
195
196 #[test]
197 fn test_l2_norm_basic() {
198 let norm = Normalizer::<f64>::l2();
199 let x = array![[3.0, 4.0]];
201 let out = norm.transform(&x).unwrap();
202 assert_abs_diff_eq!(out[[0, 0]], 0.6, epsilon = 1e-10);
203 assert_abs_diff_eq!(out[[0, 1]], 0.8, epsilon = 1e-10);
204 }
205
206 #[test]
207 fn test_l2_unit_norm_after_transform() {
208 let norm = Normalizer::<f64>::l2();
209 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
210 let out = norm.transform(&x).unwrap();
211 for row in out.rows() {
212 let row_norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
213 assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
214 }
215 }
216
217 #[test]
218 fn test_l1_norm_basic() {
219 let norm = Normalizer::<f64>::l1();
220 let x = array![[1.0, 2.0, 3.0]];
222 let out = norm.transform(&x).unwrap();
223 assert_abs_diff_eq!(out[[0, 0]], 1.0 / 6.0, epsilon = 1e-10);
224 assert_abs_diff_eq!(out[[0, 1]], 2.0 / 6.0, epsilon = 1e-10);
225 assert_abs_diff_eq!(out[[0, 2]], 3.0 / 6.0, epsilon = 1e-10);
226 }
227
228 #[test]
229 fn test_l1_unit_norm_after_transform() {
230 let norm = Normalizer::<f64>::l1();
231 let x = array![[1.0, 2.0, 3.0], [-4.0, 5.0, 6.0]];
232 let out = norm.transform(&x).unwrap();
233 for row in out.rows() {
234 let row_norm: f64 = row.iter().map(|v| v.abs()).sum();
235 assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
236 }
237 }
238
239 #[test]
240 fn test_max_norm_basic() {
241 let norm = Normalizer::<f64>::max();
242 let x = array![[-5.0, 3.0, 1.0]];
244 let out = norm.transform(&x).unwrap();
245 assert_abs_diff_eq!(out[[0, 0]], -1.0, epsilon = 1e-10);
246 assert_abs_diff_eq!(out[[0, 1]], 0.6, epsilon = 1e-10);
247 assert_abs_diff_eq!(out[[0, 2]], 0.2, epsilon = 1e-10);
248 }
249
250 #[test]
251 fn test_zero_row_unchanged() {
252 let norm = Normalizer::<f64>::l2();
253 let x = array![[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]];
254 let out = norm.transform(&x).unwrap();
255 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-15);
257 assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-15);
258 assert_abs_diff_eq!(out[[0, 2]], 0.0, epsilon = 1e-15);
259 }
260
261 #[test]
262 fn test_negative_values_l2() {
263 let norm = Normalizer::<f64>::l2();
264 let x = array![[-3.0, -4.0]];
265 let out = norm.transform(&x).unwrap();
266 assert_abs_diff_eq!(out[[0, 0]], -0.6, epsilon = 1e-10);
267 assert_abs_diff_eq!(out[[0, 1]], -0.8, epsilon = 1e-10);
268 }
269
270 #[test]
271 fn test_default_is_l2() {
272 let norm = Normalizer::<f64>::default();
273 assert_eq!(norm.norm(), NormType::L2);
274 }
275
276 #[test]
277 fn test_multiple_rows_independent() {
278 let norm = Normalizer::<f64>::l2();
279 let x = array![[3.0, 4.0], [0.0, 5.0]];
280 let out = norm.transform(&x).unwrap();
281 assert_abs_diff_eq!(out[[0, 0]], 0.6, epsilon = 1e-10);
283 assert_abs_diff_eq!(out[[0, 1]], 0.8, epsilon = 1e-10);
284 assert_abs_diff_eq!(out[[1, 0]], 0.0, epsilon = 1e-10);
286 assert_abs_diff_eq!(out[[1, 1]], 1.0, epsilon = 1e-10);
287 }
288
289 #[test]
290 fn test_pipeline_integration() {
291 use ferrolearn_core::pipeline::PipelineTransformer;
292 let norm = Normalizer::<f64>::l2();
293 let x = array![[3.0, 4.0], [0.0, 2.0]];
294 let y = Array1::zeros(2);
295 let fitted = norm.fit_pipeline(&x, &y).unwrap();
296 let result = fitted.transform_pipeline(&x).unwrap();
297 assert_abs_diff_eq!(result[[0, 0]], 0.6, epsilon = 1e-10);
298 assert_abs_diff_eq!(result[[0, 1]], 0.8, epsilon = 1e-10);
299 }
300
301 #[test]
302 fn test_f32_normalizer() {
303 let norm = Normalizer::<f32>::l2();
304 let x: Array2<f32> = array![[3.0f32, 4.0]];
305 let out = norm.transform(&x).unwrap();
306 assert!((out[[0, 0]] - 0.6f32).abs() < 1e-6);
307 assert!((out[[0, 1]] - 0.8f32).abs() < 1e-6);
308 }
309}