Skip to main content

ferrolearn_preprocess/
binarizer.rs

1//! Binarizer: threshold features to binary values.
2//!
3//! Values strictly greater than the threshold are set to `1.0`; all other
4//! values are set to `0.0`.
5//!
6//! This transformer is **stateless** — no fitting is required. Call
7//! [`Transform::transform`] directly.
8
9use ferrolearn_core::error::FerroError;
10use ferrolearn_core::traits::Transform;
11use ndarray::Array2;
12use num_traits::Float;
13
14// ---------------------------------------------------------------------------
15// Binarizer
16// ---------------------------------------------------------------------------
17
18/// A stateless feature binarizer.
19///
20/// Values strictly greater than `threshold` become `1.0`; all other values
21/// become `0.0`. The default threshold is `0.0`.
22///
23/// This transformer is stateless — no fitting is needed. Call
24/// [`Transform::transform`] directly.
25///
26/// # Examples
27///
28/// ```
29/// use ferrolearn_preprocess::binarizer::Binarizer;
30/// use ferrolearn_core::traits::Transform;
31/// use ndarray::array;
32///
33/// let binarizer = Binarizer::<f64>::new(0.5);
34/// let x = array![[0.0, 0.5, 1.0]];
35/// let out = binarizer.transform(&x).unwrap();
36/// // out = [[0.0, 0.0, 1.0]]
37/// ```
38#[derive(Debug, Clone)]
39pub struct Binarizer<F> {
40    /// The threshold value. Values strictly greater than this become 1.0.
41    pub(crate) threshold: F,
42}
43
44impl<F: Float + Send + Sync + 'static> Binarizer<F> {
45    /// Create a new `Binarizer` with the given threshold.
46    #[must_use]
47    pub fn new(threshold: F) -> Self {
48        Self { threshold }
49    }
50
51    /// Return the configured threshold.
52    #[must_use]
53    pub fn threshold(&self) -> F {
54        self.threshold
55    }
56}
57
58impl<F: Float + Send + Sync + 'static> Default for Binarizer<F> {
59    fn default() -> Self {
60        Self::new(F::zero())
61    }
62}
63
64// ---------------------------------------------------------------------------
65// Trait implementations
66// ---------------------------------------------------------------------------
67
68impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for Binarizer<F> {
69    type Output = Array2<F>;
70    type Error = FerroError;
71
72    /// Apply the threshold: values > threshold become `1.0`, others become `0.0`.
73    ///
74    /// # Errors
75    ///
76    /// This implementation never returns an error for well-formed inputs.
77    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
78        let out = x.mapv(|v| {
79            if v > self.threshold {
80                F::one()
81            } else {
82                F::zero()
83            }
84        });
85        Ok(out)
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Tests
91// ---------------------------------------------------------------------------
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use approx::assert_abs_diff_eq;
97    use ndarray::array;
98
99    #[test]
100    fn test_binarizer_default_threshold() {
101        let b = Binarizer::<f64>::default();
102        assert_eq!(b.threshold(), 0.0);
103        let x = array![[-1.0, 0.0, 0.5, 1.0]];
104        let out = b.transform(&x).unwrap();
105        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // -1 <= 0
106        assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10); // 0 not > 0
107        assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10); // 0.5 > 0
108        assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10); // 1.0 > 0
109    }
110
111    #[test]
112    fn test_binarizer_custom_threshold() {
113        let b = Binarizer::<f64>::new(0.5);
114        let x = array![[0.0, 0.5, 1.0]];
115        let out = b.transform(&x).unwrap();
116        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // 0.0 not > 0.5
117        assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10); // 0.5 not > 0.5 (strict)
118        assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10); // 1.0 > 0.5
119    }
120
121    #[test]
122    fn test_binarizer_all_zeros() {
123        let b = Binarizer::<f64>::new(0.0);
124        let x = array![[0.0, 0.0, 0.0]];
125        let out = b.transform(&x).unwrap();
126        for v in out.iter() {
127            assert_abs_diff_eq!(*v, 0.0, epsilon = 1e-10);
128        }
129    }
130
131    #[test]
132    fn test_binarizer_all_ones() {
133        let b = Binarizer::<f64>::new(0.0);
134        let x = array![[1.0, 2.0, 3.0]];
135        let out = b.transform(&x).unwrap();
136        for v in out.iter() {
137            assert_abs_diff_eq!(*v, 1.0, epsilon = 1e-10);
138        }
139    }
140
141    #[test]
142    fn test_binarizer_negative_threshold() {
143        let b = Binarizer::<f64>::new(-1.0);
144        let x = array![[-2.0, -1.0, -0.5, 0.0]];
145        let out = b.transform(&x).unwrap();
146        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // -2 <= -1
147        assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10); // -1 not > -1
148        assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10); // -0.5 > -1
149        assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10); // 0.0 > -1
150    }
151
152    #[test]
153    fn test_binarizer_multiple_rows() {
154        let b = Binarizer::<f64>::new(2.0);
155        let x = array![[1.0, 3.0], [2.0, 4.0], [5.0, 0.0]];
156        let out = b.transform(&x).unwrap();
157        assert_eq!(out.shape(), &[3, 2]);
158        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); // 1 <= 2
159        assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10); // 3 > 2
160        assert_abs_diff_eq!(out[[1, 0]], 0.0, epsilon = 1e-10); // 2 not > 2
161        assert_abs_diff_eq!(out[[1, 1]], 1.0, epsilon = 1e-10); // 4 > 2
162        assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10); // 5 > 2
163        assert_abs_diff_eq!(out[[2, 1]], 0.0, epsilon = 1e-10); // 0 <= 2
164    }
165
166    #[test]
167    fn test_binarizer_preserves_shape() {
168        let b = Binarizer::<f64>::default();
169        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
170        let out = b.transform(&x).unwrap();
171        assert_eq!(out.shape(), x.shape());
172    }
173
174    #[test]
175    fn test_binarizer_f32() {
176        let b = Binarizer::<f32>::new(0.0f32);
177        let x: Array2<f32> = array![[1.0f32, -1.0, 0.0]];
178        let out = b.transform(&x).unwrap();
179        assert!((out[[0, 0]] - 1.0f32).abs() < 1e-6);
180        assert!((out[[0, 1]] - 0.0f32).abs() < 1e-6);
181        assert!((out[[0, 2]] - 0.0f32).abs() < 1e-6);
182    }
183
184    #[test]
185    fn test_output_values_are_zero_or_one() {
186        let b = Binarizer::<f64>::new(0.0);
187        let x = array![[-5.0, -1.0, 0.0, 0.001, 1.0, 100.0]];
188        let out = b.transform(&x).unwrap();
189        for v in out.iter() {
190            assert!(*v == 0.0 || *v == 1.0, "expected 0 or 1, got {v}");
191        }
192    }
193}