1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
12pub struct Binarizer<F: Float> {
13 pub threshold: F,
15}
16
17impl<F: Float> Binarizer<F> {
18 pub fn new(threshold: F) -> Self {
20 Self { threshold }
21 }
22}
23
24impl<F: Float> Default for Binarizer<F> {
25 fn default() -> Self {
27 Self::new(F::zero())
28 }
29}
30
31#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
33#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
34pub struct FittedBinarizer<F: Float> {
35 threshold: F,
36}
37
38impl<F: Float> FitUnsupervised<F> for Binarizer<F> {
39 type Fitted = FittedBinarizer<F>;
40
41 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
42 if x.is_empty() {
43 return Err(RustMlError::EmptyInput("input array is empty".into()));
44 }
45
46 Ok(FittedBinarizer {
47 threshold: self.threshold,
48 })
49 }
50}
51
52impl<F: Float> Transform<F> for FittedBinarizer<F> {
53 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
54 let one = F::one();
55 let zero = F::zero();
56
57 let result = x.mapv(|v| if v > self.threshold { one } else { zero });
58 Ok(result)
59 }
60}
61
62impl<F: Float> FittedBinarizer<F> {
63 pub fn threshold(&self) -> F {
65 self.threshold
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use approx::assert_abs_diff_eq;
73 use ndarray::array;
74
75 #[test]
76 fn test_basic_thresholding() {
77 let x = array![[1.0, -1.0, 2.0], [0.5, 0.0, -0.5]];
78 let binarizer = Binarizer::new(0.5);
79 let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
80 let transformed = fitted.transform(&x).unwrap();
81
82 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 2]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 2]], 0.0, epsilon = 1e-10); }
90
91 #[test]
92 fn test_default_threshold_zero() {
93 let x = array![[1.0, 0.0, -1.0], [0.1, -0.1, 0.0]];
94 let binarizer = Binarizer::<f64>::default();
95 let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
96 let transformed = fitted.transform(&x).unwrap();
97
98 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 2]], 0.0, epsilon = 1e-10); }
106
107 #[test]
108 fn test_negative_threshold() {
109 let x = array![[-2.0, -1.0, 0.0], [1.0, -0.5, 0.5]];
110 let binarizer = Binarizer::new(-1.0);
111 let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
112 let transformed = fitted.transform(&x).unwrap();
113
114 assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 2]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 1.0, epsilon = 1e-10); }
121
122 #[test]
123 fn test_all_ones_and_zeros() {
124 let x = array![[10.0, 20.0], [30.0, 40.0]];
126 let binarizer = Binarizer::new(0.0);
127 let fitted = FitUnsupervised::<f64>::fit(&binarizer, &x).unwrap();
128 let transformed = fitted.transform(&x).unwrap();
129
130 for &v in transformed.iter() {
131 assert_abs_diff_eq!(v, 1.0, epsilon = 1e-10);
132 }
133
134 let x2 = array![[-10.0, -20.0], [-30.0, -40.0]];
136 let transformed2 = fitted.transform(&x2).unwrap();
137
138 for &v in transformed2.iter() {
139 assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
140 }
141 }
142
143 #[test]
144 fn test_f32_support() {
145 let x = array![[1.0f32, -1.0, 0.5], [0.0, 2.0, -0.5]];
146 let binarizer = Binarizer::new(0.0f32);
147 let fitted = FitUnsupervised::<f32>::fit(&binarizer, &x).unwrap();
148 let transformed = fitted.transform(&x).unwrap();
149
150 assert_abs_diff_eq!(transformed[[0, 0]], 1.0f32, epsilon = 1e-6);
151 assert_abs_diff_eq!(transformed[[0, 1]], 0.0f32, epsilon = 1e-6);
152 }
153
154 #[test]
155 fn test_empty_input() {
156 let x: Array2<f64> = Array2::zeros((0, 0));
157 let binarizer = Binarizer::<f64>::default();
158 let result = FitUnsupervised::<f64>::fit(&binarizer, &x);
159 assert!(result.is_err());
160 }
161}