Skip to main content

frlearn_preprocess/
range_normaliser.rs

1use frlearn_core::{FrError, FrResult, Matrix};
2use frlearn_math::{clamp01, safe_divide};
3use ndarray::Array1;
4
5use crate::{TransformerModel, validate_feature_count};
6
7#[derive(Debug, Clone, Copy)]
8pub struct RangeNormaliser {
9    pub eps: f64,
10}
11
12impl Default for RangeNormaliser {
13    fn default() -> Self {
14        Self { eps: 1e-12 }
15    }
16}
17
18impl RangeNormaliser {
19    pub fn fit(&self, x: &Matrix) -> RangeNormaliserModel {
20        let n_cols = x.ncols();
21        let mut min = Array1::<f64>::zeros(n_cols);
22        let mut max = Array1::<f64>::zeros(n_cols);
23
24        if x.nrows() > 0 {
25            for col_idx in 0..n_cols {
26                let column = x.column(col_idx);
27                let min_value = column
28                    .iter()
29                    .copied()
30                    .filter(|value| value.is_finite())
31                    .fold(f64::INFINITY, f64::min);
32                let max_value = column
33                    .iter()
34                    .copied()
35                    .filter(|value| value.is_finite())
36                    .fold(f64::NEG_INFINITY, f64::max);
37
38                min[col_idx] = if min_value.is_finite() {
39                    min_value
40                } else {
41                    0.0
42                };
43                max[col_idx] = if max_value.is_finite() {
44                    max_value
45                } else {
46                    0.0
47                };
48            }
49        }
50
51        RangeNormaliserModel {
52            min,
53            max,
54            eps: self.eps.max(0.0),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct RangeNormaliserModel {
61    pub min: Array1<f64>,
62    pub max: Array1<f64>,
63    pub eps: f64,
64}
65
66impl RangeNormaliserModel {
67    pub fn transform(&self, x: &Matrix) -> Matrix {
68        self.try_transform(x)
69            .unwrap_or_else(|_| Matrix::zeros((x.nrows(), x.ncols())))
70    }
71
72    fn try_transform(&self, x: &Matrix) -> FrResult<Matrix> {
73        if self.min.len() != self.max.len() {
74            return Err(FrError::InvalidInput(format!(
75                "model min/max length mismatch: {} vs {}",
76                self.min.len(),
77                self.max.len()
78            )));
79        }
80
81        validate_feature_count(x, self.min.len())?;
82
83        let mut output = Matrix::zeros((x.nrows(), x.ncols()));
84        for row_idx in 0..x.nrows() {
85            for col_idx in 0..x.ncols() {
86                let value = x[[row_idx, col_idx]];
87                let numerator = if value.is_finite() {
88                    value - self.min[col_idx]
89                } else {
90                    0.0
91                };
92                let denominator = self.max[col_idx] - self.min[col_idx] + self.eps;
93                let scaled = safe_divide(numerator, denominator, 0.0);
94                output[[row_idx, col_idx]] = clamp01(scaled);
95            }
96        }
97
98        Ok(output)
99    }
100}
101
102impl TransformerModel for RangeNormaliserModel {
103    fn transform(&self, x: &Matrix) -> FrResult<Matrix> {
104        self.try_transform(x)
105    }
106}