1use core::f64;
4
5use augurs_core::{FloatIterExt, NanMinMaxResult};
6
7use super::{Error, Transformer};
8
9#[derive(Debug, Clone, Copy)]
11struct MinMax {
12 min: f64,
13 max: f64,
14}
15
16impl MinMax {
17 fn zero_one() -> Self {
18 Self {
19 min: 0.0 + f64::EPSILON,
20 max: 1.0 - f64::EPSILON,
21 }
22 }
23}
24
25#[derive(Debug, Clone)]
36struct FittedMinMaxScalerParams {
37 input_scale: MinMax,
38 scale_factor: f64,
39 offset: f64,
40}
41
42impl FittedMinMaxScalerParams {
43 fn new(input_scale: MinMax, output_scale: MinMax) -> Self {
44 let scale_factor =
45 (output_scale.max - output_scale.min) / (input_scale.max - input_scale.min);
46 Self {
47 input_scale,
48 scale_factor,
49 offset: output_scale.min - (input_scale.min * scale_factor),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MinMaxScaler {
57 output_scale: MinMax,
58 params: Option<FittedMinMaxScalerParams>,
61}
62
63impl Default for MinMaxScaler {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl MinMaxScaler {
70 pub fn new() -> Self {
72 Self {
73 output_scale: MinMax::zero_one(),
74 params: None,
75 }
76 }
77
78 pub fn with_scaled_range(mut self, min: f64, max: f64) -> Self {
80 self.output_scale = MinMax { min, max };
81 self.params.iter_mut().for_each(|p| {
82 let input_scale = p.input_scale;
83 *p = FittedMinMaxScalerParams::new(input_scale, self.output_scale);
84 });
85 self
86 }
87
88 pub fn with_data_range(mut self, min: f64, max: f64) -> Self {
96 let data_range = MinMax { min, max };
97 self.params = Some(FittedMinMaxScalerParams::new(data_range, self.output_scale));
98 self
99 }
100}
101
102impl Transformer for MinMaxScaler {
103 fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
108 let params = match data.iter().copied().nanminmax(true) {
109 NanMinMaxResult::NaN => unreachable!(),
110 e @ NanMinMaxResult::NoElements | e @ NanMinMaxResult::OneElement(_) => {
111 return Err(e.into())
112 }
113 NanMinMaxResult::MinMax(min, max) => {
114 FittedMinMaxScalerParams::new(MinMax { min, max }, self.output_scale)
115 }
116 };
117 self.params = Some(params);
118 Ok(())
119 }
120
121 fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
123 let params = self.params.as_ref().ok_or(Error::NotFitted)?;
124 data.iter_mut()
125 .for_each(|x| *x = *x * params.scale_factor + params.offset);
126 Ok(())
127 }
128
129 fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
131 let params = self.params.as_ref().ok_or(Error::NotFitted)?;
132 data.iter_mut()
133 .for_each(|x| *x = (*x - params.offset) / params.scale_factor);
134 Ok(())
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct StandardScaleParams {
141 pub mean: f64,
143 pub std_dev: f64,
145}
146
147impl StandardScaleParams {
148 pub fn new(mean: f64, std_dev: f64) -> Self {
150 Self { mean, std_dev }
151 }
152
153 pub fn from_data<T>(data: T) -> Self
161 where
162 T: Iterator<Item = f64>,
163 {
164 let mut count = 0_u64;
167 let mut mean = 0.0;
168 let mut m2 = 0.0;
169
170 for x in data {
171 count += 1;
172 let delta = x - mean;
173 mean += delta / count as f64;
174 let delta2 = x - mean;
175 m2 += delta * delta2;
176 }
177
178 if count == 0 {
180 return Self::new(0.0, 1.0);
181 }
182
183 let std_dev = (m2 / count as f64).sqrt();
185
186 Self { mean, std_dev }
187 }
188
189 pub fn from_data_ignoring_nans<T: Iterator<Item = f64>>(data: T) -> Self {
196 Self::from_data(data.filter(|x| !x.is_nan()))
197 }
198}
199
200#[derive(Debug, Clone, Default)]
233pub struct StandardScaler {
234 params: Option<StandardScaleParams>,
235 ignore_nans: bool,
236}
237
238impl StandardScaler {
239 pub fn new() -> Self {
241 Self::default()
242 }
243
244 pub fn with_parameters(mut self, params: StandardScaleParams) -> Self {
251 self.params = Some(params);
252 self
253 }
254
255 pub fn ignore_nans(mut self, ignore_nans: bool) -> Self {
263 self.ignore_nans = ignore_nans;
264 self
265 }
266}
267
268impl Transformer for StandardScaler {
269 fn fit(&mut self, data: &[f64]) -> Result<(), Error> {
270 self.params = Some(if self.ignore_nans {
271 StandardScaleParams::from_data_ignoring_nans(data.iter().copied())
272 } else {
273 StandardScaleParams::from_data(data.iter().copied())
274 });
275 Ok(())
276 }
277
278 fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
279 let params = self.params.as_ref().ok_or(Error::NotFitted)?;
280 data.iter_mut()
281 .for_each(|x| *x = (*x - params.mean) / params.std_dev);
282 Ok(())
283 }
284
285 fn inverse_transform(&self, data: &mut [f64]) -> Result<(), Error> {
286 let params = self.params.as_ref().ok_or(Error::NotFitted)?;
287 data.iter_mut()
288 .for_each(|x| *x = (*x * params.std_dev) + params.mean);
289 Ok(())
290 }
291}
292
293#[cfg(test)]
294mod test {
295 use augurs_testing::{assert_all_close, assert_approx_eq};
296
297 use super::*;
298
299 #[test]
300 fn min_max_scale() {
301 let mut data = vec![1.0, 2.0, 3.0];
302 let expected = vec![0.0, 0.5, 1.0];
303 let mut scaler = MinMaxScaler::new();
304 scaler.fit_transform(&mut data).unwrap();
305 assert_all_close(&expected, &data);
306 }
307
308 #[test]
309 fn min_max_scale_custom() {
310 let mut data = vec![1.0, 2.0, 3.0];
311 let expected = vec![0.0, 5.0, 10.0];
312 let mut scaler = MinMaxScaler::new().with_scaled_range(0.0, 10.0);
313 scaler.fit_transform(&mut data).unwrap();
314 assert_all_close(&expected, &data);
315 }
316
317 #[test]
318 fn inverse_min_max_scale() {
319 let mut data = vec![0.0, 0.5, 1.0];
320 let expected = vec![1.0, 2.0, 3.0];
321 let scaler = MinMaxScaler::new().with_data_range(1.0, 3.0);
322 scaler.inverse_transform(&mut data).unwrap();
323 assert_all_close(&expected, &data);
324 }
325
326 #[test]
327 fn inverse_min_max_scale_custom() {
328 let mut data = vec![0.0, 5.0, 10.0];
329 let expected = vec![1.0, 2.0, 3.0];
330 let scaler = MinMaxScaler::new()
331 .with_scaled_range(0.0, 10.0)
332 .with_data_range(1.0, 3.0);
333 scaler.inverse_transform(&mut data).unwrap();
334 assert_all_close(&expected, &data);
335 }
336
337 #[test]
338 fn standard_scale() {
339 let mut data = vec![1.0, 2.0, 3.0];
340 let expected = vec![-1.224744871391589, 0.0, 1.224744871391589];
343 let mut scaler = StandardScaler::new(); scaler.fit_transform(&mut data).unwrap();
345 assert_all_close(&expected, &data);
346 }
347
348 #[test]
349 fn standard_scale_custom() {
350 let mut data = vec![1.0, 2.0, 3.0];
351 let expected = vec![-1.0, 0.0, 1.0];
352 let params = StandardScaleParams::new(2.0, 1.0); let scaler = StandardScaler::new().with_parameters(params);
354 scaler.transform(&mut data).unwrap();
355 assert_all_close(&expected, &data);
356 }
357
358 #[test]
359 fn inverse_standard_scale() {
360 let mut data = vec![-1.0, 0.0, 1.0];
361 let expected = vec![1.0, 2.0, 3.0];
362 let params = StandardScaleParams::new(2.0, 1.0); let scaler = StandardScaler::new().with_parameters(params);
364 scaler.inverse_transform(&mut data).unwrap();
365 assert_all_close(&expected, &data);
366 }
367
368 #[test]
369 fn standard_scale_params_from_data() {
370 let data = vec![1.0, 2.0, 3.0];
372 let params = StandardScaleParams::from_data(data.into_iter());
373 assert_approx_eq!(params.mean, 2.0);
374 assert_approx_eq!(params.std_dev, 0.816496580927726);
375
376 let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
378 let params = StandardScaleParams::from_data(data.into_iter());
379 assert_approx_eq!(params.mean, 5.0);
380 assert_approx_eq!(params.std_dev, 2.0);
381
382 let data: Vec<f64> = vec![];
384 let params = StandardScaleParams::from_data(data.into_iter());
385 assert_approx_eq!(params.mean, 0.0);
386 assert_approx_eq!(params.std_dev, 1.0);
387
388 let data = vec![42.0];
390 let params = StandardScaleParams::from_data(data.into_iter());
391 assert_approx_eq!(params.mean, 42.0);
392 assert_approx_eq!(params.std_dev, 0.0); }
394
395 #[test]
396 fn min_max_scale_with_nan() {
397 let mut data = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
398 let expected = vec![0.0, f64::NAN, 0.5, 1.0, f64::NAN];
399 let mut scaler = MinMaxScaler::new();
400 scaler.fit_transform(&mut data).unwrap();
401 assert_all_close(&expected, &data);
402 }
403
404 #[test]
405 fn inverse_min_max_scale_with_nan() {
406 let mut data = vec![0.0, f64::NAN, 0.5, 1.0, f64::NAN];
407 let expected = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
408 let scaler = MinMaxScaler::new().with_data_range(1.0, 3.0);
409 scaler.inverse_transform(&mut data).unwrap();
410 assert_all_close(&expected, &data);
411 }
412
413 #[test]
414 fn standard_scale_with_nan() {
415 let mut data = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
416 let expected = vec![
417 -1.224744871391589,
418 f64::NAN,
419 0.0,
420 1.224744871391589,
421 f64::NAN,
422 ];
423 let mut scaler = StandardScaler::new().ignore_nans(true);
424 scaler.fit_transform(&mut data).unwrap();
425 assert_all_close(&expected, &data);
426 }
427
428 #[test]
429 fn standard_scale_params_from_data_with_nan() {
430 let data = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
432 let params = StandardScaleParams::from_data_ignoring_nans(data.into_iter());
433 assert_approx_eq!(params.mean, 2.0);
434 assert_approx_eq!(params.std_dev, 0.816496580927726);
435 }
436
437 #[test]
438 fn inverse_standard_scale_with_nan() {
439 let mut data = vec![-1.0, f64::NAN, 0.0, 1.0, f64::NAN];
440 let expected = vec![1.0, f64::NAN, 2.0, 3.0, f64::NAN];
441 let params = StandardScaleParams::new(2.0, 1.0);
442 let scaler = StandardScaler::new().with_parameters(params);
443 scaler.inverse_transform(&mut data).unwrap();
444 assert_all_close(&expected, &data);
445 }
446}