1use ferrolearn_core::error::FerroError;
14use ferrolearn_core::traits::{Fit, FitTransform, Transform};
15use ndarray::Array2;
16use num_traits::Float;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BinStrategy {
25 Uniform,
27 Quantile,
29 KMeans,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum BinEncoding {
36 Ordinal,
38 OneHot,
40}
41
42#[must_use]
68#[derive(Debug, Clone)]
69pub struct KBinsDiscretizer<F> {
70 n_bins: usize,
72 encode: BinEncoding,
74 strategy: BinStrategy,
76 _marker: std::marker::PhantomData<F>,
77}
78
79impl<F: Float + Send + Sync + 'static> KBinsDiscretizer<F> {
80 pub fn new(n_bins: usize, encode: BinEncoding, strategy: BinStrategy) -> Self {
82 Self {
83 n_bins,
84 encode,
85 strategy,
86 _marker: std::marker::PhantomData,
87 }
88 }
89
90 #[must_use]
92 pub fn n_bins(&self) -> usize {
93 self.n_bins
94 }
95
96 #[must_use]
98 pub fn encode(&self) -> BinEncoding {
99 self.encode
100 }
101
102 #[must_use]
104 pub fn strategy(&self) -> BinStrategy {
105 self.strategy
106 }
107}
108
109impl<F: Float + Send + Sync + 'static> Default for KBinsDiscretizer<F> {
110 fn default() -> Self {
111 Self::new(5, BinEncoding::Ordinal, BinStrategy::Uniform)
112 }
113}
114
115#[derive(Debug, Clone)]
123pub struct FittedKBinsDiscretizer<F> {
124 bin_edges: Vec<Vec<F>>,
126 n_bins: usize,
128 encode: BinEncoding,
130}
131
132impl<F: Float + Send + Sync + 'static> FittedKBinsDiscretizer<F> {
133 #[must_use]
135 pub fn bin_edges(&self) -> &[Vec<F>] {
136 &self.bin_edges
137 }
138
139 #[must_use]
141 pub fn n_bins(&self) -> usize {
142 self.n_bins
143 }
144
145 #[must_use]
147 pub fn encode(&self) -> BinEncoding {
148 self.encode
149 }
150}
151
152fn assign_bin<F: Float>(value: F, edges: &[F]) -> usize {
158 let n_bins = edges.len() - 1;
159 if n_bins == 0 {
160 return 0;
161 }
162 for (i, edge) in edges.iter().enumerate().skip(1) {
164 if value < *edge {
165 return i - 1;
166 }
167 }
168 n_bins - 1
170}
171
172fn kmeans_1d<F: Float>(values: &[F], n_bins: usize, max_iter: usize) -> Vec<F> {
174 let n = values.len();
175 if n <= n_bins || n_bins == 0 {
176 let min_v = values.iter().copied().fold(F::infinity(), |a, b| a.min(b));
178 let max_v = values
179 .iter()
180 .copied()
181 .fold(F::neg_infinity(), |a, b| a.max(b));
182 return (0..=n_bins)
183 .map(|i| min_v + (max_v - min_v) * F::from(i).unwrap() / F::from(n_bins).unwrap())
184 .collect();
185 }
186
187 let min_v = values.iter().copied().fold(F::infinity(), |a, b| a.min(b));
189 let max_v = values
190 .iter()
191 .copied()
192 .fold(F::neg_infinity(), |a, b| a.max(b));
193
194 let mut centroids: Vec<F> = (0..n_bins)
195 .map(|i| {
196 min_v
197 + (max_v - min_v) * (F::from(i).unwrap() + F::from(0.5).unwrap())
198 / F::from(n_bins).unwrap()
199 })
200 .collect();
201
202 for _ in 0..max_iter {
203 let mut sums = vec![F::zero(); n_bins];
205 let mut counts = vec![0usize; n_bins];
206
207 for &v in values {
208 let mut best_c = 0;
209 let mut best_dist = F::infinity();
210 for (c, ¢roid) in centroids.iter().enumerate() {
211 let d = (v - centroid).abs();
212 if d < best_dist {
213 best_dist = d;
214 best_c = c;
215 }
216 }
217 sums[best_c] = sums[best_c] + v;
218 counts[best_c] += 1;
219 }
220
221 let mut converged = true;
223 for c in 0..n_bins {
224 if counts[c] > 0 {
225 let new_centroid = sums[c] / F::from(counts[c]).unwrap();
226 if (new_centroid - centroids[c]).abs() > F::from(1e-10).unwrap_or(F::epsilon()) {
227 converged = false;
228 }
229 centroids[c] = new_centroid;
230 }
231 }
232 if converged {
233 break;
234 }
235 }
236
237 centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
239
240 let mut edges = Vec::with_capacity(n_bins + 1);
241 edges.push(min_v);
242 for i in 0..n_bins - 1 {
243 let mid = (centroids[i] + centroids[i + 1]) / (F::one() + F::one());
244 edges.push(mid);
245 }
246 edges.push(max_v);
247
248 edges
249}
250
251impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KBinsDiscretizer<F> {
256 type Fitted = FittedKBinsDiscretizer<F>;
257 type Error = FerroError;
258
259 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKBinsDiscretizer<F>, FerroError> {
266 let n_samples = x.nrows();
267 if n_samples < 2 {
268 return Err(FerroError::InsufficientSamples {
269 required: 2,
270 actual: n_samples,
271 context: "KBinsDiscretizer::fit".into(),
272 });
273 }
274 if self.n_bins < 2 {
275 return Err(FerroError::InvalidParameter {
276 name: "n_bins".into(),
277 reason: "n_bins must be at least 2".into(),
278 });
279 }
280
281 let n_features = x.ncols();
282 let mut bin_edges = Vec::with_capacity(n_features);
283
284 for j in 0..n_features {
285 let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
286 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
287
288 let min_val = col_vals[0];
289 let max_val = col_vals[col_vals.len() - 1];
290
291 let edges = match self.strategy {
292 BinStrategy::Uniform => (0..=self.n_bins)
293 .map(|i| {
294 min_val
295 + (max_val - min_val) * F::from(i).unwrap()
296 / F::from(self.n_bins).unwrap()
297 })
298 .collect(),
299 BinStrategy::Quantile => {
300 let n = col_vals.len();
301 (0..=self.n_bins)
302 .map(|i| {
303 let frac = F::from(i).unwrap() / F::from(self.n_bins).unwrap();
304 let pos = frac * F::from(n.saturating_sub(1)).unwrap();
305 let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
306 let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
307 let f = pos - F::from(lo).unwrap();
308 col_vals[lo] * (F::one() - f) + col_vals[hi] * f
309 })
310 .collect()
311 }
312 BinStrategy::KMeans => kmeans_1d(&col_vals, self.n_bins, 100),
313 };
314
315 bin_edges.push(edges);
316 }
317
318 Ok(FittedKBinsDiscretizer {
319 bin_edges,
320 n_bins: self.n_bins,
321 encode: self.encode,
322 })
323 }
324}
325
326impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKBinsDiscretizer<F> {
327 type Output = Array2<F>;
328 type Error = FerroError;
329
330 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
337 let n_features = self.bin_edges.len();
338 if x.ncols() != n_features {
339 return Err(FerroError::ShapeMismatch {
340 expected: vec![x.nrows(), n_features],
341 actual: vec![x.nrows(), x.ncols()],
342 context: "FittedKBinsDiscretizer::transform".into(),
343 });
344 }
345
346 let n_samples = x.nrows();
347
348 match self.encode {
349 BinEncoding::Ordinal => {
350 let mut out = Array2::zeros((n_samples, n_features));
351 for j in 0..n_features {
352 let edges = &self.bin_edges[j];
353 for i in 0..n_samples {
354 let bin = assign_bin(x[[i, j]], edges);
355 out[[i, j]] = F::from(bin).unwrap_or(F::zero());
356 }
357 }
358 Ok(out)
359 }
360 BinEncoding::OneHot => {
361 let n_out = n_features * self.n_bins;
362 let mut out = Array2::zeros((n_samples, n_out));
363 for j in 0..n_features {
364 let edges = &self.bin_edges[j];
365 let col_offset = j * self.n_bins;
366 for i in 0..n_samples {
367 let bin = assign_bin(x[[i, j]], edges);
368 out[[i, col_offset + bin]] = F::one();
369 }
370 }
371 Ok(out)
372 }
373 }
374 }
375}
376
377impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KBinsDiscretizer<F> {
379 type Output = Array2<F>;
380 type Error = FerroError;
381
382 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
384 Err(FerroError::InvalidParameter {
385 name: "KBinsDiscretizer".into(),
386 reason: "discretizer must be fitted before calling transform; use fit() first".into(),
387 })
388 }
389}
390
391impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KBinsDiscretizer<F> {
392 type FitError = FerroError;
393
394 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
400 let fitted = self.fit(x, &())?;
401 fitted.transform(x)
402 }
403}
404
405#[cfg(test)]
410mod tests {
411 use super::*;
412 use approx::assert_abs_diff_eq;
413 use ndarray::array;
414
415 #[test]
416 fn test_kbins_ordinal_uniform() {
417 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
418 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
419 let fitted = disc.fit(&x, &()).unwrap();
420 let out = fitted.transform(&x).unwrap();
421 assert_eq!(out.ncols(), 1);
422 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[5, 0]], 2.0, epsilon = 1e-10); }
426
427 #[test]
428 fn test_kbins_onehot_uniform() {
429 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::OneHot, BinStrategy::Uniform);
430 let x = array![[0.0], [2.5], [5.0]];
431 let fitted = disc.fit(&x, &()).unwrap();
432 let out = fitted.transform(&x).unwrap();
433 assert_eq!(out.ncols(), 3);
435 for i in 0..out.nrows() {
437 let row_sum: f64 = out.row(i).iter().sum();
438 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
439 }
440 }
441
442 #[test]
443 fn test_kbins_quantile_strategy() {
444 let disc = KBinsDiscretizer::<f64>::new(4, BinEncoding::Ordinal, BinStrategy::Quantile);
445 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]];
446 let fitted = disc.fit(&x, &()).unwrap();
447 let out = fitted.transform(&x).unwrap();
448 for v in out.iter() {
450 assert!(*v >= 0.0 && *v < 4.0);
451 }
452 }
453
454 #[test]
455 fn test_kbins_kmeans_strategy() {
456 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::KMeans);
457 let x = array![
458 [0.0],
459 [0.1],
460 [0.2],
461 [5.0],
462 [5.1],
463 [5.2],
464 [10.0],
465 [10.1],
466 [10.2]
467 ];
468 let fitted = disc.fit(&x, &()).unwrap();
469 let out = fitted.transform(&x).unwrap();
470 for v in out.iter() {
472 assert!(*v >= 0.0 && *v < 3.0);
473 }
474 }
475
476 #[test]
477 fn test_kbins_multi_feature() {
478 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
479 let x = array![[0.0, 10.0], [2.5, 15.0], [5.0, 20.0]];
480 let fitted = disc.fit(&x, &()).unwrap();
481 let out = fitted.transform(&x).unwrap();
482 assert_eq!(out.ncols(), 2);
483 }
484
485 #[test]
486 fn test_kbins_bin_edges() {
487 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
488 let x = array![[0.0], [3.0], [6.0]];
489 let fitted = disc.fit(&x, &()).unwrap();
490 let edges = &fitted.bin_edges()[0];
491 assert_eq!(edges.len(), 4);
493 assert_abs_diff_eq!(edges[0], 0.0, epsilon = 1e-10);
494 assert_abs_diff_eq!(edges[3], 6.0, epsilon = 1e-10);
495 }
496
497 #[test]
498 fn test_kbins_fit_transform() {
499 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
500 let x = array![[0.0], [2.5], [5.0]];
501 let out = disc.fit_transform(&x).unwrap();
502 assert_eq!(out.ncols(), 1);
503 }
504
505 #[test]
506 fn test_kbins_insufficient_samples_error() {
507 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
508 let x = array![[1.0]];
509 assert!(disc.fit(&x, &()).is_err());
510 }
511
512 #[test]
513 fn test_kbins_too_few_bins_error() {
514 let disc = KBinsDiscretizer::<f64>::new(1, BinEncoding::Ordinal, BinStrategy::Uniform);
515 let x = array![[0.0], [1.0]];
516 assert!(disc.fit(&x, &()).is_err());
517 }
518
519 #[test]
520 fn test_kbins_shape_mismatch_error() {
521 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
522 let x_train = array![[0.0, 1.0], [2.0, 3.0]];
523 let fitted = disc.fit(&x_train, &()).unwrap();
524 let x_bad = array![[1.0, 2.0, 3.0]];
525 assert!(fitted.transform(&x_bad).is_err());
526 }
527
528 #[test]
529 fn test_kbins_unfitted_error() {
530 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
531 let x = array![[0.0]];
532 assert!(disc.transform(&x).is_err());
533 }
534
535 #[test]
536 fn test_kbins_default() {
537 let disc = KBinsDiscretizer::<f64>::default();
538 assert_eq!(disc.n_bins(), 5);
539 assert_eq!(disc.encode(), BinEncoding::Ordinal);
540 assert_eq!(disc.strategy(), BinStrategy::Uniform);
541 }
542
543 #[test]
544 fn test_kbins_ordinal_values_in_range() {
545 let disc = KBinsDiscretizer::<f64>::new(5, BinEncoding::Ordinal, BinStrategy::Uniform);
546 let x = array![[0.0], [2.5], [5.0], [7.5], [10.0]];
547 let fitted = disc.fit(&x, &()).unwrap();
548 let out = fitted.transform(&x).unwrap();
549 for v in out.iter() {
550 assert!(*v >= 0.0 && *v < 5.0, "Bin index {} out of range", v);
551 }
552 }
553}