1#[cfg(not(feature = "std"))]
15#[allow(unused_imports)]
16use num_traits::Float;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub struct NormParams {
27 pub mean: f64,
29 pub std_dev: f64,
31 pub min: f64,
33 pub max: f64,
35}
36
37impl NormParams {
38 #[must_use]
40 pub const fn identity() -> Self {
41 Self {
42 mean: 0.0,
43 std_dev: 1.0,
44 min: 0.0,
45 max: 1.0,
46 }
47 }
48}
49
50impl Default for NormParams {
51 fn default() -> Self {
52 Self::identity()
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59pub enum NormStrategy {
60 MinMax,
62 ZScore,
64 None,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq)]
82#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
83pub struct Normalizer<const D: usize> {
84 pub strategy: NormStrategy,
86 #[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
88 pub params: [NormParams; D],
89}
90
91impl<const D: usize> Normalizer<D> {
92 #[must_use]
96 pub const fn identity(strategy: NormStrategy) -> Self {
97 Self {
98 strategy,
99 params: [NormParams::identity(); D],
100 }
101 }
102
103 #[must_use]
110 #[allow(clippy::cast_precision_loss)]
111 pub fn fit(strategy: NormStrategy, samples: &[[f64; D]]) -> Self {
112 if samples.is_empty() {
113 return Self::identity(strategy);
114 }
115
116 let n = samples.len() as f64;
117 let mut params = [NormParams {
118 mean: 0.0,
119 std_dev: 0.0,
120 min: f64::MAX,
121 max: f64::MIN,
122 }; D];
123
124 for sample in samples {
125 for (i, &value) in sample.iter().enumerate() {
126 params[i].mean += value;
127 if value < params[i].min {
128 params[i].min = value;
129 }
130 if value > params[i].max {
131 params[i].max = value;
132 }
133 }
134 }
135
136 for p in &mut params {
137 p.mean /= n;
138 }
139
140 for sample in samples {
141 for (i, &value) in sample.iter().enumerate() {
142 let diff = value - params[i].mean;
143 params[i].std_dev += diff * diff;
144 }
145 }
146
147 for p in &mut params {
148 p.std_dev = (p.std_dev / n).sqrt();
149 }
150
151 Self { strategy, params }
152 }
153
154 #[must_use]
158 pub fn transform(&self, input: &[f64; D]) -> [f64; D] {
159 let mut out = [0.0_f64; D];
160 for (i, &value) in input.iter().enumerate() {
161 let p = &self.params[i];
162 out[i] = match self.strategy {
163 NormStrategy::MinMax => {
164 let range = p.max - p.min;
165 if range.abs() < f64::EPSILON {
166 0.5
167 } else {
168 ((value - p.min) / range).clamp(0.0, 1.0)
169 }
170 }
171 NormStrategy::ZScore => {
172 if p.std_dev.abs() < f64::EPSILON {
173 0.0
174 } else {
175 (value - p.mean) / p.std_dev
176 }
177 }
178 NormStrategy::None => value,
179 };
180 }
181 out
182 }
183}
184
185#[cfg(feature = "serde")]
186mod serde_arrays {
187 use super::NormParams;
191 use alloc::vec::Vec;
192 use serde::{Deserialize, Deserializer, Serialize, Serializer};
193
194 pub fn serialize<S: Serializer, const D: usize>(
196 params: &[NormParams; D],
197 s: S,
198 ) -> Result<S::Ok, S::Error> {
199 params.as_slice().serialize(s)
200 }
201
202 pub fn deserialize<'de, DSer: Deserializer<'de>, const D: usize>(
204 d: DSer,
205 ) -> Result<[NormParams; D], DSer::Error> {
206 let v: Vec<NormParams> = Vec::deserialize(d)?;
207 if v.len() != D {
208 return Err(serde::de::Error::invalid_length(
209 v.len(),
210 &"expected D entries",
211 ));
212 }
213 let mut out = [NormParams::identity(); D];
214 for (slot, p) in out.iter_mut().zip(v) {
215 *slot = p;
216 }
217 Ok(out)
218 }
219}
220
221#[cfg(test)]
222#[allow(clippy::float_cmp)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn identity_passthrough_under_none() {
228 let n = Normalizer::<3>::identity(NormStrategy::None);
229 let out = n.transform(&[10.0, 5000.0, -2.5]);
230 assert_eq!(out, [10.0, 5000.0, -2.5]);
231 }
232
233 #[test]
234 fn minmax_rescales_mid_point() {
235 let mut n = Normalizer::<1>::identity(NormStrategy::MinMax);
236 n.params[0] = NormParams {
237 min: 0.0,
238 max: 100.0,
239 mean: 0.0,
240 std_dev: 1.0,
241 };
242 let out = n.transform(&[50.0]);
243 assert!((out[0] - 0.5).abs() < 1e-12);
244 }
245
246 #[test]
247 fn minmax_clamps_above_max() {
248 let mut n = Normalizer::<1>::identity(NormStrategy::MinMax);
249 n.params[0] = NormParams {
250 min: 0.0,
251 max: 10.0,
252 mean: 0.0,
253 std_dev: 1.0,
254 };
255 let out = n.transform(&[20.0]);
256 assert!((out[0] - 1.0).abs() < 1e-12);
257 }
258
259 #[test]
260 fn minmax_clamps_below_min() {
261 let mut n = Normalizer::<1>::identity(NormStrategy::MinMax);
262 n.params[0] = NormParams {
263 min: 0.0,
264 max: 10.0,
265 mean: 0.0,
266 std_dev: 1.0,
267 };
268 let out = n.transform(&[-5.0]);
269 assert!((out[0] - 0.0).abs() < 1e-12);
270 }
271
272 #[test]
273 fn minmax_zero_range_returns_mid() {
274 let mut n = Normalizer::<1>::identity(NormStrategy::MinMax);
275 n.params[0] = NormParams {
276 min: 5.0,
277 max: 5.0,
278 mean: 5.0,
279 std_dev: 0.0,
280 };
281 let out = n.transform(&[5.0]);
282 assert!((out[0] - 0.5).abs() < 1e-12);
283 }
284
285 #[test]
286 fn zscore_centers_and_scales() {
287 let mut n = Normalizer::<1>::identity(NormStrategy::ZScore);
288 n.params[0] = NormParams {
289 mean: 50.0,
290 std_dev: 10.0,
291 min: 0.0,
292 max: 100.0,
293 };
294 let out = n.transform(&[70.0]);
295 assert!((out[0] - 2.0).abs() < 1e-12);
296 }
297
298 #[test]
299 fn zscore_zero_std_returns_zero() {
300 let mut n = Normalizer::<1>::identity(NormStrategy::ZScore);
301 n.params[0] = NormParams {
302 mean: 5.0,
303 std_dev: 0.0,
304 min: 5.0,
305 max: 5.0,
306 };
307 let out = n.transform(&[10.0]);
308 assert_eq!(out[0], 0.0);
309 }
310
311 #[test]
312 fn fit_learns_mean_min_max() {
313 let samples = [[10.0, 1000.0], [20.0, 2000.0], [30.0, 3000.0]];
314 let n = Normalizer::<2>::fit(NormStrategy::MinMax, &samples);
315 assert!((n.params[0].min - 10.0).abs() < 1e-12);
316 assert!((n.params[0].max - 30.0).abs() < 1e-12);
317 assert!((n.params[0].mean - 20.0).abs() < 1e-12);
318 assert!((n.params[1].mean - 2000.0).abs() < 1e-12);
319 }
320
321 #[test]
322 fn fit_then_transform_rescales_correctly() {
323 let samples = [[0.0, 0.0], [100.0, 1000.0]];
324 let n = Normalizer::<2>::fit(NormStrategy::MinMax, &samples);
325 let out = n.transform(&[50.0, 500.0]);
326 assert!((out[0] - 0.5).abs() < 1e-12);
327 assert!((out[1] - 0.5).abs() < 1e-12);
328 }
329
330 #[test]
331 fn fit_empty_falls_back_to_identity() {
332 let n: Normalizer<5> = Normalizer::fit(NormStrategy::ZScore, &[]);
333 for p in &n.params {
334 assert_eq!(p.mean, 0.0);
335 assert_eq!(p.std_dev, 1.0);
336 }
337 }
338
339 #[test]
340 fn fit_single_sample_has_zero_std() {
341 let samples = [[5.0_f64]];
342 let n = Normalizer::<1>::fit(NormStrategy::ZScore, &samples);
343 assert_eq!(n.params[0].mean, 5.0);
344 assert_eq!(n.params[0].std_dev, 0.0);
345 }
346
347 #[cfg(all(feature = "serde", feature = "postcard"))]
348 #[test]
349 fn postcard_roundtrip_preserves_transform() {
350 let samples = [[0.0_f64, 0.0], [100.0, 1000.0]];
351 let n = Normalizer::<2>::fit(NormStrategy::MinMax, &samples);
352 let bytes = postcard::to_allocvec(&n).expect("serde ok");
353 let back: Normalizer<2> = postcard::from_bytes(&bytes).expect("serde ok");
354 let probe = [50.0, 500.0];
355 assert_eq!(back.transform(&probe), n.transform(&probe));
356 }
357}