Skip to main content

scry_learn/preprocess/
scaler.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Feature scaling transformers.
3
4use crate::dataset::Dataset;
5use crate::error::{Result, ScryLearnError};
6use crate::preprocess::Transformer;
7use crate::sparse::CscMatrix;
8
9/// Standardize features by removing the mean and scaling to unit variance.
10///
11/// Each feature is transformed as: `x' = (x - mean) / std`.
12/// Features with zero variance are left unchanged.
13#[derive(Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15#[non_exhaustive]
16pub struct StandardScaler {
17    means: Vec<f64>,
18    stds: Vec<f64>,
19    fitted: bool,
20    #[cfg_attr(feature = "serde", serde(default))]
21    _schema_version: u32,
22}
23
24impl StandardScaler {
25    /// Create a new unfitted scaler.
26    pub fn new() -> Self {
27        Self {
28            means: Vec::new(),
29            stds: Vec::new(),
30            fitted: false,
31            _schema_version: crate::version::SCHEMA_VERSION,
32        }
33    }
34}
35
36impl StandardScaler {
37    /// Fit on sparse features.
38    ///
39    /// Computes mean and std from sparse columns, correctly accounting for
40    /// zero entries: `mean = sum_nonzero / n_total`.
41    pub fn fit_sparse(&mut self, features: &CscMatrix) -> Result<()> {
42        let n = features.n_rows();
43        if n == 0 {
44            return Err(ScryLearnError::EmptyDataset);
45        }
46        let n_f64 = n as f64;
47        self.means = Vec::with_capacity(features.n_cols());
48        self.stds = Vec::with_capacity(features.n_cols());
49
50        for j in 0..features.n_cols() {
51            let col = features.col(j);
52            let sum: f64 = col.iter().map(|(_, v)| v).sum();
53            let mean = sum / n_f64;
54            let mut var = 0.0;
55            let mut nnz_count = 0usize;
56            for (_, val) in col.iter() {
57                var += (val - mean).powi(2);
58                nnz_count += 1;
59            }
60            // Zero entries contribute (0 - mean)² each.
61            let n_zeros = n - nnz_count;
62            var += n_zeros as f64 * mean * mean;
63            var /= n_f64;
64            self.means.push(mean);
65            self.stds.push(var.sqrt());
66        }
67        self.fitted = true;
68        Ok(())
69    }
70
71    /// Transform sparse features, returning a new `CscMatrix`.
72    ///
73    /// Only scales by std (no centering) to preserve sparsity.
74    /// Centering would make all zeros become `-mean`, destroying sparsity.
75    pub fn transform_sparse(&self, features: &CscMatrix) -> Result<CscMatrix> {
76        if !self.fitted {
77            return Err(ScryLearnError::NotFitted);
78        }
79        // Build new CscMatrix with scaled values.
80        let mut cols: Vec<Vec<f64>> = Vec::with_capacity(features.n_cols());
81        for j in 0..features.n_cols() {
82            let std = self.stds[j];
83            let mut col = vec![0.0; features.n_rows()];
84            if std > 1e-12 {
85                for (row_idx, val) in features.col(j).iter() {
86                    col[row_idx] = val / std;
87                }
88            }
89            cols.push(col);
90        }
91        Ok(CscMatrix::from_dense(&cols))
92    }
93}
94
95impl StandardScaler {
96    /// Whether the scaler has been fitted.
97    pub fn is_fitted(&self) -> bool {
98        self.fitted
99    }
100
101    /// Per-feature means computed during fit.
102    pub fn means(&self) -> &[f64] {
103        &self.means
104    }
105
106    /// Per-feature standard deviations computed during fit.
107    pub fn stds(&self) -> &[f64] {
108        &self.stds
109    }
110}
111
112impl Default for StandardScaler {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl Transformer for StandardScaler {
119    fn fit(&mut self, data: &Dataset) -> Result<()> {
120        data.validate_finite()?;
121        if let Some(csc) = data.sparse_csc() {
122            return self.fit_sparse(csc);
123        }
124        let n = data.n_samples() as f64;
125        if n == 0.0 {
126            return Err(ScryLearnError::EmptyDataset);
127        }
128        let mat = data.matrix();
129        self.means = Vec::with_capacity(data.n_features());
130        self.stds = Vec::with_capacity(data.n_features());
131
132        for j in 0..data.n_features() {
133            let col = mat.col(j);
134            let mean = col.iter().sum::<f64>() / n;
135            let var = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
136            self.means.push(mean);
137            self.stds.push(var.sqrt());
138        }
139        self.fitted = true;
140        Ok(())
141    }
142
143    fn transform(&self, data: &mut Dataset) -> Result<()> {
144        crate::version::check_schema_version(self._schema_version)?;
145        if !self.fitted {
146            return Err(ScryLearnError::NotFitted);
147        }
148        for (j, col) in data.features.iter_mut().enumerate() {
149            let mean = self.means[j];
150            let std = self.stds[j];
151            if std > 1e-12 {
152                for x in col.iter_mut() {
153                    *x = (*x - mean) / std;
154                }
155            }
156        }
157        data.sync_matrix();
158        Ok(())
159    }
160
161    fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
162        if !self.fitted {
163            return Err(ScryLearnError::NotFitted);
164        }
165        for (j, col) in data.features.iter_mut().enumerate() {
166            let mean = self.means[j];
167            let std = self.stds[j];
168            if std > 1e-12 {
169                for x in col.iter_mut() {
170                    *x = *x * std + mean;
171                }
172            }
173            // When std <= 1e-12, transform left values unchanged,
174            // so inverse_transform must also leave them unchanged.
175        }
176        data.sync_matrix();
177        Ok(())
178    }
179}
180
181/// Scale features to a [0, 1] range.
182///
183/// Each feature is transformed as: `x' = (x - min) / (max - min)`.
184/// Features with zero range are set to 0.
185#[derive(Clone, Debug)]
186#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
187#[non_exhaustive]
188pub struct MinMaxScaler {
189    mins: Vec<f64>,
190    maxs: Vec<f64>,
191    fitted: bool,
192    #[cfg_attr(feature = "serde", serde(default))]
193    _schema_version: u32,
194}
195
196impl MinMaxScaler {
197    /// Create a new unfitted scaler.
198    pub fn new() -> Self {
199        Self {
200            mins: Vec::new(),
201            maxs: Vec::new(),
202            fitted: false,
203            _schema_version: crate::version::SCHEMA_VERSION,
204        }
205    }
206}
207
208impl MinMaxScaler {
209    /// Fit on sparse features.
210    ///
211    /// Computes min/max from sparse columns, accounting for implicit zeros.
212    pub fn fit_sparse(&mut self, features: &CscMatrix) -> Result<()> {
213        let n = features.n_rows();
214        if n == 0 {
215            return Err(ScryLearnError::EmptyDataset);
216        }
217        self.mins = Vec::with_capacity(features.n_cols());
218        self.maxs = Vec::with_capacity(features.n_cols());
219
220        for j in 0..features.n_cols() {
221            let col = features.col(j);
222            let nnz = col.nnz();
223            if nnz == 0 {
224                // All zeros.
225                self.mins.push(0.0);
226                self.maxs.push(0.0);
227            } else {
228                let mut min = f64::INFINITY;
229                let mut max = f64::NEG_INFINITY;
230                for (_, val) in col.iter() {
231                    if val < min {
232                        min = val;
233                    }
234                    if val > max {
235                        max = val;
236                    }
237                }
238                // Account for implicit zeros.
239                if nnz < n {
240                    if 0.0 < min {
241                        min = 0.0;
242                    }
243                    if 0.0 > max {
244                        max = 0.0;
245                    }
246                }
247                self.mins.push(min);
248                self.maxs.push(max);
249            }
250        }
251        self.fitted = true;
252        Ok(())
253    }
254
255    /// Transform sparse features, returning a new `CscMatrix`.
256    pub fn transform_sparse(&self, features: &CscMatrix) -> Result<CscMatrix> {
257        if !self.fitted {
258            return Err(ScryLearnError::NotFitted);
259        }
260        let mut cols: Vec<Vec<f64>> = Vec::with_capacity(features.n_cols());
261        for j in 0..features.n_cols() {
262            let min = self.mins[j];
263            let range = self.maxs[j] - min;
264            let mut col = vec![0.0; features.n_rows()];
265            if range > 1e-12 {
266                // Zero entries map to (0 - min) / range.
267                let zero_mapped = (0.0 - min) / range;
268                col.fill(zero_mapped);
269                for (row_idx, val) in features.col(j).iter() {
270                    col[row_idx] = (val - min) / range;
271                }
272            }
273            cols.push(col);
274        }
275        Ok(CscMatrix::from_dense(&cols))
276    }
277}
278
279impl Default for MinMaxScaler {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285impl Transformer for MinMaxScaler {
286    fn fit(&mut self, data: &Dataset) -> Result<()> {
287        data.validate_finite()?;
288        if let Some(csc) = data.sparse_csc() {
289            return self.fit_sparse(csc);
290        }
291        if data.n_samples() == 0 {
292            return Err(ScryLearnError::EmptyDataset);
293        }
294        let mat = data.matrix();
295        self.mins = Vec::with_capacity(data.n_features());
296        self.maxs = Vec::with_capacity(data.n_features());
297
298        for j in 0..data.n_features() {
299            let col = mat.col(j);
300            let min = col.iter().copied().fold(f64::INFINITY, f64::min);
301            let max = col.iter().copied().fold(f64::NEG_INFINITY, f64::max);
302            self.mins.push(min);
303            self.maxs.push(max);
304        }
305        self.fitted = true;
306        Ok(())
307    }
308
309    fn transform(&self, data: &mut Dataset) -> Result<()> {
310        crate::version::check_schema_version(self._schema_version)?;
311        if !self.fitted {
312            return Err(ScryLearnError::NotFitted);
313        }
314        for (j, col) in data.features.iter_mut().enumerate() {
315            let min = self.mins[j];
316            let range = self.maxs[j] - min;
317            if range > 1e-12 {
318                for x in col.iter_mut() {
319                    *x = (*x - min) / range;
320                }
321            } else {
322                for x in col.iter_mut() {
323                    *x = 0.0;
324                }
325            }
326        }
327        data.sync_matrix();
328        Ok(())
329    }
330
331    fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
332        if !self.fitted {
333            return Err(ScryLearnError::NotFitted);
334        }
335        for (j, col) in data.features.iter_mut().enumerate() {
336            let min = self.mins[j];
337            let range = self.maxs[j] - min;
338            for x in col.iter_mut() {
339                *x = *x * range + min;
340            }
341        }
342        data.sync_matrix();
343        Ok(())
344    }
345}
346
347// ── helpers ──────────────────────────────────────────────────────
348
349/// Compute the quantile of a sorted slice using linear interpolation.
350fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
351    debug_assert!(!sorted.is_empty());
352    if sorted.len() == 1 {
353        return sorted[0];
354    }
355    let pos = q * (sorted.len() - 1) as f64;
356    let lo = pos.floor() as usize;
357    let hi = pos.ceil() as usize;
358    let frac = pos - lo as f64;
359    sorted[lo] * (1.0 - frac) + sorted[hi] * frac
360}
361
362/// Scale features using the median and inter-quartile range (IQR).
363///
364/// Each feature is transformed as: `x' = (x - median) / IQR`.
365/// Features with zero IQR are left unchanged.
366///
367/// `RobustScaler` is less sensitive to outliers than [`StandardScaler`]
368/// because it uses the median and quartiles rather than mean and std.
369///
370/// # Example
371///
372/// ```ignore
373/// let mut scaler = RobustScaler::new();
374/// scaler.fit_transform(&mut ds)?;
375/// ```
376#[derive(Clone, Debug)]
377#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
378#[non_exhaustive]
379pub struct RobustScaler {
380    medians: Vec<f64>,
381    iqrs: Vec<f64>,
382    fitted: bool,
383    #[cfg_attr(feature = "serde", serde(default))]
384    _schema_version: u32,
385}
386
387impl RobustScaler {
388    /// Create a new unfitted robust scaler.
389    pub fn new() -> Self {
390        Self {
391            medians: Vec::new(),
392            iqrs: Vec::new(),
393            fitted: false,
394            _schema_version: crate::version::SCHEMA_VERSION,
395        }
396    }
397}
398
399impl Default for RobustScaler {
400    fn default() -> Self {
401        Self::new()
402    }
403}
404
405impl Transformer for RobustScaler {
406    fn fit(&mut self, data: &Dataset) -> Result<()> {
407        data.validate_finite()?;
408        if data.n_samples() == 0 {
409            return Err(ScryLearnError::EmptyDataset);
410        }
411        let mat = data.matrix();
412        self.medians = Vec::with_capacity(data.n_features());
413        self.iqrs = Vec::with_capacity(data.n_features());
414
415        for j in 0..data.n_features() {
416            let col = mat.col(j);
417            let mut sorted = col.to_vec();
418            sorted.sort_unstable_by(|a, b| a.total_cmp(b));
419            let median = quantile_sorted(&sorted, 0.5);
420            let q1 = quantile_sorted(&sorted, 0.25);
421            let q3 = quantile_sorted(&sorted, 0.75);
422            self.medians.push(median);
423            self.iqrs.push(q3 - q1);
424        }
425        self.fitted = true;
426        Ok(())
427    }
428
429    fn transform(&self, data: &mut Dataset) -> Result<()> {
430        crate::version::check_schema_version(self._schema_version)?;
431        if !self.fitted {
432            return Err(ScryLearnError::NotFitted);
433        }
434        for (j, col) in data.features.iter_mut().enumerate() {
435            let median = self.medians[j];
436            let iqr = self.iqrs[j];
437            if iqr > 1e-12 {
438                for x in col.iter_mut() {
439                    *x = (*x - median) / iqr;
440                }
441            } else {
442                for x in col.iter_mut() {
443                    *x -= median;
444                }
445            }
446        }
447        data.sync_matrix();
448        Ok(())
449    }
450
451    fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
452        if !self.fitted {
453            return Err(ScryLearnError::NotFitted);
454        }
455        for (j, col) in data.features.iter_mut().enumerate() {
456            let median = self.medians[j];
457            let iqr = self.iqrs[j];
458            if iqr > 1e-12 {
459                for x in col.iter_mut() {
460                    *x = *x * iqr + median;
461                }
462            } else {
463                // When IQR <= 1e-12, transform only subtracted median,
464                // so inverse must only add it back.
465                for x in col.iter_mut() {
466                    *x += median;
467                }
468            }
469        }
470        data.sync_matrix();
471        Ok(())
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    #[test]
480    fn test_standard_scaler_zero_mean() {
481        let mut ds = Dataset::new(
482            vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]],
483            vec![0.0; 5],
484            vec!["x".into()],
485            "y",
486        );
487        let mut scaler = StandardScaler::new();
488        scaler.fit_transform(&mut ds).unwrap();
489
490        let mean: f64 = ds.features[0].iter().sum::<f64>() / 5.0;
491        assert!((mean).abs() < 1e-10, "mean should be ~0, got {mean}");
492
493        let var: f64 = ds.features[0].iter().map(|x| x.powi(2)).sum::<f64>() / 5.0;
494        assert!(
495            (var - 1.0).abs() < 1e-10,
496            "variance should be ~1, got {var}"
497        );
498    }
499
500    #[test]
501    fn test_minmax_scaler_range() {
502        let mut ds = Dataset::new(
503            vec![vec![10.0, 20.0, 30.0]],
504            vec![0.0; 3],
505            vec!["x".into()],
506            "y",
507        );
508        let mut scaler = MinMaxScaler::new();
509        scaler.fit_transform(&mut ds).unwrap();
510
511        assert!((ds.features[0][0]).abs() < 1e-10);
512        assert!((ds.features[0][2] - 1.0).abs() < 1e-10);
513    }
514
515    #[test]
516    fn test_standard_scaler_not_fitted() {
517        let scaler = StandardScaler::new();
518        let mut ds = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "y");
519        assert!(scaler.transform(&mut ds).is_err());
520    }
521
522    #[test]
523    fn test_standard_scaler_roundtrip() {
524        let original = vec![2.0, 4.0, 6.0, 8.0];
525        let mut ds = Dataset::new(vec![original.clone()], vec![0.0; 4], vec!["x".into()], "y");
526        let mut scaler = StandardScaler::new();
527        scaler.fit_transform(&mut ds).unwrap();
528        scaler.inverse_transform(&mut ds).unwrap();
529
530        for (a, b) in ds.features[0].iter().zip(original.iter()) {
531            assert!((a - b).abs() < 1e-10);
532        }
533    }
534
535    #[test]
536    fn test_robust_scaler_median_centering() {
537        // [1, 2, 3, 4, 5]: median=3, Q1=1.5 (interp), Q3=4.5, IQR=3
538        let mut ds = Dataset::new(
539            vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]],
540            vec![0.0; 5],
541            vec!["x".into()],
542            "y",
543        );
544        let mut scaler = RobustScaler::new();
545        scaler.fit_transform(&mut ds).unwrap();
546
547        // median value should map to 0
548        assert!(
549            ds.features[0][2].abs() < 1e-10,
550            "median should map to 0, got {}",
551            ds.features[0][2]
552        );
553    }
554
555    #[test]
556    fn test_robust_scaler_outlier_tolerance() {
557        // Data with an extreme outlier: [1, 2, 3, 4, 1000]
558        let data = vec![1.0, 2.0, 3.0, 4.0, 1000.0];
559
560        // StandardScaler: the outlier heavily influences mean/std
561        let mut ds_std = Dataset::new(vec![data.clone()], vec![0.0; 5], vec!["x".into()], "y");
562        let mut std_scaler = StandardScaler::new();
563        std_scaler.fit_transform(&mut ds_std).unwrap();
564
565        // RobustScaler: outlier has minimal effect on median/IQR
566        let mut ds_rob = Dataset::new(vec![data], vec![0.0; 5], vec!["x".into()], "y");
567        let mut rob_scaler = RobustScaler::new();
568        rob_scaler.fit_transform(&mut ds_rob).unwrap();
569
570        // In StandardScaler, the non-outlier values are squished near 0
571        // because std is dominated by the outlier.
572        // In RobustScaler, the non-outlier values have reasonable spread.
573        let robust_range = ds_rob.features[0][3] - ds_rob.features[0][0];
574        let std_range = ds_std.features[0][3] - ds_std.features[0][0];
575        assert!(
576            robust_range > std_range,
577            "RobustScaler should give wider spread to non-outliers: robust={robust_range:.4} vs std={std_range:.4}"
578        );
579    }
580
581    #[test]
582    fn test_robust_scaler_roundtrip() {
583        let original = vec![2.0, 4.0, 6.0, 8.0];
584        let mut ds = Dataset::new(vec![original.clone()], vec![0.0; 4], vec!["x".into()], "y");
585        let mut scaler = RobustScaler::new();
586        scaler.fit_transform(&mut ds).unwrap();
587        scaler.inverse_transform(&mut ds).unwrap();
588
589        for (a, b) in ds.features[0].iter().zip(original.iter()) {
590            assert!((a - b).abs() < 1e-10, "roundtrip failed: {a} != {b}");
591        }
592    }
593
594    #[test]
595    fn test_standard_scaler_sparse_fit() {
596        let cols = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
597        let csc = CscMatrix::from_dense(&cols);
598
599        let mut scaler = StandardScaler::new();
600        scaler.fit_sparse(&csc).unwrap();
601
602        // Also fit dense for comparison.
603        let ds = Dataset::new(cols, vec![0.0; 5], vec!["x".into()], "y");
604        let mut scaler_d = StandardScaler::new();
605        scaler_d.fit(&ds).unwrap();
606
607        // Means should match.
608        assert!(
609            (scaler.means[0] - scaler_d.means[0]).abs() < 1e-10,
610            "Sparse mean={} vs Dense mean={}",
611            scaler.means[0],
612            scaler_d.means[0]
613        );
614    }
615
616    #[test]
617    fn test_minmax_scaler_sparse_fit() {
618        let cols = vec![vec![0.0, 5.0, 0.0, 10.0, 0.0]];
619        let csc = CscMatrix::from_dense(&cols);
620
621        let mut scaler = MinMaxScaler::new();
622        scaler.fit_sparse(&csc).unwrap();
623
624        assert!((scaler.mins[0] - 0.0).abs() < 1e-10);
625        assert!((scaler.maxs[0] - 10.0).abs() < 1e-10);
626    }
627}