Skip to main content

sklears_preprocessing/
binarization.rs

1//! Binarization transformers
2//!
3//! This module provides transformers for binarizing data:
4//! - Binarizer: Binarize data according to a threshold
5//! - KBinsDiscretizer: Discretize continuous features into bins
6
7use scirs2_core::ndarray::{Array1, Array2};
8use std::marker::PhantomData;
9
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::{Estimator, Fit, Trained, Transform, Untrained},
13    types::Float,
14};
15
16/// Configuration for Binarizer
17#[derive(Debug, Clone)]
18pub struct BinarizerConfig {
19    /// Feature values below or equal to this are replaced by 0, above it by 1
20    pub threshold: Float,
21    /// Whether to copy the input array
22    pub copy: bool,
23}
24
25impl Default for BinarizerConfig {
26    fn default() -> Self {
27        Self {
28            threshold: 0.0,
29            copy: true,
30        }
31    }
32}
33
34/// Binarizer transforms data to binary values based on a threshold
35pub struct Binarizer<State = Untrained> {
36    config: BinarizerConfig,
37    state: PhantomData<State>,
38}
39
40impl Binarizer<Untrained> {
41    /// Create a new Binarizer with default configuration
42    pub fn new() -> Self {
43        Self {
44            config: BinarizerConfig::default(),
45            state: PhantomData,
46        }
47    }
48
49    /// Create a new Binarizer with specified threshold
50    pub fn with_threshold(threshold: Float) -> Self {
51        Self {
52            config: BinarizerConfig {
53                threshold,
54                copy: true,
55            },
56            state: PhantomData,
57        }
58    }
59
60    /// Set the threshold
61    pub fn threshold(mut self, threshold: Float) -> Self {
62        self.config.threshold = threshold;
63        self
64    }
65
66    /// Set whether to copy the input array
67    pub fn copy(mut self, copy: bool) -> Self {
68        self.config.copy = copy;
69        self
70    }
71}
72
73impl Default for Binarizer<Untrained> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl Estimator for Binarizer<Untrained> {
80    type Config = BinarizerConfig;
81    type Error = SklearsError;
82    type Float = Float;
83
84    fn config(&self) -> &Self::Config {
85        &self.config
86    }
87}
88
89impl Estimator for Binarizer<Trained> {
90    type Config = BinarizerConfig;
91    type Error = SklearsError;
92    type Float = Float;
93
94    fn config(&self) -> &Self::Config {
95        &self.config
96    }
97}
98
99impl Fit<Array2<Float>, ()> for Binarizer<Untrained> {
100    type Fitted = Binarizer<Trained>;
101
102    fn fit(self, _x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
103        // Binarizer doesn't need to learn anything from the data
104        Ok(Binarizer {
105            config: self.config,
106            state: PhantomData,
107        })
108    }
109}
110
111impl Transform<Array2<Float>, Array2<Float>> for Binarizer<Trained> {
112    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
113        let result = if self.config.copy {
114            x.clone()
115        } else {
116            x.to_owned()
117        };
118
119        Ok(result.mapv(|v| if v > self.config.threshold { 1.0 } else { 0.0 }))
120    }
121}
122
123impl Transform<Array1<Float>, Array1<Float>> for Binarizer<Trained> {
124    fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
125        let result = if self.config.copy {
126            x.clone()
127        } else {
128            x.to_owned()
129        };
130
131        Ok(result.mapv(|v| if v > self.config.threshold { 1.0 } else { 0.0 }))
132    }
133}
134
135/// Discretization strategy for KBinsDiscretizer
136#[derive(Debug, Clone, Copy, PartialEq)]
137pub enum DiscretizationStrategy {
138    /// All bins have identical widths
139    Uniform,
140    /// All bins have the same number of points
141    Quantile,
142    /// Bins are clustered using k-means
143    KMeans,
144}
145
146/// Encoding method for discretized values
147#[derive(Debug, Clone, Copy, PartialEq)]
148pub enum DiscretizerEncoding {
149    /// Encode as one-hot vectors
150    OneHot,
151    /// Encode with the bin identifier as an integer
152    Ordinal,
153}
154
155/// Configuration for KBinsDiscretizer
156#[derive(Debug, Clone)]
157pub struct KBinsDiscretizerConfig {
158    /// Number of bins to produce
159    pub n_bins: usize,
160    /// Encoding method
161    pub encode: DiscretizerEncoding,
162    /// Strategy used to define the widths of the bins
163    pub strategy: DiscretizationStrategy,
164    /// Subsample size for KMeans strategy
165    pub subsample: Option<usize>,
166    /// Random state for KMeans
167    pub random_state: Option<u64>,
168}
169
170impl Default for KBinsDiscretizerConfig {
171    fn default() -> Self {
172        Self {
173            n_bins: 5,
174            encode: DiscretizerEncoding::OneHot,
175            strategy: DiscretizationStrategy::Quantile,
176            subsample: Some(200_000),
177            random_state: None,
178        }
179    }
180}
181
182/// KBinsDiscretizer bins continuous data into intervals
183pub struct KBinsDiscretizer<State = Untrained> {
184    config: KBinsDiscretizerConfig,
185    state: PhantomData<State>,
186    /// The edges of each bin for each feature
187    bin_edges_: Option<Vec<Array1<Float>>>,
188    /// Number of bins for each feature
189    n_bins_: Option<Vec<usize>>,
190}
191
192impl KBinsDiscretizer<Untrained> {
193    /// Create a new KBinsDiscretizer with default configuration
194    pub fn new() -> Self {
195        Self {
196            config: KBinsDiscretizerConfig::default(),
197            state: PhantomData,
198            bin_edges_: None,
199            n_bins_: None,
200        }
201    }
202
203    /// Set the number of bins
204    pub fn n_bins(mut self, n_bins: usize) -> Result<Self> {
205        if n_bins < 2 {
206            return Err(SklearsError::InvalidParameter {
207                name: "n_bins".to_string(),
208                reason: "must be at least 2".to_string(),
209            });
210        }
211        self.config.n_bins = n_bins;
212        Ok(self)
213    }
214
215    /// Set the encoding method
216    pub fn encode(mut self, encode: DiscretizerEncoding) -> Self {
217        self.config.encode = encode;
218        self
219    }
220
221    /// Set the discretization strategy
222    pub fn strategy(mut self, strategy: DiscretizationStrategy) -> Self {
223        self.config.strategy = strategy;
224        self
225    }
226}
227
228impl Default for KBinsDiscretizer<Untrained> {
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234impl Estimator for KBinsDiscretizer<Untrained> {
235    type Config = KBinsDiscretizerConfig;
236    type Error = SklearsError;
237    type Float = Float;
238
239    fn config(&self) -> &Self::Config {
240        &self.config
241    }
242}
243
244impl Estimator for KBinsDiscretizer<Trained> {
245    type Config = KBinsDiscretizerConfig;
246    type Error = SklearsError;
247    type Float = Float;
248
249    fn config(&self) -> &Self::Config {
250        &self.config
251    }
252}
253
254/// Compute uniform bin edges
255fn compute_uniform_bins(data: &Array1<Float>, n_bins: usize) -> Array1<Float> {
256    let min_val = data.iter().cloned().fold(Float::INFINITY, Float::min);
257    let max_val = data.iter().cloned().fold(Float::NEG_INFINITY, Float::max);
258
259    if (max_val - min_val).abs() < Float::EPSILON {
260        // All values are the same
261        return Array1::from_vec(vec![min_val - 0.5, max_val + 0.5]);
262    }
263
264    let width = (max_val - min_val) / n_bins as Float;
265    let mut edges = Vec::with_capacity(n_bins + 1);
266
267    for i in 0..=n_bins {
268        edges.push(min_val + i as Float * width);
269    }
270
271    // Extend the last edge slightly to include the maximum value
272    edges[n_bins] = max_val + Float::EPSILON;
273
274    Array1::from_vec(edges)
275}
276
277/// Compute quantile bin edges
278fn compute_quantile_bins(data: &Array1<Float>, n_bins: usize) -> Array1<Float> {
279    let mut sorted_data = data.to_vec();
280    sorted_data.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
281
282    let n_samples = sorted_data.len();
283    let mut edges = Vec::with_capacity(n_bins + 1);
284
285    // Add minimum
286    edges.push(sorted_data[0]);
287
288    // Add quantile edges
289    for i in 1..n_bins {
290        let idx = (i * n_samples) / n_bins;
291        let value = sorted_data[idx.min(n_samples - 1)];
292
293        // Avoid duplicate edges
294        if value > edges.last().expect("collection should not be empty") + Float::EPSILON {
295            edges.push(value);
296        }
297    }
298
299    // Add maximum
300    edges.push(sorted_data[n_samples - 1] + Float::EPSILON);
301
302    // Ensure we have at least 2 bins
303    if edges.len() < 3 {
304        edges.clear();
305        edges.push(sorted_data[0]);
306        edges.push(sorted_data[n_samples - 1] + Float::EPSILON);
307    }
308
309    Array1::from_vec(edges)
310}
311
312impl Fit<Array2<Float>, ()> for KBinsDiscretizer<Untrained> {
313    type Fitted = KBinsDiscretizer<Trained>;
314
315    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
316        let n_features = x.ncols();
317        let mut bin_edges = Vec::with_capacity(n_features);
318        let mut n_bins = Vec::with_capacity(n_features);
319
320        // Compute bin edges for each feature
321        for j in 0..n_features {
322            let feature_data = x.column(j).to_owned();
323
324            let edges = match self.config.strategy {
325                DiscretizationStrategy::Uniform => {
326                    compute_uniform_bins(&feature_data, self.config.n_bins)
327                }
328                DiscretizationStrategy::Quantile => {
329                    compute_quantile_bins(&feature_data, self.config.n_bins)
330                }
331                DiscretizationStrategy::KMeans => {
332                    // For now, fall back to quantile
333                    compute_quantile_bins(&feature_data, self.config.n_bins)
334                }
335            };
336
337            n_bins.push(edges.len() - 1);
338            bin_edges.push(edges);
339        }
340
341        Ok(KBinsDiscretizer {
342            config: self.config,
343            state: PhantomData,
344            bin_edges_: Some(bin_edges),
345            n_bins_: Some(n_bins),
346        })
347    }
348}
349
350/// Find the bin index for a value given bin edges
351fn find_bin(value: Float, edges: &Array1<Float>) -> usize {
352    // Binary search for the bin
353    let n_edges = edges.len();
354
355    if value <= edges[0] {
356        return 0;
357    }
358    if value >= edges[n_edges - 1] {
359        return n_edges - 2;
360    }
361
362    let mut left = 0;
363    let mut right = n_edges - 1;
364
365    while left < right - 1 {
366        let mid = (left + right) / 2;
367        if value < edges[mid] {
368            right = mid;
369        } else {
370            left = mid;
371        }
372    }
373
374    left
375}
376
377impl Transform<Array2<Float>, Array2<Float>> for KBinsDiscretizer<Trained> {
378    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
379        let n_samples = x.nrows();
380        let n_features = x.ncols();
381        let bin_edges = self.bin_edges_.as_ref().expect("operation should succeed");
382        let n_bins = self.n_bins_.as_ref().expect("operation should succeed");
383
384        match self.config.encode {
385            DiscretizerEncoding::Ordinal => {
386                let mut result = Array2::zeros((n_samples, n_features));
387
388                for i in 0..n_samples {
389                    for j in 0..n_features {
390                        let bin_idx = find_bin(x[[i, j]], &bin_edges[j]);
391                        result[[i, j]] = bin_idx as Float;
392                    }
393                }
394
395                Ok(result)
396            }
397            DiscretizerEncoding::OneHot => {
398                // Calculate total number of columns for one-hot encoding
399                let total_bins: usize = n_bins.iter().sum();
400                let mut result = Array2::zeros((n_samples, total_bins));
401
402                for i in 0..n_samples {
403                    let mut col_offset = 0;
404                    for j in 0..n_features {
405                        let bin_idx = find_bin(x[[i, j]], &bin_edges[j]);
406                        result[[i, col_offset + bin_idx]] = 1.0;
407                        col_offset += n_bins[j];
408                    }
409                }
410
411                Ok(result)
412            }
413        }
414    }
415}
416
417impl KBinsDiscretizer<Trained> {
418    /// Get the bin edges for each feature
419    pub fn bin_edges(&self) -> &Vec<Array1<Float>> {
420        self.bin_edges_.as_ref().expect("operation should succeed")
421    }
422
423    /// Get the number of bins for each feature
424    pub fn n_bins(&self) -> &Vec<usize> {
425        self.n_bins_.as_ref().expect("operation should succeed")
426    }
427
428    /// Transform back from bin indices to representative values (inverse transform)
429    pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
430        let bin_edges = self.bin_edges_.as_ref().expect("operation should succeed");
431        let n_features = bin_edges.len();
432
433        match self.config.encode {
434            DiscretizerEncoding::Ordinal => {
435                if x.ncols() != n_features {
436                    return Err(SklearsError::InvalidInput(
437                        "Input must have the same number of features as during fit".to_string(),
438                    ));
439                }
440
441                let mut result = Array2::zeros(x.dim());
442
443                for i in 0..x.nrows() {
444                    for j in 0..n_features {
445                        let bin_idx = x[[i, j]] as usize;
446                        let edges = &bin_edges[j];
447
448                        if bin_idx >= edges.len() - 1 {
449                            return Err(SklearsError::InvalidInput(format!(
450                                "Invalid bin index {bin_idx} for feature {j}"
451                            )));
452                        }
453
454                        // Use bin center as representative value
455                        result[[i, j]] = (edges[bin_idx] + edges[bin_idx + 1]) / 2.0;
456                    }
457                }
458
459                Ok(result)
460            }
461            DiscretizerEncoding::OneHot => Err(SklearsError::InvalidInput(
462                "Inverse transform not supported for one-hot encoding".to_string(),
463            )),
464        }
465    }
466}
467
468#[allow(non_snake_case)]
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use scirs2_core::ndarray::array;
473
474    #[test]
475    fn test_binarizer() {
476        let x = array![[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0],];
477
478        let binarizer = Binarizer::with_threshold(0.0)
479            .fit(&x, &())
480            .expect("model fitting should succeed");
481
482        let x_bin = binarizer
483            .transform(&x)
484            .expect("transformation should succeed");
485
486        let expected = array![[1.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0],];
487
488        assert_eq!(x_bin, expected);
489    }
490
491    #[test]
492    fn test_binarizer_custom_threshold() {
493        let x = array![[1.0, 2.0, 3.0, 4.0]];
494
495        let binarizer = Binarizer::new()
496            .threshold(2.5)
497            .fit(&x, &())
498            .expect("model fitting should succeed");
499
500        let x_bin = binarizer
501            .transform(&x)
502            .expect("transformation should succeed");
503        let expected = array![[0.0, 0.0, 1.0, 1.0]];
504
505        assert_eq!(x_bin, expected);
506    }
507
508    #[test]
509    fn test_binarizer_1d() {
510        let x = array![1.0, -1.0, 2.0, 0.0];
511
512        let binarizer = Binarizer::new()
513            .fit(&array![[0.0]], &())
514            .expect("model fitting should succeed");
515
516        let x_bin = binarizer
517            .transform(&x)
518            .expect("transformation should succeed");
519        let expected = array![1.0, 0.0, 1.0, 0.0];
520
521        assert_eq!(x_bin, expected);
522    }
523
524    #[test]
525    fn test_kbins_discretizer_uniform() {
526        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
527
528        let discretizer = KBinsDiscretizer::new()
529            .n_bins(3)
530            .expect("valid parameter")
531            .strategy(DiscretizationStrategy::Uniform)
532            .encode(DiscretizerEncoding::Ordinal)
533            .fit(&x, &())
534            .expect("operation should succeed");
535
536        let x_disc = discretizer
537            .transform(&x)
538            .expect("transformation should succeed");
539
540        // With 3 bins and uniform strategy: [0, 2), [2, 4), [4, 5+ε]
541        // Values should be binned as: 0, 0, 1, 1, 2, 2
542        assert_eq!(
543            x_disc.column(0).to_vec(),
544            vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]
545        );
546    }
547
548    #[test]
549    fn test_kbins_discretizer_quantile() {
550        let x = array![[0.0], [1.0], [1.0], [2.0], [3.0], [10.0],];
551
552        let discretizer = KBinsDiscretizer::new()
553            .n_bins(3)
554            .expect("valid parameter")
555            .strategy(DiscretizationStrategy::Quantile)
556            .encode(DiscretizerEncoding::Ordinal)
557            .fit(&x, &())
558            .expect("operation should succeed");
559
560        let x_disc = discretizer
561            .transform(&x)
562            .expect("transformation should succeed");
563
564        // Check that each bin has approximately the same number of samples
565        let bin_counts = vec![
566            x_disc.iter().filter(|&&v| v == 0.0).count(),
567            x_disc.iter().filter(|&&v| v == 1.0).count(),
568            x_disc.iter().filter(|&&v| v == 2.0).count(),
569        ];
570
571        // Each bin should have about 2 samples
572        for count in bin_counts {
573            assert!(count >= 1 && count <= 3);
574        }
575    }
576
577    #[test]
578    fn test_kbins_discretizer_onehot() {
579        let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0],];
580
581        let discretizer = KBinsDiscretizer::new()
582            .n_bins(2)
583            .expect("valid parameter")
584            .encode(DiscretizerEncoding::OneHot)
585            .fit(&x, &())
586            .expect("operation should succeed");
587
588        let x_disc = discretizer
589            .transform(&x)
590            .expect("transformation should succeed");
591
592        // With 2 bins per feature and 2 features, we should get 4 columns
593        assert_eq!(x_disc.ncols(), 4);
594
595        // Each row should have exactly 2 ones (one per feature)
596        for i in 0..x_disc.nrows() {
597            let row_sum: Float = x_disc.row(i).sum();
598            assert_eq!(row_sum, 2.0);
599        }
600    }
601
602    #[test]
603    fn test_kbins_discretizer_inverse_transform() {
604        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
605
606        let discretizer = KBinsDiscretizer::new()
607            .n_bins(3)
608            .expect("valid parameter")
609            .strategy(DiscretizationStrategy::Uniform)
610            .encode(DiscretizerEncoding::Ordinal)
611            .fit(&x, &())
612            .expect("operation should succeed");
613
614        let x_disc = discretizer
615            .transform(&x)
616            .expect("transformation should succeed");
617        let x_inv = discretizer
618            .inverse_transform(&x_disc)
619            .expect("operation should succeed");
620
621        // The inverse transform should produce values close to the bin centers
622        // Bins: [0, 2), [2, 4), [4, 5+ε] with centers at 1, 3, ~4.5
623        assert!(x_inv[[0, 0]] < 2.0); // First bin center
624        assert!(x_inv[[2, 0]] > 2.0 && x_inv[[2, 0]] < 4.0); // Second bin center
625        assert!(x_inv[[4, 0]] > 4.0); // Third bin center
626    }
627
628    #[test]
629    fn test_find_bin() {
630        let edges = array![0.0, 2.0, 4.0, 6.0];
631
632        assert_eq!(find_bin(-1.0, &edges), 0);
633        assert_eq!(find_bin(0.0, &edges), 0);
634        assert_eq!(find_bin(1.0, &edges), 0);
635        assert_eq!(find_bin(2.0, &edges), 1);
636        assert_eq!(find_bin(3.0, &edges), 1);
637        assert_eq!(find_bin(4.0, &edges), 2);
638        assert_eq!(find_bin(5.0, &edges), 2);
639        assert_eq!(find_bin(6.0, &edges), 2);
640        assert_eq!(find_bin(7.0, &edges), 2);
641    }
642}