Skip to main content

ferrolearn_preprocess/
standard_scaler.rs

1//! Standard scaler: zero-mean, unit-variance scaling.
2//!
3//! Each feature is transformed as `(x - mean) / std`. Zero-variance
4//! columns are left unchanged (the feature value is not modified).
5
6use ferrolearn_core::error::FerroError;
7use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
8use ferrolearn_core::traits::{Fit, FitTransform, Transform};
9use ndarray::{Array1, Array2};
10use num_traits::Float;
11
12// ---------------------------------------------------------------------------
13// StandardScaler (unfitted)
14// ---------------------------------------------------------------------------
15
16/// An unfitted standard scaler.
17///
18/// Calling [`Fit::fit`] learns the per-column means and standard deviations
19/// and returns a [`FittedStandardScaler`] that can transform new data.
20///
21/// Zero-variance columns (std = 0) are left unchanged after transformation.
22///
23/// # Examples
24///
25/// ```
26/// use ferrolearn_preprocess::StandardScaler;
27/// use ferrolearn_core::traits::{Fit, Transform};
28/// use ndarray::array;
29///
30/// let scaler = StandardScaler::<f64>::new();
31/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
32/// let fitted = scaler.fit(&x, &()).unwrap();
33/// let scaled = fitted.transform(&x).unwrap();
34/// // Mean of each column is now ~0, std ~1
35/// ```
36#[derive(Debug, Clone)]
37pub struct StandardScaler<F> {
38    _marker: std::marker::PhantomData<F>,
39}
40
41impl<F: Float + Send + Sync + 'static> StandardScaler<F> {
42    /// Create a new `StandardScaler` with default configuration.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            _marker: std::marker::PhantomData,
47        }
48    }
49}
50
51impl<F: Float + Send + Sync + 'static> Default for StandardScaler<F> {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57// ---------------------------------------------------------------------------
58// FittedStandardScaler
59// ---------------------------------------------------------------------------
60
61/// A fitted standard scaler holding per-column means and standard deviations.
62///
63/// Created by calling [`Fit::fit`] on a [`StandardScaler`].
64#[derive(Debug, Clone)]
65pub struct FittedStandardScaler<F> {
66    /// Per-column means learned during fitting.
67    pub(crate) mean: Array1<F>,
68    /// Per-column standard deviations learned during fitting.
69    pub(crate) std: Array1<F>,
70}
71
72impl<F: Float + Send + Sync + 'static> FittedStandardScaler<F> {
73    /// Return the per-column means learned during fitting.
74    #[must_use]
75    pub fn mean(&self) -> &Array1<F> {
76        &self.mean
77    }
78
79    /// Return the per-column standard deviations learned during fitting.
80    #[must_use]
81    pub fn std(&self) -> &Array1<F> {
82        &self.std
83    }
84
85    /// Inverse-transform scaled data back to original space.
86    ///
87    /// Applies `x_orig = x_scaled * std + mean` per column.
88    ///
89    /// # Errors
90    ///
91    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
92    /// match the number of features seen during fitting.
93    pub fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
94        let n_features = self.mean.len();
95        if x.ncols() != n_features {
96            return Err(FerroError::ShapeMismatch {
97                expected: vec![x.nrows(), n_features],
98                actual: vec![x.nrows(), x.ncols()],
99                context: "FittedStandardScaler::inverse_transform".into(),
100            });
101        }
102        let mut out = x.to_owned();
103        for (mut col, (&m, &s)) in out
104            .columns_mut()
105            .into_iter()
106            .zip(self.mean.iter().zip(self.std.iter()))
107        {
108            for v in col.iter_mut() {
109                *v = *v * s + m;
110            }
111        }
112        Ok(out)
113    }
114}
115
116// ---------------------------------------------------------------------------
117// Trait implementations
118// ---------------------------------------------------------------------------
119
120impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for StandardScaler<F> {
121    type Fitted = FittedStandardScaler<F>;
122    type Error = FerroError;
123
124    /// Fit the scaler by computing per-column means and standard deviations.
125    ///
126    /// # Errors
127    ///
128    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
129    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedStandardScaler<F>, FerroError> {
130        let n_samples = x.nrows();
131        if n_samples == 0 {
132            return Err(FerroError::InsufficientSamples {
133                required: 1,
134                actual: 0,
135                context: "StandardScaler::fit".into(),
136            });
137        }
138
139        let n = F::from(n_samples).unwrap_or(F::one());
140        let n_features = x.ncols();
141        let mut mean = Array1::zeros(n_features);
142        let mut std_arr = Array1::zeros(n_features);
143
144        for j in 0..n_features {
145            let col = x.column(j);
146            let m = col.iter().copied().fold(F::zero(), |acc, v| acc + v) / n;
147            let variance = col
148                .iter()
149                .copied()
150                .map(|v| (v - m) * (v - m))
151                .fold(F::zero(), |acc, v| acc + v)
152                / n;
153            mean[j] = m;
154            std_arr[j] = variance.sqrt();
155        }
156
157        Ok(FittedStandardScaler { mean, std: std_arr })
158    }
159}
160
161impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedStandardScaler<F> {
162    type Output = Array2<F>;
163    type Error = FerroError;
164
165    /// Transform data by subtracting the mean and dividing by the standard deviation.
166    ///
167    /// Columns with zero standard deviation are left unchanged.
168    ///
169    /// # Errors
170    ///
171    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
172    /// match the number of features seen during fitting.
173    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
174        let n_features = self.mean.len();
175        if x.ncols() != n_features {
176            return Err(FerroError::ShapeMismatch {
177                expected: vec![x.nrows(), n_features],
178                actual: vec![x.nrows(), x.ncols()],
179                context: "FittedStandardScaler::transform".into(),
180            });
181        }
182        let mut out = x.to_owned();
183        for (mut col, (&m, &s)) in out
184            .columns_mut()
185            .into_iter()
186            .zip(self.mean.iter().zip(self.std.iter()))
187        {
188            if s == F::zero() {
189                // Zero-variance column: leave unchanged.
190                continue;
191            }
192            for v in col.iter_mut() {
193                *v = (*v - m) / s;
194            }
195        }
196        Ok(out)
197    }
198}
199
200/// Implement `Transform` on the unfitted scaler to satisfy the `FitTransform: Transform`
201/// supertrait bound. Calling `transform` on an unfitted scaler always returns an error
202/// because no statistics have been learned yet.
203impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for StandardScaler<F> {
204    type Output = Array2<F>;
205    type Error = FerroError;
206
207    /// Always returns an error — the scaler must be fitted first.
208    ///
209    /// Use [`Fit::fit`] to produce a [`FittedStandardScaler`], then call
210    /// [`Transform::transform`] on that.
211    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
212        Err(FerroError::InvalidParameter {
213            name: "StandardScaler".into(),
214            reason: "scaler must be fitted before calling transform; use fit() first".into(),
215        })
216    }
217}
218
219impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for StandardScaler<F> {
220    type FitError = FerroError;
221
222    /// Fit the scaler on `x` and return the scaled output in one step.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if fitting fails (e.g., zero rows).
227    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
228        let fitted = self.fit(x, &())?;
229        fitted.transform(x)
230    }
231}
232
233// ---------------------------------------------------------------------------
234// Pipeline integration
235// ---------------------------------------------------------------------------
236
237impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for StandardScaler<F> {
238    /// Fit the scaler using the pipeline interface.
239    ///
240    /// The `y` argument is ignored; it exists only for API compatibility.
241    ///
242    /// # Errors
243    ///
244    /// Propagates errors from [`Fit::fit`].
245    fn fit_pipeline(
246        &self,
247        x: &Array2<F>,
248        _y: &Array1<F>,
249    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
250        let fitted = self.fit(x, &())?;
251        Ok(Box::new(fitted))
252    }
253}
254
255impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedStandardScaler<F> {
256    /// Transform data using the pipeline interface.
257    ///
258    /// # Errors
259    ///
260    /// Propagates errors from [`Transform::transform`].
261    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
262        self.transform(x)
263    }
264}
265
266// ---------------------------------------------------------------------------
267// Tests
268// ---------------------------------------------------------------------------
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use approx::assert_abs_diff_eq;
274    use ndarray::array;
275
276    #[test]
277    fn test_standard_scaler_zero_mean_unit_variance() {
278        let scaler = StandardScaler::<f64>::new();
279        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
280        let fitted = scaler.fit(&x, &()).unwrap();
281        let scaled = fitted.transform(&x).unwrap();
282
283        // Each column should have mean ~0
284        for j in 0..scaled.ncols() {
285            let col_mean: f64 = scaled.column(j).iter().sum::<f64>() / scaled.nrows() as f64;
286            assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
287        }
288
289        // Each column should have population std ~1
290        for j in 0..scaled.ncols() {
291            let col_mean: f64 = scaled.column(j).iter().sum::<f64>() / scaled.nrows() as f64;
292            let variance: f64 = scaled
293                .column(j)
294                .iter()
295                .map(|&v| (v - col_mean).powi(2))
296                .sum::<f64>()
297                / scaled.nrows() as f64;
298            assert_abs_diff_eq!(variance, 1.0, epsilon = 1e-10);
299        }
300    }
301
302    #[test]
303    fn test_inverse_transform_roundtrip() {
304        let scaler = StandardScaler::<f64>::new();
305        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
306        let fitted = scaler.fit(&x, &()).unwrap();
307        let scaled = fitted.transform(&x).unwrap();
308        let recovered = fitted.inverse_transform(&scaled).unwrap();
309
310        for (a, b) in x.iter().zip(recovered.iter()) {
311            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
312        }
313    }
314
315    #[test]
316    fn test_zero_variance_column_unchanged() {
317        let scaler = StandardScaler::<f64>::new();
318        // Column 1 is constant: std = 0
319        let x = array![[1.0, 5.0], [2.0, 5.0], [3.0, 5.0]];
320        let fitted = scaler.fit(&x, &()).unwrap();
321        assert_abs_diff_eq!(fitted.std()[1], 0.0, epsilon = 1e-15);
322        let scaled = fitted.transform(&x).unwrap();
323        // Constant column should remain 5.0 (unchanged)
324        for i in 0..3 {
325            assert_abs_diff_eq!(scaled[[i, 1]], 5.0, epsilon = 1e-10);
326        }
327    }
328
329    #[test]
330    fn test_fit_transform_equivalence() {
331        let scaler = StandardScaler::<f64>::new();
332        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
333        let via_fit_transform = scaler.fit_transform(&x).unwrap();
334        let fitted = scaler.fit(&x, &()).unwrap();
335        let via_separate = fitted.transform(&x).unwrap();
336        for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
337            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
338        }
339    }
340
341    #[test]
342    fn test_shape_mismatch_error() {
343        let scaler = StandardScaler::<f64>::new();
344        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
345        let fitted = scaler.fit(&x_train, &()).unwrap();
346        let x_bad = array![[1.0, 2.0, 3.0]];
347        assert!(fitted.transform(&x_bad).is_err());
348    }
349
350    #[test]
351    fn test_insufficient_samples_error() {
352        let scaler = StandardScaler::<f64>::new();
353        let x: Array2<f64> = Array2::zeros((0, 3));
354        assert!(scaler.fit(&x, &()).is_err());
355    }
356
357    #[test]
358    fn test_f32_scaler() {
359        let scaler = StandardScaler::<f32>::new();
360        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
361        let fitted = scaler.fit(&x, &()).unwrap();
362        let scaled = fitted.transform(&x).unwrap();
363        let col0_mean: f32 = scaled.column(0).iter().sum::<f32>() / 3.0;
364        assert!((col0_mean).abs() < 1e-6);
365    }
366}