1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct StandardScaler {
10 pub with_mean: bool,
12 pub with_std: bool,
14}
15
16impl StandardScaler {
17 pub fn new() -> Self {
19 Self {
20 with_mean: true,
21 with_std: true,
22 }
23 }
24
25 pub fn with_mean(mut self, with_mean: bool) -> Self {
27 self.with_mean = with_mean;
28 self
29 }
30
31 pub fn with_std(mut self, with_std: bool) -> Self {
33 self.with_std = with_std;
34 self
35 }
36}
37
38impl Default for StandardScaler {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
46#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
47pub struct FittedStandardScaler<F: Float> {
48 mean: Array1<F>,
49 std: Array1<F>,
50 with_mean: bool,
51 with_std: bool,
52}
53
54impl<F: Float> FitUnsupervised<F> for StandardScaler {
55 type Fitted = FittedStandardScaler<F>;
56
57 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
58 if x.is_empty() {
59 return Err(RustMlError::EmptyInput("input array is empty".into()));
60 }
61
62 let n = F::from_usize(x.nrows()).unwrap();
63 let mean = x.sum_axis(Axis(0)) / n;
64
65 let std = if self.with_std {
66 let mut s = Array1::<F>::zeros(x.ncols());
68 for row in x.rows() {
69 for (s_j, (&val, &m)) in s.iter_mut().zip(row.iter().zip(mean.iter())) {
70 let d = val - m;
71 *s_j += d * d;
72 }
73 }
74 s.mapv(|v| (v / n).sqrt())
75 } else {
76 Array1::ones(x.ncols())
77 };
78
79 Ok(FittedStandardScaler {
80 mean,
81 std,
82 with_mean: self.with_mean,
83 with_std: self.with_std,
84 })
85 }
86}
87
88impl<F: Float> Transform<F> for FittedStandardScaler<F> {
89 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
90 if x.ncols() != self.mean.len() {
91 return Err(RustMlError::ShapeMismatch(format!(
92 "expected {} features, got {}",
93 self.mean.len(),
94 x.ncols()
95 )));
96 }
97
98 let mut result = x.to_owned();
99 for mut row in result.rows_mut() {
100 for (j, val) in row.iter_mut().enumerate() {
101 if self.with_mean {
102 *val -= self.mean[j];
103 }
104 if self.with_std && self.std[j] > F::from_f64(1e-15).unwrap() {
105 *val /= self.std[j];
106 }
107 }
108 }
109 Ok(result)
110 }
111}
112
113impl<F: Float> InverseTransform<F> for FittedStandardScaler<F> {
114 fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
115 if x.ncols() != self.mean.len() {
116 return Err(RustMlError::ShapeMismatch(format!(
117 "expected {} features, got {}",
118 self.mean.len(),
119 x.ncols()
120 )));
121 }
122
123 let mut result = x.to_owned();
124 for mut row in result.rows_mut() {
125 for (j, val) in row.iter_mut().enumerate() {
126 if self.with_std && self.std[j] > F::from_f64(1e-15).unwrap() {
127 *val *= self.std[j];
128 }
129 if self.with_mean {
130 *val += self.mean[j];
131 }
132 }
133 }
134 Ok(result)
135 }
136}
137
138impl<F: Float> FittedStandardScaler<F> {
139 pub fn mean(&self) -> &Array1<F> {
140 &self.mean
141 }
142
143 pub fn std(&self) -> &Array1<F> {
144 &self.std
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use approx::assert_abs_diff_eq;
152 use ndarray::array;
153
154 #[test]
155 fn test_fit_transform() {
156 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
157 let scaler = StandardScaler::default();
158 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
159 let transformed = fitted.transform(&x).unwrap();
160
161 let col_means = transformed.sum_axis(Axis(0)) / 3.0;
163 assert_abs_diff_eq!(col_means[0], 0.0, epsilon = 1e-10);
164 assert_abs_diff_eq!(col_means[1], 0.0, epsilon = 1e-10);
165 }
166
167 #[test]
168 fn test_inverse_transform_roundtrip() {
169 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
170 let scaler = StandardScaler::default();
171 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
172 let transformed = fitted.transform(&x).unwrap();
173 let recovered = fitted.inverse_transform(&transformed).unwrap();
174
175 for (a, b) in x.iter().zip(recovered.iter()) {
176 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
177 }
178 }
179
180 #[test]
181 fn test_without_mean() {
182 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
183 let scaler = StandardScaler {
184 with_mean: false,
185 with_std: true,
186 };
187 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
188 let transformed = fitted.transform(&x).unwrap();
189
190 assert!(transformed[[0, 0]] > 0.0);
192 }
193
194 #[test]
195 fn test_large_values() {
196 let x = array![[1e10, -1e10], [2e10, -2e10], [3e10, -3e10], [4e10, -4e10],];
198 let scaler = StandardScaler::default();
199 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
200 let transformed = fitted.transform(&x).unwrap();
201
202 for &v in transformed.iter() {
203 assert!(
204 v.is_finite(),
205 "transformed value should be finite, got {}",
206 v
207 );
208 }
209 let col_means = transformed.sum_axis(Axis(0)) / 4.0;
211 assert_abs_diff_eq!(col_means[0], 0.0, epsilon = 1e-8);
212 }
213
214 #[test]
215 fn test_small_values() {
216 let x = array![
218 [1e-10, 2e-10],
219 [3e-10, 4e-10],
220 [5e-10, 6e-10],
221 [7e-10, 8e-10],
222 ];
223 let scaler = StandardScaler::default();
224 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
225 let transformed = fitted.transform(&x).unwrap();
226
227 for &v in transformed.iter() {
228 assert!(
229 v.is_finite(),
230 "transformed value should be finite, got {}",
231 v
232 );
233 }
234 let recovered = fitted.inverse_transform(&transformed).unwrap();
236 for (a, b) in x.iter().zip(recovered.iter()) {
237 assert_abs_diff_eq!(a, b, epsilon = 1e-18);
238 }
239 }
240
241 #[test]
242 fn test_near_zero_variance_column() {
243 let x = array![
245 [1.0, 5.0],
246 [2.0, 5.0 + 1e-15],
247 [3.0, 5.0 - 1e-15],
248 [4.0, 5.0],
249 ];
250 let scaler = StandardScaler::default();
251 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
252 let transformed = fitted.transform(&x).unwrap();
253
254 for &v in transformed.iter() {
255 assert!(
256 v.is_finite(),
257 "near-zero variance column produced non-finite: {}",
258 v
259 );
260 }
261 }
262
263 mod prop_tests {
264 use super::*;
265 use proptest::prelude::*;
266
267 fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
269 use std::collections::hash_map::DefaultHasher;
270 use std::hash::{Hash, Hasher};
271 let mut values = Vec::with_capacity(rows * cols);
272 for i in 0..(rows * cols) {
273 let mut h = DefaultHasher::new();
274 seed.hash(&mut h);
275 (i as u64).hash(&mut h);
276 let bits = h.finish();
277 let v = (bits as f64 / u64::MAX as f64) * 20.0 - 10.0;
279 values.push(v);
280 }
281 Array2::from_shape_vec((rows, cols), values).unwrap()
282 }
283
284 proptest! {
285 #[test]
286 fn standard_scaler_roundtrip(
287 rows in 2..50usize,
288 cols in 1..10usize,
289 seed in 0u64..10000,
290 ) {
291 let x = make_data(rows, cols, seed);
292
293 let scaler = StandardScaler::default();
294 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
295 let transformed = fitted.transform(&x).unwrap();
296 let recovered = fitted.inverse_transform(&transformed).unwrap();
297
298 for (a, b) in x.iter().zip(recovered.iter()) {
299 prop_assert!((a - b).abs() < 1e-8,
300 "roundtrip failed: original={}, recovered={}", a, b);
301 }
302 }
303
304 #[test]
305 fn standard_scaler_mean_zero(
306 rows in 2..50usize,
307 cols in 1..10usize,
308 seed in 0u64..10000,
309 ) {
310 let x = make_data(rows, cols, seed);
311
312 let scaler = StandardScaler::default();
313 let fitted = FitUnsupervised::<f64>::fit(&scaler, &x).unwrap();
314 let transformed = fitted.transform(&x).unwrap();
315
316 let n = rows as f64;
317 for col_idx in 0..cols {
318 let col_mean: f64 = transformed.column(col_idx).sum() / n;
319 prop_assert!(col_mean.abs() < 1e-8,
320 "column {} mean should be ~0, got {}", col_idx, col_mean);
321
322 let col_std: f64 = (transformed.column(col_idx)
324 .iter()
325 .map(|&v| (v - col_mean) * (v - col_mean))
326 .sum::<f64>() / n)
327 .sqrt();
328 if fitted.std()[col_idx] > 1e-15 {
329 prop_assert!((col_std - 1.0).abs() < 1e-6,
330 "column {} std should be ~1, got {}", col_idx, col_std);
331 }
332 }
333 }
334 }
335 }
336}