Skip to main content

anofox_ml_preprocessing/
binarizer.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4/// Parameters for Binarizer (unfitted state).
5///
6/// Thresholds features: values strictly greater than the threshold become 1,
7/// all others become 0.
8///
9/// `x_out[i, j] = if x[i, j] > threshold { 1 } else { 0 }`
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
12pub struct Binarizer<F: Float> {
13    /// The threshold value. Features greater than this are set to 1, else 0.
14    pub threshold: F,
15}
16
17impl<F: Float> Binarizer<F> {
18    /// Create a new `Binarizer` with the given threshold.
19    pub fn new(threshold: F) -> Self {
20        Self { threshold }
21    }
22}
23
24impl<F: Float> Default for Binarizer<F> {
25    /// Default binarizer with threshold 0.
26    fn default() -> Self {
27        Self::new(F::zero())
28    }
29}
30
31/// Fitted Binarizer — stateless, stores only the threshold.
32#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
33#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
34pub struct FittedBinarizer<F: Float> {
35    threshold: F,
36}
37
38impl<F: Float> FitUnsupervised<F> for Binarizer<F> {
39    type Fitted = FittedBinarizer<F>;
40
41    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
42        if x.is_empty() {
43            return Err(RustMlError::EmptyInput("input array is empty".into()));
44        }
45
46        Ok(FittedBinarizer {
47            threshold: self.threshold,
48        })
49    }
50}
51
52impl<F: Float> Transform<F> for FittedBinarizer<F> {
53    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
54        let one = F::one();
55        let zero = F::zero();
56
57        let result = x.mapv(|v| if v > self.threshold { one } else { zero });
58        Ok(result)
59    }
60}
61
62impl<F: Float> FittedBinarizer<F> {
63    /// Return the threshold value.
64    pub fn threshold(&self) -> F {
65        self.threshold
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use approx::assert_abs_diff_eq;
73    use ndarray::array;
74
75    #[test]
76    fn test_basic_thresholding() {
77        let x = array![[1.0, -1.0, 2.0], [0.5, 0.0, -0.5]];
78        let binarizer = Binarizer::new(0.5);
79        let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
80        let transformed = fitted.transform(&x).unwrap();
81
82        // Values > 0.5 become 1, others become 0
83        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); // 1.0 > 0.5
84        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); // -1.0 <= 0.5
85        assert_abs_diff_eq!(transformed[[0, 2]], 1.0, epsilon = 1e-10); // 2.0 > 0.5
86        assert_abs_diff_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10); // 0.5 is NOT > 0.5
87        assert_abs_diff_eq!(transformed[[1, 1]], 0.0, epsilon = 1e-10); // 0.0 <= 0.5
88        assert_abs_diff_eq!(transformed[[1, 2]], 0.0, epsilon = 1e-10); // -0.5 <= 0.5
89    }
90
91    #[test]
92    fn test_default_threshold_zero() {
93        let x = array![[1.0, 0.0, -1.0], [0.1, -0.1, 0.0]];
94        let binarizer = Binarizer::<f64>::default();
95        let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
96        let transformed = fitted.transform(&x).unwrap();
97
98        // threshold = 0: values > 0 become 1
99        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); // 1.0 > 0
100        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); // 0.0 is NOT > 0
101        assert_abs_diff_eq!(transformed[[0, 2]], 0.0, epsilon = 1e-10); // -1.0 <= 0
102        assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10); // 0.1 > 0
103        assert_abs_diff_eq!(transformed[[1, 1]], 0.0, epsilon = 1e-10); // -0.1 <= 0
104        assert_abs_diff_eq!(transformed[[1, 2]], 0.0, epsilon = 1e-10); // 0.0 is NOT > 0
105    }
106
107    #[test]
108    fn test_negative_threshold() {
109        let x = array![[-2.0, -1.0, 0.0], [1.0, -0.5, 0.5]];
110        let binarizer = Binarizer::new(-1.0);
111        let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
112        let transformed = fitted.transform(&x).unwrap();
113
114        // threshold = -1.0: values > -1.0 become 1
115        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10); // -2.0 <= -1.0
116        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); // -1.0 is NOT > -1.0
117        assert_abs_diff_eq!(transformed[[0, 2]], 1.0, epsilon = 1e-10); // 0.0 > -1.0
118        assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10); // 1.0 > -1.0
119        assert_abs_diff_eq!(transformed[[1, 1]], 1.0, epsilon = 1e-10); // -0.5 > -1.0
120    }
121
122    #[test]
123    fn test_all_ones_and_zeros() {
124        // All values above threshold -> all ones
125        let x = array![[10.0, 20.0], [30.0, 40.0]];
126        let binarizer = Binarizer::new(0.0);
127        let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
128        let transformed = fitted.transform(&x).unwrap();
129
130        for &v in transformed.iter() {
131            assert_abs_diff_eq!(v, 1.0, epsilon = 1e-10);
132        }
133
134        // All values below threshold -> all zeros
135        let x2 = array![[-10.0, -20.0], [-30.0, -40.0]];
136        let transformed2 = fitted.transform(&x2).unwrap();
137
138        for &v in transformed2.iter() {
139            assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
140        }
141    }
142
143    #[test]
144    fn test_f32_support() {
145        let x = array![[1.0f32, -1.0, 0.5], [0.0, 2.0, -0.5]];
146        let binarizer = Binarizer::new(0.0f32);
147        let fitted = FitUnsupervised::<f32>::fit(&binarizer, &x).unwrap();
148        let transformed = fitted.transform(&x).unwrap();
149
150        assert_abs_diff_eq!(transformed[[0, 0]], 1.0f32, epsilon = 1e-6);
151        assert_abs_diff_eq!(transformed[[0, 1]], 0.0f32, epsilon = 1e-6);
152    }
153
154    #[test]
155    fn test_empty_input() {
156        let x: Array2<f64> = Array2::zeros((0, 0));
157        let binarizer = Binarizer::<f64>::default();
158        let result = FitUnsupervised::<f64>::fit(&binarizer, &x);
159        assert!(result.is_err());
160    }
161}