1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6pub enum BinStrategy {
7 Uniform,
9 Quantile,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
15pub enum EncodeStrategy {
16 Ordinal,
18 Onehot,
20}
21
22#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct KBinsDiscretizer {
29 pub n_bins: usize,
31 pub strategy: BinStrategy,
33 pub encode: EncodeStrategy,
35}
36
37impl KBinsDiscretizer {
38 pub fn new() -> Self {
40 Self {
41 n_bins: 5,
42 strategy: BinStrategy::Quantile,
43 encode: EncodeStrategy::Ordinal,
44 }
45 }
46
47 pub fn n_bins(mut self, n_bins: usize) -> Self {
49 self.n_bins = n_bins;
50 self
51 }
52
53 pub fn strategy(mut self, strategy: BinStrategy) -> Self {
55 self.strategy = strategy;
56 self
57 }
58
59 pub fn encode(mut self, encode: EncodeStrategy) -> Self {
61 self.encode = encode;
62 self
63 }
64}
65
66impl Default for KBinsDiscretizer {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
74#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
75pub struct FittedKBinsDiscretizer<F: Float> {
76 bin_edges: Vec<Vec<F>>,
78 n_bins: usize,
79 encode: EncodeStrategy,
80}
81
82fn percentile_sorted<F: Float>(sorted: &[F], p: f64) -> F {
84 let n = sorted.len();
85 if n == 1 {
86 return sorted[0];
87 }
88 let idx = p * (n - 1) as f64;
89 let lo = idx.floor() as usize;
90 let hi = idx.ceil().min((n - 1) as f64) as usize;
91 if lo == hi {
92 sorted[lo]
93 } else {
94 let frac = F::from_f64(idx - lo as f64).unwrap();
95 sorted[lo] * (F::one() - frac) + sorted[hi] * frac
96 }
97}
98
99impl<F: Float> FitUnsupervised<F> for KBinsDiscretizer {
100 type Fitted = FittedKBinsDiscretizer<F>;
101
102 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
103 if x.is_empty() {
104 return Err(RustMlError::EmptyInput("input array is empty".into()));
105 }
106 if self.n_bins < 2 {
107 return Err(RustMlError::InvalidParameter(
108 "n_bins must be at least 2".into(),
109 ));
110 }
111
112 let ncols = x.ncols();
113 let mut bin_edges = Vec::with_capacity(ncols);
114
115 for j in 0..ncols {
116 let mut col: Vec<F> = x.column(j).to_vec();
117 col.sort_by(|a, b| a.partial_cmp(b).unwrap());
118
119 let edges = match self.strategy {
120 BinStrategy::Uniform => {
121 let min_val = col[0];
122 let max_val = col[col.len() - 1];
123 let range = max_val - min_val;
124 let step = range / F::from_usize(self.n_bins).unwrap();
125 let mut e = Vec::with_capacity(self.n_bins + 1);
126 for i in 0..=self.n_bins {
127 e.push(min_val + step * F::from_usize(i).unwrap());
128 }
129 e
130 }
131 BinStrategy::Quantile => {
132 let mut e = Vec::with_capacity(self.n_bins + 1);
133 for i in 0..=self.n_bins {
134 let p = i as f64 / self.n_bins as f64;
135 e.push(percentile_sorted(&col, p));
136 }
137 e
138 }
139 };
140
141 bin_edges.push(edges);
142 }
143
144 Ok(FittedKBinsDiscretizer {
145 bin_edges,
146 n_bins: self.n_bins,
147 encode: self.encode,
148 })
149 }
150}
151
152fn find_bin<F: Float>(val: F, edges: &[F], n_bins: usize) -> usize {
155 let mut lo = 0;
157 let mut hi = edges.len() - 1;
158
159 if val <= edges[0] {
161 return 0;
162 }
163 if val >= edges[edges.len() - 1] {
164 return n_bins - 1;
165 }
166
167 while lo + 1 < hi {
168 let mid = (lo + hi) / 2;
169 if edges[mid] <= val {
170 lo = mid;
171 } else {
172 hi = mid;
173 }
174 }
175
176 lo.min(n_bins - 1)
179}
180
181impl<F: Float> Transform<F> for FittedKBinsDiscretizer<F> {
182 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
183 let expected_cols = self.bin_edges.len();
184 if x.ncols() != expected_cols {
185 return Err(RustMlError::ShapeMismatch(format!(
186 "expected {} features, got {}",
187 expected_cols,
188 x.ncols()
189 )));
190 }
191
192 match self.encode {
193 EncodeStrategy::Ordinal => {
194 let mut result = Array2::<F>::zeros(x.raw_dim());
195 for i in 0..x.nrows() {
196 for j in 0..x.ncols() {
197 let bin = find_bin(x[[i, j]], &self.bin_edges[j], self.n_bins);
198 result[[i, j]] = F::from_usize(bin).unwrap();
199 }
200 }
201 Ok(result)
202 }
203 EncodeStrategy::Onehot => {
204 let out_cols = expected_cols * self.n_bins;
205 let mut result = Array2::<F>::zeros((x.nrows(), out_cols));
206 for i in 0..x.nrows() {
207 for j in 0..x.ncols() {
208 let bin = find_bin(x[[i, j]], &self.bin_edges[j], self.n_bins);
209 let col_offset = j * self.n_bins + bin;
210 result[[i, col_offset]] = F::one();
211 }
212 }
213 Ok(result)
214 }
215 }
216 }
217}
218
219impl<F: Float> FittedKBinsDiscretizer<F> {
220 pub fn bin_edges(&self) -> &Vec<Vec<F>> {
222 &self.bin_edges
223 }
224
225 pub fn n_bins(&self) -> usize {
227 self.n_bins
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use approx::assert_abs_diff_eq;
235 use ndarray::array;
236
237 #[test]
238 fn test_uniform_ordinal() {
239 let x = array![
240 [0.0, 0.0],
241 [2.5, 5.0],
242 [5.0, 10.0],
243 [7.5, 15.0],
244 [10.0, 20.0],
245 ];
246 let kbd = KBinsDiscretizer::new()
247 .n_bins(4)
248 .strategy(BinStrategy::Uniform)
249 .encode(EncodeStrategy::Ordinal);
250 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
251 let transformed = fitted.transform(&x).unwrap();
252
253 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10);
256 assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
257 assert_abs_diff_eq!(transformed[[2, 0]], 2.0, epsilon = 1e-10);
258 assert_abs_diff_eq!(transformed[[4, 0]], 3.0, epsilon = 1e-10);
259 }
260
261 #[test]
262 fn test_quantile_ordinal() {
263 let x = array![
264 [1.0],
265 [2.0],
266 [3.0],
267 [4.0],
268 [5.0],
269 [6.0],
270 [7.0],
271 [8.0],
272 [9.0],
273 [10.0],
274 ];
275 let kbd = KBinsDiscretizer::new()
276 .n_bins(5)
277 .strategy(BinStrategy::Quantile)
278 .encode(EncodeStrategy::Ordinal);
279 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
280 let transformed = fitted.transform(&x).unwrap();
281
282 for &v in transformed.iter() {
284 assert!(v >= 0.0 && v <= 4.0, "bin index out of range: {}", v);
285 }
286
287 for i in 1..x.nrows() {
289 assert!(
290 transformed[[i, 0]] >= transformed[[i - 1, 0]],
291 "monotonicity violated at row {}",
292 i
293 );
294 }
295 }
296
297 #[test]
298 fn test_onehot_encoding() {
299 let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
300 let kbd = KBinsDiscretizer::new()
301 .n_bins(3)
302 .strategy(BinStrategy::Uniform)
303 .encode(EncodeStrategy::Onehot);
304 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
305 let transformed = fitted.transform(&x).unwrap();
306
307 assert_eq!(transformed.ncols(), 3);
309
310 for i in 0..transformed.nrows() {
312 let row_sum: f64 = transformed.row(i).sum();
313 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
314 }
315 }
316
317 #[test]
318 fn test_onehot_multiple_features() {
319 let x = array![[1.0, 10.0], [5.0, 50.0], [9.0, 90.0]];
320 let kbd = KBinsDiscretizer::new()
321 .n_bins(3)
322 .strategy(BinStrategy::Uniform)
323 .encode(EncodeStrategy::Onehot);
324 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
325 let transformed = fitted.transform(&x).unwrap();
326
327 assert_eq!(transformed.ncols(), 6);
329
330 for i in 0..transformed.nrows() {
332 let row_sum: f64 = transformed.row(i).sum();
333 assert_abs_diff_eq!(row_sum, 2.0, epsilon = 1e-10);
334 }
335 }
336
337 #[test]
338 fn test_empty_input() {
339 let x: Array2<f64> = Array2::zeros((0, 0));
340 let kbd = KBinsDiscretizer::default();
341 assert!(FitUnsupervised::<f64>::fit(&kbd, &x).is_err());
342 }
343
344 #[test]
345 fn test_invalid_n_bins() {
346 let x = array![[1.0], [2.0], [3.0]];
347 let kbd = KBinsDiscretizer::new().n_bins(1);
348 assert!(FitUnsupervised::<f64>::fit(&kbd, &x).is_err());
349 }
350
351 #[test]
352 fn test_shape_mismatch() {
353 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
354 let kbd = KBinsDiscretizer::default();
355 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
356
357 let x_wrong = array![[1.0, 2.0, 3.0]];
358 assert!(fitted.transform(&x_wrong).is_err());
359 }
360
361 #[test]
362 fn test_out_of_range_values() {
363 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
364 let kbd = KBinsDiscretizer::new()
365 .n_bins(3)
366 .strategy(BinStrategy::Uniform)
367 .encode(EncodeStrategy::Ordinal);
368 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
369
370 let x_test = array![[-10.0], [0.0], [3.0], [6.0], [100.0]];
372 let transformed = fitted.transform(&x_test).unwrap();
373
374 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[4, 0]], 2.0, epsilon = 1e-10); }
378
379 #[test]
380 fn test_constant_feature() {
381 let x = array![[5.0], [5.0], [5.0], [5.0]];
382 let kbd = KBinsDiscretizer::new()
383 .n_bins(3)
384 .strategy(BinStrategy::Uniform)
385 .encode(EncodeStrategy::Ordinal);
386 let fitted = FitUnsupervised::<f64>::fit(&kbd, &x).unwrap();
387 let transformed = fitted.transform(&x).unwrap();
388
389 for &v in transformed.iter() {
391 assert!(v.is_finite(), "constant feature produced non-finite: {}", v);
392 }
393 }
394
395 #[test]
396 fn test_f32() {
397 let x = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
398 let kbd = KBinsDiscretizer::new()
399 .n_bins(3)
400 .strategy(BinStrategy::Quantile)
401 .encode(EncodeStrategy::Ordinal);
402 let fitted = FitUnsupervised::<f32>::fit(&kbd, &x).unwrap();
403 let transformed = fitted.transform(&x).unwrap();
404
405 for &v in transformed.iter() {
406 assert!(v.is_finite());
407 assert!(v >= 0.0 && v < 3.0);
408 }
409 }
410}