1use ferrolearn_core::error::FerroError;
10use ferrolearn_core::traits::Transform;
11use ndarray::Array2;
12use num_traits::Float;
13
14#[derive(Debug, Clone)]
39pub struct Binarizer<F> {
40 pub(crate) threshold: F,
42}
43
44impl<F: Float + Send + Sync + 'static> Binarizer<F> {
45 #[must_use]
47 pub fn new(threshold: F) -> Self {
48 Self { threshold }
49 }
50
51 #[must_use]
53 pub fn threshold(&self) -> F {
54 self.threshold
55 }
56}
57
58impl<F: Float + Send + Sync + 'static> Default for Binarizer<F> {
59 fn default() -> Self {
60 Self::new(F::zero())
61 }
62}
63
64impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for Binarizer<F> {
69 type Output = Array2<F>;
70 type Error = FerroError;
71
72 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
78 let out = x.mapv(|v| {
79 if v > self.threshold {
80 F::one()
81 } else {
82 F::zero()
83 }
84 });
85 Ok(out)
86 }
87}
88
89#[cfg(test)]
94mod tests {
95 use super::*;
96 use approx::assert_abs_diff_eq;
97 use ndarray::array;
98
99 #[test]
100 fn test_binarizer_default_threshold() {
101 let b = Binarizer::<f64>::default();
102 assert_eq!(b.threshold(), 0.0);
103 let x = array![[-1.0, 0.0, 0.5, 1.0]];
104 let out = b.transform(&x).unwrap();
105 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10); }
110
111 #[test]
112 fn test_binarizer_custom_threshold() {
113 let b = Binarizer::<f64>::new(0.5);
114 let x = array![[0.0, 0.5, 1.0]];
115 let out = b.transform(&x).unwrap();
116 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10); }
120
121 #[test]
122 fn test_binarizer_all_zeros() {
123 let b = Binarizer::<f64>::new(0.0);
124 let x = array![[0.0, 0.0, 0.0]];
125 let out = b.transform(&x).unwrap();
126 for v in out.iter() {
127 assert_abs_diff_eq!(*v, 0.0, epsilon = 1e-10);
128 }
129 }
130
131 #[test]
132 fn test_binarizer_all_ones() {
133 let b = Binarizer::<f64>::new(0.0);
134 let x = array![[1.0, 2.0, 3.0]];
135 let out = b.transform(&x).unwrap();
136 for v in out.iter() {
137 assert_abs_diff_eq!(*v, 1.0, epsilon = 1e-10);
138 }
139 }
140
141 #[test]
142 fn test_binarizer_negative_threshold() {
143 let b = Binarizer::<f64>::new(-1.0);
144 let x = array![[-2.0, -1.0, -0.5, 0.0]];
145 let out = b.transform(&x).unwrap();
146 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10); }
151
152 #[test]
153 fn test_binarizer_multiple_rows() {
154 let b = Binarizer::<f64>::new(2.0);
155 let x = array![[1.0, 3.0], [2.0, 4.0], [5.0, 0.0]];
156 let out = b.transform(&x).unwrap();
157 assert_eq!(out.shape(), &[3, 2]);
158 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[1, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[2, 1]], 0.0, epsilon = 1e-10); }
165
166 #[test]
167 fn test_binarizer_preserves_shape() {
168 let b = Binarizer::<f64>::default();
169 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
170 let out = b.transform(&x).unwrap();
171 assert_eq!(out.shape(), x.shape());
172 }
173
174 #[test]
175 fn test_binarizer_f32() {
176 let b = Binarizer::<f32>::new(0.0f32);
177 let x: Array2<f32> = array![[1.0f32, -1.0, 0.0]];
178 let out = b.transform(&x).unwrap();
179 assert!((out[[0, 0]] - 1.0f32).abs() < 1e-6);
180 assert!((out[[0, 1]] - 0.0f32).abs() < 1e-6);
181 assert!((out[[0, 2]] - 0.0f32).abs() < 1e-6);
182 }
183
184 #[test]
185 fn test_output_values_are_zero_or_one() {
186 let b = Binarizer::<f64>::new(0.0);
187 let x = array![[-5.0, -1.0, 0.0, 0.001, 1.0, 100.0]];
188 let out = b.transform(&x).unwrap();
189 for v in out.iter() {
190 assert!(*v == 0.0 || *v == 1.0, "expected 0 or 1, got {v}");
191 }
192 }
193}