1use ferrolearn_core::error::FerroError;
14use ferrolearn_core::traits::{Fit, FitTransform, Transform};
15use ndarray::Array2;
16use num_traits::Float;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BinStrategy {
25 Uniform,
27 Quantile,
29 KMeans,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum BinEncoding {
36 Ordinal,
38 OneHot,
40}
41
42#[must_use]
68#[derive(Debug, Clone)]
69pub struct KBinsDiscretizer<F> {
70 n_bins: usize,
72 encode: BinEncoding,
74 strategy: BinStrategy,
76 _marker: std::marker::PhantomData<F>,
77}
78
79impl<F: Float + Send + Sync + 'static> KBinsDiscretizer<F> {
80 pub fn new(n_bins: usize, encode: BinEncoding, strategy: BinStrategy) -> Self {
82 Self {
83 n_bins,
84 encode,
85 strategy,
86 _marker: std::marker::PhantomData,
87 }
88 }
89
90 #[must_use]
92 pub fn n_bins(&self) -> usize {
93 self.n_bins
94 }
95
96 #[must_use]
98 pub fn encode(&self) -> BinEncoding {
99 self.encode
100 }
101
102 #[must_use]
104 pub fn strategy(&self) -> BinStrategy {
105 self.strategy
106 }
107}
108
109impl<F: Float + Send + Sync + 'static> Default for KBinsDiscretizer<F> {
110 fn default() -> Self {
111 Self::new(5, BinEncoding::Ordinal, BinStrategy::Uniform)
112 }
113}
114
115#[derive(Debug, Clone)]
123pub struct FittedKBinsDiscretizer<F> {
124 bin_edges: Vec<Vec<F>>,
126 n_bins: usize,
128 encode: BinEncoding,
130}
131
132impl<F: Float + Send + Sync + 'static> FittedKBinsDiscretizer<F> {
133 #[must_use]
135 pub fn bin_edges(&self) -> &[Vec<F>] {
136 &self.bin_edges
137 }
138
139 #[must_use]
141 pub fn n_bins(&self) -> usize {
142 self.n_bins
143 }
144
145 #[must_use]
147 pub fn encode(&self) -> BinEncoding {
148 self.encode
149 }
150}
151
152fn assign_bin<F: Float>(value: F, edges: &[F]) -> usize {
158 let n_bins = edges.len() - 1;
159 if n_bins == 0 {
160 return 0;
161 }
162 for (i, edge) in edges.iter().enumerate().skip(1) {
164 if value < *edge {
165 return i - 1;
166 }
167 }
168 n_bins - 1
170}
171
172fn kmeans_1d<F: Float>(values: &[F], n_bins: usize, max_iter: usize) -> Vec<F> {
174 let n = values.len();
175 if n <= n_bins || n_bins == 0 {
176 let min_v = values
178 .iter()
179 .copied()
180 .fold(F::infinity(), num_traits::Float::min);
181 let max_v = values
182 .iter()
183 .copied()
184 .fold(F::neg_infinity(), num_traits::Float::max);
185 return (0..=n_bins)
186 .map(|i| min_v + (max_v - min_v) * F::from(i).unwrap() / F::from(n_bins).unwrap())
187 .collect();
188 }
189
190 let min_v = values
192 .iter()
193 .copied()
194 .fold(F::infinity(), num_traits::Float::min);
195 let max_v = values
196 .iter()
197 .copied()
198 .fold(F::neg_infinity(), num_traits::Float::max);
199
200 let mut centroids: Vec<F> = (0..n_bins)
201 .map(|i| {
202 min_v
203 + (max_v - min_v) * (F::from(i).unwrap() + F::from(0.5).unwrap())
204 / F::from(n_bins).unwrap()
205 })
206 .collect();
207
208 for _ in 0..max_iter {
209 let mut sums = vec![F::zero(); n_bins];
211 let mut counts = vec![0usize; n_bins];
212
213 for &v in values {
214 let mut best_c = 0;
215 let mut best_dist = F::infinity();
216 for (c, ¢roid) in centroids.iter().enumerate() {
217 let d = (v - centroid).abs();
218 if d < best_dist {
219 best_dist = d;
220 best_c = c;
221 }
222 }
223 sums[best_c] = sums[best_c] + v;
224 counts[best_c] += 1;
225 }
226
227 let mut converged = true;
229 for c in 0..n_bins {
230 if counts[c] > 0 {
231 let new_centroid = sums[c] / F::from(counts[c]).unwrap();
232 if (new_centroid - centroids[c]).abs() > F::from(1e-10).unwrap_or_else(F::epsilon) {
233 converged = false;
234 }
235 centroids[c] = new_centroid;
236 }
237 }
238 if converged {
239 break;
240 }
241 }
242
243 centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
245
246 let mut edges = Vec::with_capacity(n_bins + 1);
247 edges.push(min_v);
248 for i in 0..n_bins - 1 {
249 let mid = (centroids[i] + centroids[i + 1]) / (F::one() + F::one());
250 edges.push(mid);
251 }
252 edges.push(max_v);
253
254 edges
255}
256
257impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KBinsDiscretizer<F> {
262 type Fitted = FittedKBinsDiscretizer<F>;
263 type Error = FerroError;
264
265 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKBinsDiscretizer<F>, FerroError> {
272 let n_samples = x.nrows();
273 if n_samples < 2 {
274 return Err(FerroError::InsufficientSamples {
275 required: 2,
276 actual: n_samples,
277 context: "KBinsDiscretizer::fit".into(),
278 });
279 }
280 if self.n_bins < 2 {
281 return Err(FerroError::InvalidParameter {
282 name: "n_bins".into(),
283 reason: "n_bins must be at least 2".into(),
284 });
285 }
286
287 let n_features = x.ncols();
288 let mut bin_edges = Vec::with_capacity(n_features);
289
290 for j in 0..n_features {
291 let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
292 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
293
294 let min_val = col_vals[0];
295 let max_val = col_vals[col_vals.len() - 1];
296
297 let edges = match self.strategy {
298 BinStrategy::Uniform => (0..=self.n_bins)
299 .map(|i| {
300 min_val
301 + (max_val - min_val) * F::from(i).unwrap()
302 / F::from(self.n_bins).unwrap()
303 })
304 .collect(),
305 BinStrategy::Quantile => {
306 let n = col_vals.len();
307 (0..=self.n_bins)
308 .map(|i| {
309 let frac = F::from(i).unwrap() / F::from(self.n_bins).unwrap();
310 let pos = frac * F::from(n.saturating_sub(1)).unwrap();
311 let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
312 let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
313 let f = pos - F::from(lo).unwrap();
314 col_vals[lo] * (F::one() - f) + col_vals[hi] * f
315 })
316 .collect()
317 }
318 BinStrategy::KMeans => kmeans_1d(&col_vals, self.n_bins, 100),
319 };
320
321 bin_edges.push(edges);
322 }
323
324 Ok(FittedKBinsDiscretizer {
325 bin_edges,
326 n_bins: self.n_bins,
327 encode: self.encode,
328 })
329 }
330}
331
332impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKBinsDiscretizer<F> {
333 type Output = Array2<F>;
334 type Error = FerroError;
335
336 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
343 let n_features = self.bin_edges.len();
344 if x.ncols() != n_features {
345 return Err(FerroError::ShapeMismatch {
346 expected: vec![x.nrows(), n_features],
347 actual: vec![x.nrows(), x.ncols()],
348 context: "FittedKBinsDiscretizer::transform".into(),
349 });
350 }
351
352 let n_samples = x.nrows();
353
354 match self.encode {
355 BinEncoding::Ordinal => {
356 let mut out = Array2::zeros((n_samples, n_features));
357 for j in 0..n_features {
358 let edges = &self.bin_edges[j];
359 for i in 0..n_samples {
360 let bin = assign_bin(x[[i, j]], edges);
361 out[[i, j]] = F::from(bin).unwrap_or_else(F::zero);
362 }
363 }
364 Ok(out)
365 }
366 BinEncoding::OneHot => {
367 let n_out = n_features * self.n_bins;
368 let mut out = Array2::zeros((n_samples, n_out));
369 for j in 0..n_features {
370 let edges = &self.bin_edges[j];
371 let col_offset = j * self.n_bins;
372 for i in 0..n_samples {
373 let bin = assign_bin(x[[i, j]], edges);
374 out[[i, col_offset + bin]] = F::one();
375 }
376 }
377 Ok(out)
378 }
379 }
380 }
381}
382
383impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KBinsDiscretizer<F> {
385 type Output = Array2<F>;
386 type Error = FerroError;
387
388 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
390 Err(FerroError::InvalidParameter {
391 name: "KBinsDiscretizer".into(),
392 reason: "discretizer must be fitted before calling transform; use fit() first".into(),
393 })
394 }
395}
396
397impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KBinsDiscretizer<F> {
398 type FitError = FerroError;
399
400 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
406 let fitted = self.fit(x, &())?;
407 fitted.transform(x)
408 }
409}
410
411#[cfg(test)]
416mod tests {
417 use super::*;
418 use approx::assert_abs_diff_eq;
419 use ndarray::array;
420
421 #[test]
422 fn test_kbins_ordinal_uniform() {
423 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
424 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]];
425 let fitted = disc.fit(&x, &()).unwrap();
426 let out = fitted.transform(&x).unwrap();
427 assert_eq!(out.ncols(), 1);
428 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[5, 0]], 2.0, epsilon = 1e-10); }
432
433 #[test]
434 fn test_kbins_onehot_uniform() {
435 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::OneHot, BinStrategy::Uniform);
436 let x = array![[0.0], [2.5], [5.0]];
437 let fitted = disc.fit(&x, &()).unwrap();
438 let out = fitted.transform(&x).unwrap();
439 assert_eq!(out.ncols(), 3);
441 for i in 0..out.nrows() {
443 let row_sum: f64 = out.row(i).iter().sum();
444 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
445 }
446 }
447
448 #[test]
449 fn test_kbins_quantile_strategy() {
450 let disc = KBinsDiscretizer::<f64>::new(4, BinEncoding::Ordinal, BinStrategy::Quantile);
451 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]];
452 let fitted = disc.fit(&x, &()).unwrap();
453 let out = fitted.transform(&x).unwrap();
454 for v in &out {
456 assert!(*v >= 0.0 && *v < 4.0);
457 }
458 }
459
460 #[test]
461 fn test_kbins_kmeans_strategy() {
462 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::KMeans);
463 let x = array![
464 [0.0],
465 [0.1],
466 [0.2],
467 [5.0],
468 [5.1],
469 [5.2],
470 [10.0],
471 [10.1],
472 [10.2]
473 ];
474 let fitted = disc.fit(&x, &()).unwrap();
475 let out = fitted.transform(&x).unwrap();
476 for v in &out {
478 assert!(*v >= 0.0 && *v < 3.0);
479 }
480 }
481
482 #[test]
483 fn test_kbins_multi_feature() {
484 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
485 let x = array![[0.0, 10.0], [2.5, 15.0], [5.0, 20.0]];
486 let fitted = disc.fit(&x, &()).unwrap();
487 let out = fitted.transform(&x).unwrap();
488 assert_eq!(out.ncols(), 2);
489 }
490
491 #[test]
492 fn test_kbins_bin_edges() {
493 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
494 let x = array![[0.0], [3.0], [6.0]];
495 let fitted = disc.fit(&x, &()).unwrap();
496 let edges = &fitted.bin_edges()[0];
497 assert_eq!(edges.len(), 4);
499 assert_abs_diff_eq!(edges[0], 0.0, epsilon = 1e-10);
500 assert_abs_diff_eq!(edges[3], 6.0, epsilon = 1e-10);
501 }
502
503 #[test]
504 fn test_kbins_fit_transform() {
505 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
506 let x = array![[0.0], [2.5], [5.0]];
507 let out = disc.fit_transform(&x).unwrap();
508 assert_eq!(out.ncols(), 1);
509 }
510
511 #[test]
512 fn test_kbins_insufficient_samples_error() {
513 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
514 let x = array![[1.0]];
515 assert!(disc.fit(&x, &()).is_err());
516 }
517
518 #[test]
519 fn test_kbins_too_few_bins_error() {
520 let disc = KBinsDiscretizer::<f64>::new(1, BinEncoding::Ordinal, BinStrategy::Uniform);
521 let x = array![[0.0], [1.0]];
522 assert!(disc.fit(&x, &()).is_err());
523 }
524
525 #[test]
526 fn test_kbins_shape_mismatch_error() {
527 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
528 let x_train = array![[0.0, 1.0], [2.0, 3.0]];
529 let fitted = disc.fit(&x_train, &()).unwrap();
530 let x_bad = array![[1.0, 2.0, 3.0]];
531 assert!(fitted.transform(&x_bad).is_err());
532 }
533
534 #[test]
535 fn test_kbins_unfitted_error() {
536 let disc = KBinsDiscretizer::<f64>::new(3, BinEncoding::Ordinal, BinStrategy::Uniform);
537 let x = array![[0.0]];
538 assert!(disc.transform(&x).is_err());
539 }
540
541 #[test]
542 fn test_kbins_default() {
543 let disc = KBinsDiscretizer::<f64>::default();
544 assert_eq!(disc.n_bins(), 5);
545 assert_eq!(disc.encode(), BinEncoding::Ordinal);
546 assert_eq!(disc.strategy(), BinStrategy::Uniform);
547 }
548
549 #[test]
550 fn test_kbins_ordinal_values_in_range() {
551 let disc = KBinsDiscretizer::<f64>::new(5, BinEncoding::Ordinal, BinStrategy::Uniform);
552 let x = array![[0.0], [2.5], [5.0], [7.5], [10.0]];
553 let fitted = disc.fit(&x, &()).unwrap();
554 let out = fitted.transform(&x).unwrap();
555 for v in &out {
556 assert!(*v >= 0.0 && *v < 5.0, "Bin index {v} out of range");
557 }
558 }
559}