1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
10pub struct MinMaxScaler<F: Float> {
11 pub feature_min: F,
12 pub feature_max: F,
13}
14
15impl<F: Float> MinMaxScaler<F> {
16 pub fn new() -> Self {
18 Self {
19 feature_min: F::zero(),
20 feature_max: F::one(),
21 }
22 }
23
24 pub fn with_range(mut self, min: F, max: F) -> Self {
26 self.feature_min = min;
27 self.feature_max = max;
28 self
29 }
30}
31
32impl<F: Float> Default for MinMaxScaler<F> {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
41pub struct FittedMinMaxScaler<F: Float> {
42 data_min: Array1<F>,
43 data_max: Array1<F>,
44 feature_min: F,
45 feature_max: F,
46}
47
48impl<F: Float> FitUnsupervised<F> for MinMaxScaler<F> {
49 type Fitted = FittedMinMaxScaler<F>;
50
51 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
52 if x.is_empty() {
53 return Err(RustMlError::EmptyInput("input array is empty".into()));
54 }
55 if self.feature_min >= self.feature_max {
56 return Err(RustMlError::InvalidParameter(
57 "feature_min must be less than feature_max".into(),
58 ));
59 }
60
61 let data_min = x.fold_axis(Axis(0), F::infinity(), |&a, &b| a.min(b));
62 let data_max = x.fold_axis(Axis(0), F::neg_infinity(), |&a, &b| a.max(b));
63
64 Ok(FittedMinMaxScaler {
65 data_min,
66 data_max,
67 feature_min: self.feature_min,
68 feature_max: self.feature_max,
69 })
70 }
71}
72
73impl<F: Float> Transform<F> for FittedMinMaxScaler<F> {
74 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
75 if x.ncols() != self.data_min.len() {
76 return Err(RustMlError::ShapeMismatch(format!(
77 "expected {} features, got {}",
78 self.data_min.len(),
79 x.ncols()
80 )));
81 }
82
83 let range = self.feature_max - self.feature_min;
84 let mut result = x.to_owned();
85
86 for mut row in result.rows_mut() {
87 for (j, val) in row.iter_mut().enumerate() {
88 let data_range = self.data_max[j] - self.data_min[j];
89 if data_range > F::from_f64(1e-15).unwrap() {
90 *val = (*val - self.data_min[j]) / data_range * range + self.feature_min;
91 } else {
92 *val = self.feature_min;
93 }
94 }
95 }
96 Ok(result)
97 }
98}
99
100impl<F: Float> InverseTransform<F> for FittedMinMaxScaler<F> {
101 fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
102 if x.ncols() != self.data_min.len() {
103 return Err(RustMlError::ShapeMismatch(format!(
104 "expected {} features, got {}",
105 self.data_min.len(),
106 x.ncols()
107 )));
108 }
109
110 let range = self.feature_max - self.feature_min;
111 let mut result = x.to_owned();
112
113 for mut row in result.rows_mut() {
114 for (j, val) in row.iter_mut().enumerate() {
115 let data_range = self.data_max[j] - self.data_min[j];
116 if data_range > F::from_f64(1e-15).unwrap() {
117 *val = (*val - self.feature_min) / range * data_range + self.data_min[j];
118 } else {
119 *val = self.data_min[j];
120 }
121 }
122 }
123 Ok(result)
124 }
125}
126
127impl<F: Float> FittedMinMaxScaler<F> {
128 pub fn data_min(&self) -> &Array1<F> {
129 &self.data_min
130 }
131
132 pub fn data_max(&self) -> &Array1<F> {
133 &self.data_max
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use approx::assert_abs_diff_eq;
141 use ndarray::array;
142
143 #[test]
144 fn test_fit_transform_default() {
145 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
146 let scaler = MinMaxScaler::<f64>::default();
147 let fitted = scaler.fit(&x).unwrap();
148 let transformed = fitted.transform(&x).unwrap();
149
150 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
151 assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-10);
152 assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10);
153 assert_abs_diff_eq!(transformed[[2, 1]], 1.0, epsilon = 1e-10);
154 }
155
156 #[test]
157 fn test_custom_range() {
158 let x = array![[1.0], [2.0], [3.0]];
159 let scaler = MinMaxScaler {
160 feature_min: -1.0,
161 feature_max: 1.0,
162 };
163 let fitted = scaler.fit(&x).unwrap();
164 let transformed = fitted.transform(&x).unwrap();
165
166 assert_abs_diff_eq!(transformed[[0, 0]], -1.0, epsilon = 1e-10);
167 assert_abs_diff_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10);
168 assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-10);
169 }
170
171 #[test]
172 fn test_inverse_transform_roundtrip() {
173 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]];
174 let scaler = MinMaxScaler::<f64>::default();
175 let fitted = scaler.fit(&x).unwrap();
176 let transformed = fitted.transform(&x).unwrap();
177 let recovered = fitted.inverse_transform(&transformed).unwrap();
178
179 for (a, b) in x.iter().zip(recovered.iter()) {
180 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
181 }
182 }
183
184 #[test]
185 fn test_large_values() {
186 let x = array![[1e10], [2e10], [3e10]];
187 let scaler = MinMaxScaler::<f64>::default();
188 let fitted = scaler.fit(&x).unwrap();
189 let transformed = fitted.transform(&x).unwrap();
190
191 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-8);
192 assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-8);
193 for &v in transformed.iter() {
194 assert!(v.is_finite(), "large values produced non-finite: {}", v);
195 }
196 }
197
198 #[test]
199 fn test_small_values() {
200 let x = array![[1e-10], [2e-10], [3e-10]];
201 let scaler = MinMaxScaler::<f64>::default();
202 let fitted = scaler.fit(&x).unwrap();
203 let transformed = fitted.transform(&x).unwrap();
204
205 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-8);
206 assert_abs_diff_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-8);
207 }
208
209 #[test]
210 fn test_constant_column() {
211 let x = array![[5.0, 1.0], [5.0, 2.0], [5.0, 3.0]];
213 let scaler = MinMaxScaler::<f64>::default();
214 let fitted = scaler.fit(&x).unwrap();
215 let transformed = fitted.transform(&x).unwrap();
216
217 for &v in transformed.iter() {
218 assert!(v.is_finite(), "constant column produced non-finite: {}", v);
219 }
220 }
221}