Skip to main content

augurs_forecaster/transforms/
scale.rs

1//! Scalers, including min-max and standard scalers.
2
3use core::f64;
4
5use augurs_core::{FloatIterExt, NanMinMaxResult};
6
7use super::{Error, Transformer};
8
9/// Helper struct holding the min and max for use in a `MinMaxScaler`.
10#[derive(Debug, Clone, Copy)]
11struct MinMax {
12    min: f64,
13    max: f64,
14}
15
16impl MinMax {
17    fn zero_one() -> Self {
18        Self {
19            min: 0.0 + f64::EPSILON,
20            max: 1.0 - f64::EPSILON,
21        }
22    }
23}
24
25/// Parameters for the min-max scaler.
26///
27/// Will be created by the `MinMaxScaler` when it is fit to the data,
28/// or when it is supplied with a custom data range.
29///
30/// We store the scale factor and offset to avoid having to
31/// recalculating them every time the transform is applied.
32///
33/// We store the input scale as well so we can recalculate the
34/// scale factor and offset if the user changes the output scale.
35#[derive(Debug, Clone)]
36struct FittedMinMaxScalerParams {
37    input_scale: MinMax,
38    scale_factor: f64,
39    offset: f64,
40}
41
42impl FittedMinMaxScalerParams {
43    fn new(input_scale: MinMax, output_scale: MinMax) -> Self {
44        let scale_factor =
45            (output_scale.max - output_scale.min) / (input_scale.max - input_scale.min);
46        Self {
47            input_scale,
48            scale_factor,
49            offset: output_scale.min - (input_scale.min * scale_factor),
50        }
51    }
52}
53
54/// A transformer that scales each item to a custom range, defaulting to [0, 1].
55#[derive(Debug, Clone)]
56pub struct MinMaxScaler {
57    output_scale: MinMax,
58    // The parameters learned from the data and used to transform it.
59    // Not known until the transform method is called.
60    params: Option<FittedMinMaxScalerParams>,
61}
62
63impl Default for MinMaxScaler {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl MinMaxScaler {
70    /// Create a new `MinMaxScaler` with the default output range of [0, 1].
71    pub fn new() -> Self {
72        Self {
73            output_scale: MinMax::zero_one(),
74            params: None,
75        }
76    }
77
78    /// Set the output range for the transformation.
79    pub fn with_scaled_range(mut self, min: f64, max: f64) -> Self {
80        self.output_scale = MinMax { min, max };
81        self.params.iter_mut().for_each(|p| {
82            let input_scale = p.input_scale;
83            *p = FittedMinMaxScalerParams::new(input_scale, self.output_scale);
84        });
85        self
86    }
87
88    /// Manually set the input range for the transformation.
89    ///
90    /// This is useful if you know the input range in advance and want to avoid
91    /// the overhead of fitting the scaler to the data during the initial transform,
92    /// and instead want to set the input range manually.
93    ///
94    /// Note that this will override any previously set (or learned) parameters.
95    pub fn with_data_range(mut self, min: f64, max: f64) -> Self {
96        let data_range = MinMax { min, max };
97        self.params = Some(FittedMinMaxScalerParams::new(data_range, self.output_scale));
98        self
99    }
100}
101
102impl Transformer for MinMaxScaler {
103    /// Fit the scaler to the given data.
104    ///
105    /// This will compute the min and max values of the data and store them
106    /// in the `params` field of the scaler.
107    fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
108        let params = match data.iter().copied().nanminmax(true) {
109            NanMinMaxResult::NaN => unreachable!(),
110            e @ NanMinMaxResult::NoElements | e @ NanMinMaxResult::OneElement(_) => {
111                return Err(e.into())
112            }
113            NanMinMaxResult::MinMax(min, max) => {
114                FittedMinMaxScalerParams::new(MinMax { min, max }, self.output_scale)
115            }
116        };
117        self.params = Some(params);
118        Ok(())
119    }
120
121    /// Apply the scaler to the given data.
122    fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
123        let params = self.params.as_ref().ok_or(Error::NotFitted)?;
124        data.iter_mut()
125            .for_each(|x| *x = *x * params.scale_factor + params.offset);
126        Ok(())
127    }
128
129    /// Apply the inverse of the scaler to the given data.
130    fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
131        let params = self.params.as_ref().ok_or(Error::NotFitted)?;
132        data.iter_mut()
133            .for_each(|x| *x = (*x - params.offset) / params.scale_factor);
134        Ok(())
135    }
136}
137
138/// Parameters for the standard scaler.
139#[derive(Debug, Clone)]
140pub struct StandardScaleParams {
141    /// The mean of the data.
142    pub mean: f64,
143    /// The standard deviation of the data.
144    pub std_dev: f64,
145}
146
147impl StandardScaleParams {
148    /// Create a new `StandardScaleParams` with the given mean and standard deviation.
149    pub fn new(mean: f64, std_dev: f64) -> Self {
150        Self { mean, std_dev }
151    }
152
153    /// Create a new `StandardScaleParams` from the given data.
154    ///
155    /// Note: this uses Welford's online algorithm to compute mean and variance in a single pass,
156    /// since we only have an iterator. The standard deviation is calculated using the
157    /// biased estimator, for parity with the [scikit-learn implementation][sklearn].
158    ///
159    /// [sklearn]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html
160    pub fn from_data<T>(data: T) -> Self
161    where
162        T: Iterator<Item = f64>,
163    {
164        // Use Welford's online algorithm to compute mean and variance in a single pass,
165        // since we only have an iterator.
166        let mut count = 0_u64;
167        let mut mean = 0.0;
168        let mut m2 = 0.0;
169
170        for x in data {
171            count += 1;
172            let delta = x - mean;
173            mean += delta / count as f64;
174            let delta2 = x - mean;
175            m2 += delta * delta2;
176        }
177
178        // Handle empty iterator case
179        if count == 0 {
180            return Self::new(0.0, 1.0);
181        }
182
183        // Calculate standard deviation
184        let std_dev = (m2 / count as f64).sqrt();
185
186        Self { mean, std_dev }
187    }
188
189    /// Create a new `StandardScaleParams` from the given data, ignoring NaN values.
190    ///
191    /// This is useful if you have NaN values in your data and want to avoid
192    /// propagating them during the standardization process.
193    ///
194    /// See [`StandardScaleParams::from_data`] for more details on the implementation.
195    pub fn from_data_ignoring_nans<T: Iterator<Item = f64>>(data: T) -> Self {
196        Self::from_data(data.filter(|x| !x.is_nan()))
197    }
198}
199
200/// A transformer that scales items to have zero mean and unit standard deviation.
201///
202/// The standard score of a sample `x` is calculated as:
203///
204/// ```text
205/// z = (x - mean) / std_dev
206/// ```
207///
208/// where `mean` is the mean and s is the standard deviation of the data first passed to
209/// `transform` (or provided via `with_parameters`).
210///
211/// # Implementation
212///
213/// This transformer uses Welford's online algorithm to compute mean and variance in
214/// one pass over the data. The standard deviation is calculated using the biased
215/// estimator, for parity with the [scikit-learn implementation][sklearn].
216///
217/// [sklearn]: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/preprocessing/_data.py#L128
218///
219/// # Example
220///
221/// ## Using the default constructor
222///
223/// ```
224/// use augurs_forecaster::transforms::{StandardScaler, Transformer};
225///
226/// let mut data = vec![1.0, 2.0, 3.0];
227/// let mut scaler = StandardScaler::new();
228/// scaler.fit_transform(&mut data);
229///
230/// assert_eq!(data, vec![-1.224744871391589, 0.0, 1.224744871391589]);
231/// ```
232#[derive(Debug, Clone, Default)]
233pub struct StandardScaler {
234    params: Option<StandardScaleParams>,
235    ignore_nans: bool,
236}
237
238impl StandardScaler {
239    /// Create a new `StandardScaler`.
240    pub fn new() -> Self {
241        Self::default()
242    }
243
244    /// Set the parameters for the scaler.
245    ///
246    /// This is useful if you know the mean and standard deviation in advance
247    /// and want to avoid the overhead of fitting the scaler to the data
248    /// during the initial transform, and instead want to set the parameters
249    /// manually.
250    pub fn with_parameters(mut self, params: StandardScaleParams) -> Self {
251        self.params = Some(params);
252        self
253    }
254
255    /// Whether to ignore NaN values when calculating the scaler parameters.
256    ///
257    /// If `true`, NaN values will be ignored when calculating the scaler parameters.
258    /// This can be useful if you have NaN values in your data and want to avoid
259    /// errors when calculating the scaler parameters.
260    ///
261    /// Defaults to `false`.
262    pub fn ignore_nans(mut self, ignore_nans: bool) -> Self {
263        self.ignore_nans = ignore_nans;
264        self
265    }
266}
267
268impl Transformer for StandardScaler {
269    fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
270        self.params = Some(if self.ignore_nans {
271            StandardScaleParams::from_data_ignoring_nans(data.iter().copied())
272        } else {
273            StandardScaleParams::from_data(data.iter().copied())
274        });
275        Ok(())
276    }
277
278    fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
279        let params = self.params.as_ref().ok_or(Error::NotFitted)?;
280        data.iter_mut()
281            .for_each(|x| *x = (*x - params.mean) / params.std_dev);
282        Ok(())
283    }
284
285    fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
286        let params = self.params.as_ref().ok_or(Error::NotFitted)?;
287        data.iter_mut()
288            .for_each(|x| *x = (*x * params.std_dev) + params.mean);
289        Ok(())
290    }
291}
292
293#[cfg(test)]
294mod test {
295    use augurs_testing::{assert_all_close, assert_approx_eq};
296
297    use super::*;
298
299    #[test]
300    fn min_max_scale() {
301        let mut data = vec![1.0, 2.0, 3.0];
302        let expected = vec![0.0, 0.5, 1.0];
303        let mut scaler = MinMaxScaler::new();
304        scaler.fit_transform(&mut data).unwrap();
305        assert_all_close(&expected, &data);
306    }
307
308    #[test]
309    fn min_max_scale_custom() {
310        let mut data = vec![1.0, 2.0, 3.0];
311        let expected = vec![0.0, 5.0, 10.0];
312        let mut scaler = MinMaxScaler::new().with_scaled_range(0.0, 10.0);
313        scaler.fit_transform(&mut data).unwrap();
314        assert_all_close(&expected, &data);
315    }
316
317    #[test]
318    fn inverse_min_max_scale() {
319        let mut data = vec![0.0, 0.5, 1.0];
320        let expected = vec![1.0, 2.0, 3.0];
321        let scaler = MinMaxScaler::new().with_data_range(1.0, 3.0);
322        scaler.inverse_transform(&mut data).unwrap();
323        assert_all_close(&expected, &data);
324    }
325
326    #[test]
327    fn inverse_min_max_scale_custom() {
328        let mut data = vec![0.0, 5.0, 10.0];
329        let expected = vec![1.0, 2.0, 3.0];
330        let scaler = MinMaxScaler::new()
331            .with_scaled_range(0.0, 10.0)
332            .with_data_range(1.0, 3.0);
333        scaler.inverse_transform(&mut data).unwrap();
334        assert_all_close(&expected, &data);
335    }
336
337    #[test]
338    fn standard_scale() {
339        let mut data = vec![1.0, 2.0, 3.0];
340        // We use the biased estimator for standard deviation so the result is
341        // not necessarily obvious.
342        let expected = vec![-1.224744871391589, 0.0, 1.224744871391589];
343        let mut scaler = StandardScaler::new(); // 2.0, 1.0); // mean=2, std=1
344        scaler.fit_transform(&mut data).unwrap();
345        assert_all_close(&expected, &data);
346    }
347
348    #[test]
349    fn standard_scale_custom() {
350        let mut data = vec![1.0, 2.0, 3.0];
351        let expected = vec![-1.0, 0.0, 1.0];
352        let params = StandardScaleParams::new(2.0, 1.0); // mean=2, std=1
353        let scaler = StandardScaler::new().with_parameters(params);
354        scaler.transform(&mut data).unwrap();
355        assert_all_close(&expected, &data);
356    }
357
358    #[test]
359    fn inverse_standard_scale() {
360        let mut data = vec![-1.0, 0.0, 1.0];
361        let expected = vec![1.0, 2.0, 3.0];
362        let params = StandardScaleParams::new(2.0, 1.0); // mean=2, std=1
363        let scaler = StandardScaler::new().with_parameters(params);
364        scaler.inverse_transform(&mut data).unwrap();
365        assert_all_close(&expected, &data);
366    }
367
368    #[test]
369    fn standard_scale_params_from_data() {
370        // Test case 1: Simple sequence
371        let data = vec![1.0, 2.0, 3.0];
372        let params = StandardScaleParams::from_data(data.into_iter());
373        assert_approx_eq!(params.mean, 2.0);
374        assert_approx_eq!(params.std_dev, 0.816496580927726);
375
376        // Test case 2: More complex data
377        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
378        let params = StandardScaleParams::from_data(data.into_iter());
379        assert_approx_eq!(params.mean, 5.0);
380        assert_approx_eq!(params.std_dev, 2.0);
381
382        // Test case 3: Empty iterator should return default values
383        let data: Vec<f64> = vec![];
384        let params = StandardScaleParams::from_data(data.into_iter());
385        assert_approx_eq!(params.mean, 0.0);
386        assert_approx_eq!(params.std_dev, 1.0);
387
388        // Test case 4: Single value
389        let data = vec![42.0];
390        let params = StandardScaleParams::from_data(data.into_iter());
391        assert_approx_eq!(params.mean, 42.0);
392        assert_approx_eq!(params.std_dev, 0.0); // technically undefined, but we return 0
393    }
394
395    #[test]
396    fn min_max_scale_with_nan() {
397        let mut data = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
398        let expected = vec![0.0, f64::NAN, 0.5, 1.0, f64::NAN];
399        let mut scaler = MinMaxScaler::new();
400        scaler.fit_transform(&mut data).unwrap();
401        assert_all_close(&expected, &data);
402    }
403
404    #[test]
405    fn inverse_min_max_scale_with_nan() {
406        let mut data = vec![0.0, f64::NAN, 0.5, 1.0, f64::NAN];
407        let expected = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
408        let scaler = MinMaxScaler::new().with_data_range(1.0, 3.0);
409        scaler.inverse_transform(&mut data).unwrap();
410        assert_all_close(&expected, &data);
411    }
412
413    #[test]
414    fn standard_scale_with_nan() {
415        let mut data = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
416        let expected = vec![
417            -1.224744871391589,
418            f64::NAN,
419            0.0,
420            1.224744871391589,
421            f64::NAN,
422        ];
423        let mut scaler = StandardScaler::new().ignore_nans(true);
424        scaler.fit_transform(&mut data).unwrap();
425        assert_all_close(&expected, &data);
426    }
427
428    #[test]
429    fn standard_scale_params_from_data_with_nan() {
430        // Test that NaN values are properly ignored in parameter calculation
431        let data = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
432        let params = StandardScaleParams::from_data_ignoring_nans(data.into_iter());
433        assert_approx_eq!(params.mean, 2.0);
434        assert_approx_eq!(params.std_dev, 0.816496580927726);
435    }
436
437    #[test]
438    fn inverse_standard_scale_with_nan() {
439        let mut data = vec![-1.0, f64::NAN, 0.0, 1.0, f64::NAN];
440        let expected = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
441        let params = StandardScaleParams::new(2.0, 1.0);
442        let scaler = StandardScaler::new().with_parameters(params);
443        scaler.inverse_transform(&mut data).unwrap();
444        assert_all_close(&expected, &data);
445    }
446}