Skip to main content

anofox_ml_preprocessing/
standard_scaler.rs

1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4/// Parameters for StandardScaler (unfitted state).
5///
6/// Standardizes features by removing the mean and scaling to unit variance:
7/// `z = (x - mean) / std`
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct StandardScaler {
10    /// If true, center data to zero mean before scaling.
11    pub with_mean: bool,
12    /// If true, scale data to unit variance.
13    pub with_std: bool,
14}
15
16impl StandardScaler {
17    /// Create a new `StandardScaler` with defaults (both centering and scaling enabled).
18    pub fn new() -> Self {
19        Self {
20            with_mean: true,
21            with_std: true,
22        }
23    }
24
25    /// Set whether to center data to zero mean before scaling.
26    pub fn with_mean(mut self, with_mean: bool) -> Self {
27        self.with_mean = with_mean;
28        self
29    }
30
31    /// Set whether to scale data to unit variance.
32    pub fn with_std(mut self, with_std: bool) -> Self {
33        self.with_std = with_std;
34        self
35    }
36}
37
38impl Default for StandardScaler {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44/// Fitted StandardScaler — holds learned mean and std per feature.
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
46#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
47pub struct FittedStandardScaler<F: Float> {
48    mean: Array1<F>,
49    std: Array1<F>,
50    with_mean: bool,
51    with_std: bool,
52}
53
54impl<F: Float> FitUnsupervised<F> for StandardScaler {
55    type Fitted = FittedStandardScaler<F>;
56
57    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
58        if x.is_empty() {
59            return Err(RustMlError::EmptyInput("input array is empty".into()));
60        }
61
62        let n = F::from_usize(x.nrows()).unwrap();
63        let mean = x.sum_axis(Axis(0)) / n;
64
65        let std = if self.with_std {
66            // Single-pass variance: no intermediate array allocations.
67            let mut s = Array1::<F>::zeros(x.ncols());
68            for row in x.rows() {
69                for (s_j, (&val, &m)) in s.iter_mut().zip(row.iter().zip(mean.iter())) {
70                    let d = val - m;
71                    *s_j += d * d;
72                }
73            }
74            s.mapv(|v| (v / n).sqrt())
75        } else {
76            Array1::ones(x.ncols())
77        };
78
79        Ok(FittedStandardScaler {
80            mean,
81            std,
82            with_mean: self.with_mean,
83            with_std: self.with_std,
84        })
85    }
86}
87
88impl<F: Float> Transform<F> for FittedStandardScaler<F> {
89    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
90        if x.ncols() != self.mean.len() {
91            return Err(RustMlError::ShapeMismatch(format!(
92                "expected {} features, got {}",
93                self.mean.len(),
94                x.ncols()
95            )));
96        }
97
98        let mut result = x.to_owned();
99        for mut row in result.rows_mut() {
100            for (j, val) in row.iter_mut().enumerate() {
101                if self.with_mean {
102                    *val -= self.mean[j];
103                }
104                if self.with_std && self.std[j] > F::from_f64(1e-15).unwrap() {
105                    *val /= self.std[j];
106                }
107            }
108        }
109        Ok(result)
110    }
111}
112
113impl<F: Float> InverseTransform<F> for FittedStandardScaler<F> {
114    fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
115        if x.ncols() != self.mean.len() {
116            return Err(RustMlError::ShapeMismatch(format!(
117                "expected {} features, got {}",
118                self.mean.len(),
119                x.ncols()
120            )));
121        }
122
123        let mut result = x.to_owned();
124        for mut row in result.rows_mut() {
125            for (j, val) in row.iter_mut().enumerate() {
126                if self.with_std && self.std[j] > F::from_f64(1e-15).unwrap() {
127                    *val *= self.std[j];
128                }
129                if self.with_mean {
130                    *val += self.mean[j];
131                }
132            }
133        }
134        Ok(result)
135    }
136}
137
138impl<F: Float> FittedStandardScaler<F> {
139    pub fn mean(&self) -> &Array1<F> {
140        &self.mean
141    }
142
143    pub fn std(&self) -> &Array1<F> {
144        &self.std
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use approx::assert_abs_diff_eq;
152    use ndarray::array;
153
154    #[test]
155    fn test_fit_transform() {
156        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
157        let scaler = StandardScaler::default();
158        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
159        let transformed = fitted.transform(&x).unwrap();
160
161        // Mean of each column should be ~0
162        let col_means = transformed.sum_axis(Axis(0)) / 3.0;
163        assert_abs_diff_eq!(col_means[0], 0.0, epsilon = 1e-10);
164        assert_abs_diff_eq!(col_means[1], 0.0, epsilon = 1e-10);
165    }
166
167    #[test]
168    fn test_inverse_transform_roundtrip() {
169        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
170        let scaler = StandardScaler::default();
171        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
172        let transformed = fitted.transform(&x).unwrap();
173        let recovered = fitted.inverse_transform(&transformed).unwrap();
174
175        for (a, b) in x.iter().zip(recovered.iter()) {
176            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
177        }
178    }
179
180    #[test]
181    fn test_without_mean() {
182        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
183        let scaler = StandardScaler {
184            with_mean: false,
185            with_std: true,
186        };
187        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
188        let transformed = fitted.transform(&x).unwrap();
189
190        // Without centering, values should still be positive
191        assert!(transformed[[0, 0]] > 0.0);
192    }
193
194    #[test]
195    fn test_large_values() {
196        // Very large feature values should not produce NaN/Inf
197        let x = array![[1e10, -1e10], [2e10, -2e10], [3e10, -3e10], [4e10, -4e10],];
198        let scaler = StandardScaler::default();
199        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
200        let transformed = fitted.transform(&x).unwrap();
201
202        for &v in transformed.iter() {
203            assert!(
204                v.is_finite(),
205                "transformed value should be finite, got {}",
206                v
207            );
208        }
209        // Mean should be ~0
210        let col_means = transformed.sum_axis(Axis(0)) / 4.0;
211        assert_abs_diff_eq!(col_means[0], 0.0, epsilon = 1e-8);
212    }
213
214    #[test]
215    fn test_small_values() {
216        // Very small feature values should not lose precision
217        let x = array![
218            [1e-10, 2e-10],
219            [3e-10, 4e-10],
220            [5e-10, 6e-10],
221            [7e-10, 8e-10],
222        ];
223        let scaler = StandardScaler::default();
224        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
225        let transformed = fitted.transform(&x).unwrap();
226
227        for &v in transformed.iter() {
228            assert!(
229                v.is_finite(),
230                "transformed value should be finite, got {}",
231                v
232            );
233        }
234        // Roundtrip should preserve values
235        let recovered = fitted.inverse_transform(&transformed).unwrap();
236        for (a, b) in x.iter().zip(recovered.iter()) {
237            assert_abs_diff_eq!(a, b, epsilon = 1e-18);
238        }
239    }
240
241    #[test]
242    fn test_near_zero_variance_column() {
243        // One column has near-zero variance; scaler should handle without NaN
244        let x = array![
245            [1.0, 5.0],
246            [2.0, 5.0 + 1e-15],
247            [3.0, 5.0 - 1e-15],
248            [4.0, 5.0],
249        ];
250        let scaler = StandardScaler::default();
251        let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
252        let transformed = fitted.transform(&x).unwrap();
253
254        for &v in transformed.iter() {
255            assert!(
256                v.is_finite(),
257                "near-zero variance column produced non-finite: {}",
258                v
259            );
260        }
261    }
262
263    mod prop_tests {
264        use super::*;
265        use proptest::prelude::*;
266
267        /// Generate a deterministic Array2<f64> from dimensions and a seed.
268        fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
269            use std::collections::hash_map::DefaultHasher;
270            use std::hash::{Hash, Hasher};
271            let mut values = Vec::with_capacity(rows * cols);
272            for i in 0..(rows * cols) {
273                let mut h = DefaultHasher::new();
274                seed.hash(&mut h);
275                (i as u64).hash(&mut h);
276                let bits = h.finish();
277                // Map to a reasonable f64 range [-10, 10]
278                let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
279                values.push(v);
280            }
281            Array2::from_shape_vec((rows, cols), values).unwrap()
282        }
283
284        proptest! {
285            #[test]
286            fn standard_scaler_roundtrip(
287                rows in 2..50usize,
288                cols in 1..10usize,
289                seed in 0u64..10000,
290            ) {
291                let x = make_data(rows, cols, seed);
292
293                let scaler = StandardScaler::default();
294                let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
295                let transformed = fitted.transform(&x).unwrap();
296                let recovered = fitted.inverse_transform(&transformed).unwrap();
297
298                for (a, b) in x.iter().zip(recovered.iter()) {
299                    prop_assert!((a - b).abs() < 1e-8,
300                        "roundtrip failed: original={}, recovered={}", a, b);
301                }
302            }
303
304            #[test]
305            fn standard_scaler_mean_zero(
306                rows in 2..50usize,
307                cols in 1..10usize,
308                seed in 0u64..10000,
309            ) {
310                let x = make_data(rows, cols, seed);
311
312                let scaler = StandardScaler::default();
313                let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
314                let transformed = fitted.transform(&x).unwrap();
315
316                let n = rows as f64;
317                for col_idx in 0..cols {
318                    let col_mean: f64 = transformed.column(col_idx).sum() / n;
319                    prop_assert!(col_mean.abs() < 1e-8,
320                        "column {} mean should be ~0, got {}", col_idx, col_mean);
321
322                    // Standard deviation should be ~1 (if original std > 0)
323                    let col_std: f64 = (transformed.column(col_idx)
324                        .iter()
325                        .map(|&v| (v - col_mean) * (v - col_mean))
326                        .sum::<f64>() / n)
327                        .sqrt();
328                    if fitted.std()[col_idx] > 1e-15 {
329                        prop_assert!((col_std - 1.0).abs() < 1e-6,
330                            "column {} std should be ~1, got {}", col_idx, col_std);
331                    }
332                }
333            }
334        }
335    }
336}