1use ferrolearn_core::error::FerroError;
7use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
8use ferrolearn_core::traits::{Fit, FitTransform, Transform};
9use ndarray::{Array1, Array2};
10use num_traits::Float;
11
12#[derive(Debug, Clone)]
37pub struct StandardScaler<F> {
38 _marker: std::marker::PhantomData<F>,
39}
40
41impl<F: Float + Send + Sync + 'static> StandardScaler<F> {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 _marker: std::marker::PhantomData,
47 }
48 }
49}
50
51impl<F: Float + Send + Sync + 'static> Default for StandardScaler<F> {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57#[derive(Debug, Clone)]
65pub struct FittedStandardScaler<F> {
66 pub(crate) mean: Array1<F>,
68 pub(crate) std: Array1<F>,
70}
71
72impl<F: Float + Send + Sync + 'static> FittedStandardScaler<F> {
73 #[must_use]
75 pub fn mean(&self) -> &Array1<F> {
76 &self.mean
77 }
78
79 #[must_use]
81 pub fn std(&self) -> &Array1<F> {
82 &self.std
83 }
84
85 pub fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
94 let n_features = self.mean.len();
95 if x.ncols() != n_features {
96 return Err(FerroError::ShapeMismatch {
97 expected: vec![x.nrows(), n_features],
98 actual: vec![x.nrows(), x.ncols()],
99 context: "FittedStandardScaler::inverse_transform".into(),
100 });
101 }
102 let mut out = x.to_owned();
103 for (mut col, (&m, &s)) in out
104 .columns_mut()
105 .into_iter()
106 .zip(self.mean.iter().zip(self.std.iter()))
107 {
108 for v in col.iter_mut() {
109 *v = *v * s + m;
110 }
111 }
112 Ok(out)
113 }
114}
115
116impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for StandardScaler<F> {
121 type Fitted = FittedStandardScaler<F>;
122 type Error = FerroError;
123
124 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedStandardScaler<F>, FerroError> {
130 let n_samples = x.nrows();
131 if n_samples == 0 {
132 return Err(FerroError::InsufficientSamples {
133 required: 1,
134 actual: 0,
135 context: "StandardScaler::fit".into(),
136 });
137 }
138
139 let n = F::from(n_samples).unwrap_or(F::one());
140 let n_features = x.ncols();
141 let mut mean = Array1::zeros(n_features);
142 let mut std_arr = Array1::zeros(n_features);
143
144 for j in 0..n_features {
145 let col = x.column(j);
146 let m = col.iter().copied().fold(F::zero(), |acc, v| acc + v) / n;
147 let variance = col
148 .iter()
149 .copied()
150 .map(|v| (v - m) * (v - m))
151 .fold(F::zero(), |acc, v| acc + v)
152 / n;
153 mean[j] = m;
154 std_arr[j] = variance.sqrt();
155 }
156
157 Ok(FittedStandardScaler { mean, std: std_arr })
158 }
159}
160
161impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedStandardScaler<F> {
162 type Output = Array2<F>;
163 type Error = FerroError;
164
165 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
174 let n_features = self.mean.len();
175 if x.ncols() != n_features {
176 return Err(FerroError::ShapeMismatch {
177 expected: vec![x.nrows(), n_features],
178 actual: vec![x.nrows(), x.ncols()],
179 context: "FittedStandardScaler::transform".into(),
180 });
181 }
182 let mut out = x.to_owned();
183 for (mut col, (&m, &s)) in out
184 .columns_mut()
185 .into_iter()
186 .zip(self.mean.iter().zip(self.std.iter()))
187 {
188 if s == F::zero() {
189 continue;
191 }
192 for v in col.iter_mut() {
193 *v = (*v - m) / s;
194 }
195 }
196 Ok(out)
197 }
198}
199
200impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for StandardScaler<F> {
204 type Output = Array2<F>;
205 type Error = FerroError;
206
207 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
212 Err(FerroError::InvalidParameter {
213 name: "StandardScaler".into(),
214 reason: "scaler must be fitted before calling transform; use fit() first".into(),
215 })
216 }
217}
218
219impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for StandardScaler<F> {
220 type FitError = FerroError;
221
222 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
228 let fitted = self.fit(x, &())?;
229 fitted.transform(x)
230 }
231}
232
233impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for StandardScaler<F> {
238 fn fit_pipeline(
246 &self,
247 x: &Array2<F>,
248 _y: &Array1<F>,
249 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
250 let fitted = self.fit(x, &())?;
251 Ok(Box::new(fitted))
252 }
253}
254
255impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedStandardScaler<F> {
256 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
262 self.transform(x)
263 }
264}
265
266#[cfg(test)]
271mod tests {
272 use super::*;
273 use approx::assert_abs_diff_eq;
274 use ndarray::array;
275
276 #[test]
277 fn test_standard_scaler_zero_mean_unit_variance() {
278 let scaler = StandardScaler::<f64>::new();
279 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
280 let fitted = scaler.fit(&x, &()).unwrap();
281 let scaled = fitted.transform(&x).unwrap();
282
283 for j in 0..scaled.ncols() {
285 let col_mean: f64 = scaled.column(j).iter().sum::<f64>() / scaled.nrows() as f64;
286 assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
287 }
288
289 for j in 0..scaled.ncols() {
291 let col_mean: f64 = scaled.column(j).iter().sum::<f64>() / scaled.nrows() as f64;
292 let variance: f64 = scaled
293 .column(j)
294 .iter()
295 .map(|&v| (v - col_mean).powi(2))
296 .sum::<f64>()
297 / scaled.nrows() as f64;
298 assert_abs_diff_eq!(variance, 1.0, epsilon = 1e-10);
299 }
300 }
301
302 #[test]
303 fn test_inverse_transform_roundtrip() {
304 let scaler = StandardScaler::<f64>::new();
305 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
306 let fitted = scaler.fit(&x, &()).unwrap();
307 let scaled = fitted.transform(&x).unwrap();
308 let recovered = fitted.inverse_transform(&scaled).unwrap();
309
310 for (a, b) in x.iter().zip(recovered.iter()) {
311 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
312 }
313 }
314
315 #[test]
316 fn test_zero_variance_column_unchanged() {
317 let scaler = StandardScaler::<f64>::new();
318 let x = array![[1.0, 5.0], [2.0, 5.0], [3.0, 5.0]];
320 let fitted = scaler.fit(&x, &()).unwrap();
321 assert_abs_diff_eq!(fitted.std()[1], 0.0, epsilon = 1e-15);
322 let scaled = fitted.transform(&x).unwrap();
323 for i in 0..3 {
325 assert_abs_diff_eq!(scaled[[i, 1]], 5.0, epsilon = 1e-10);
326 }
327 }
328
329 #[test]
330 fn test_fit_transform_equivalence() {
331 let scaler = StandardScaler::<f64>::new();
332 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
333 let via_fit_transform = scaler.fit_transform(&x).unwrap();
334 let fitted = scaler.fit(&x, &()).unwrap();
335 let via_separate = fitted.transform(&x).unwrap();
336 for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
337 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
338 }
339 }
340
341 #[test]
342 fn test_shape_mismatch_error() {
343 let scaler = StandardScaler::<f64>::new();
344 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
345 let fitted = scaler.fit(&x_train, &()).unwrap();
346 let x_bad = array![[1.0, 2.0, 3.0]];
347 assert!(fitted.transform(&x_bad).is_err());
348 }
349
350 #[test]
351 fn test_insufficient_samples_error() {
352 let scaler = StandardScaler::<f64>::new();
353 let x: Array2<f64> = Array2::zeros((0, 3));
354 assert!(scaler.fit(&x, &()).is_err());
355 }
356
357 #[test]
358 fn test_f32_scaler() {
359 let scaler = StandardScaler::<f32>::new();
360 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
361 let fitted = scaler.fit(&x, &()).unwrap();
362 let scaled = fitted.transform(&x).unwrap();
363 let col0_mean: f32 = scaled.column(0).iter().sum::<f32>() / 3.0;
364 assert!((col0_mean).abs() < 1e-6);
365 }
366}