1use ferrolearn_core::error::FerroError;
25use ferrolearn_core::traits::{Fit, Transform};
26use ndarray::{Array1, Array2};
27
28#[derive(Debug, Clone, Default)]
37pub struct LabelBinarizer;
38
39impl LabelBinarizer {
40 #[must_use]
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47#[derive(Debug, Clone)]
55pub struct FittedLabelBinarizer {
56 classes: Vec<usize>,
58}
59
60impl FittedLabelBinarizer {
61 #[must_use]
63 pub fn classes(&self) -> &[usize] {
64 &self.classes
65 }
66
67 #[must_use]
69 pub fn n_classes(&self) -> usize {
70 self.classes.len()
71 }
72
73 pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Array1<usize>, FerroError> {
82 let k = self.classes.len();
83 let expected_cols = if k == 2 { 1 } else { k };
84
85 if y.ncols() != expected_cols {
86 return Err(FerroError::ShapeMismatch {
87 expected: vec![y.nrows(), expected_cols],
88 actual: vec![y.nrows(), y.ncols()],
89 context: "FittedLabelBinarizer::inverse_transform".into(),
90 });
91 }
92
93 let n = y.nrows();
94 let mut result = Array1::zeros(n);
95
96 if k == 2 {
97 for i in 0..n {
99 result[i] = if y[[i, 0]] >= 0.5 {
100 self.classes[1]
101 } else {
102 self.classes[0]
103 };
104 }
105 } else {
106 for i in 0..n {
108 let row = y.row(i);
109 let mut best_j = 0;
110 let mut best_v = f64::NEG_INFINITY;
111 for (j, &v) in row.iter().enumerate() {
112 if v > best_v {
113 best_v = v;
114 best_j = j;
115 }
116 }
117 result[i] = self.classes[best_j];
118 }
119 }
120
121 Ok(result)
122 }
123}
124
125impl Fit<Array1<usize>, ()> for LabelBinarizer {
130 type Fitted = FittedLabelBinarizer;
131 type Error = FerroError;
132
133 fn fit(&self, y: &Array1<usize>, _target: &()) -> Result<FittedLabelBinarizer, FerroError> {
139 if y.is_empty() {
140 return Err(FerroError::InsufficientSamples {
141 required: 1,
142 actual: 0,
143 context: "LabelBinarizer::fit".into(),
144 });
145 }
146
147 let mut classes: Vec<usize> = y.iter().copied().collect();
148 classes.sort_unstable();
149 classes.dedup();
150
151 Ok(FittedLabelBinarizer { classes })
152 }
153}
154
155impl Transform<Array1<usize>> for FittedLabelBinarizer {
156 type Output = Array2<f64>;
157 type Error = FerroError;
158
159 fn transform(&self, y: &Array1<usize>) -> Result<Array2<f64>, FerroError> {
169 let k = self.classes.len();
170 let n = y.len();
171
172 let class_to_idx: std::collections::HashMap<usize, usize> = self
174 .classes
175 .iter()
176 .enumerate()
177 .map(|(i, &c)| (c, i))
178 .collect();
179
180 if k == 2 {
181 let mut out = Array2::zeros((n, 1));
183 for (i, &label) in y.iter().enumerate() {
184 let idx = class_to_idx
185 .get(&label)
186 .ok_or_else(|| FerroError::InvalidParameter {
187 name: "y".into(),
188 reason: format!("unknown label {label} not seen during fit"),
189 })?;
190 out[[i, 0]] = if *idx == 1 { 1.0 } else { 0.0 };
191 }
192 Ok(out)
193 } else {
194 let mut out = Array2::zeros((n, k));
196 for (i, &label) in y.iter().enumerate() {
197 let &idx =
198 class_to_idx
199 .get(&label)
200 .ok_or_else(|| FerroError::InvalidParameter {
201 name: "y".into(),
202 reason: format!("unknown label {label} not seen during fit"),
203 })?;
204 out[[i, idx]] = 1.0;
205 }
206 Ok(out)
207 }
208 }
209}
210
211#[cfg(test)]
216mod tests {
217 use super::*;
218 use ndarray::array;
219
220 #[test]
221 fn test_fit_discovers_sorted_classes() {
222 let lb = LabelBinarizer::new();
223 let y = array![2_usize, 0, 1, 2, 0];
224 let fitted = lb.fit(&y, &()).unwrap();
225 assert_eq!(fitted.classes(), &[0, 1, 2]);
226 }
227
228 #[test]
229 fn test_fit_empty_input_error() {
230 let lb = LabelBinarizer::new();
231 let y: Array1<usize> = Array1::zeros(0);
232 assert!(lb.fit(&y, &()).is_err());
233 }
234
235 #[test]
236 fn test_binary_transform_single_column() {
237 let lb = LabelBinarizer::new();
238 let y = array![0_usize, 1, 0, 1];
239 let fitted = lb.fit(&y, &()).unwrap();
240 let mat = fitted.transform(&y).unwrap();
241 assert_eq!(mat.shape(), &[4, 1]);
242 assert_eq!(mat[[0, 0]], 0.0); assert_eq!(mat[[1, 0]], 1.0); assert_eq!(mat[[2, 0]], 0.0);
245 assert_eq!(mat[[3, 0]], 1.0);
246 }
247
248 #[test]
249 fn test_multiclass_transform_indicator_matrix() {
250 let lb = LabelBinarizer::new();
251 let y = array![0_usize, 1, 2, 1];
252 let fitted = lb.fit(&y, &()).unwrap();
253 let mat = fitted.transform(&y).unwrap();
254 assert_eq!(mat.shape(), &[4, 3]);
255 assert_eq!(mat[[0, 0]], 1.0);
257 assert_eq!(mat[[0, 1]], 0.0);
258 assert_eq!(mat[[0, 2]], 0.0);
259 assert_eq!(mat[[2, 0]], 0.0);
261 assert_eq!(mat[[2, 1]], 0.0);
262 assert_eq!(mat[[2, 2]], 1.0);
263 }
264
265 #[test]
266 fn test_inverse_transform_multiclass() {
267 let lb = LabelBinarizer::new();
268 let y = array![0_usize, 1, 2, 1];
269 let fitted = lb.fit(&y, &()).unwrap();
270 let mat = fitted.transform(&y).unwrap();
271 let recovered = fitted.inverse_transform(&mat).unwrap();
272 assert_eq!(recovered, y);
273 }
274
275 #[test]
276 fn test_inverse_transform_binary() {
277 let lb = LabelBinarizer::new();
278 let y = array![0_usize, 1, 0, 1];
279 let fitted = lb.fit(&y, &()).unwrap();
280 let mat = fitted.transform(&y).unwrap();
281 let recovered = fitted.inverse_transform(&mat).unwrap();
282 assert_eq!(recovered, y);
283 }
284
285 #[test]
286 fn test_transform_unknown_label_error() {
287 let lb = LabelBinarizer::new();
288 let y = array![0_usize, 1, 2];
289 let fitted = lb.fit(&y, &()).unwrap();
290 let y2 = array![0_usize, 3]; assert!(fitted.transform(&y2).is_err());
292 }
293
294 #[test]
295 fn test_inverse_transform_shape_mismatch() {
296 let lb = LabelBinarizer::new();
297 let y = array![0_usize, 1, 2];
298 let fitted = lb.fit(&y, &()).unwrap();
299 let bad = Array2::<f64>::zeros((2, 2));
301 assert!(fitted.inverse_transform(&bad).is_err());
302 }
303
304 #[test]
305 fn test_single_class() {
306 let lb = LabelBinarizer::new();
307 let y = array![5_usize, 5, 5];
308 let fitted = lb.fit(&y, &()).unwrap();
309 assert_eq!(fitted.n_classes(), 1);
310 let mat = fitted.transform(&y).unwrap();
312 assert_eq!(mat.shape(), &[3, 1]);
313 }
314
315 #[test]
316 fn test_non_contiguous_classes() {
317 let lb = LabelBinarizer::new();
318 let y = array![10_usize, 20, 30, 10];
319 let fitted = lb.fit(&y, &()).unwrap();
320 assert_eq!(fitted.classes(), &[10, 20, 30]);
321 let mat = fitted.transform(&y).unwrap();
322 assert_eq!(mat.shape(), &[4, 3]);
323 assert_eq!(mat[[0, 0]], 1.0); assert_eq!(mat[[1, 1]], 1.0); assert_eq!(mat[[2, 2]], 1.0); }
327
328 #[test]
329 fn test_roundtrip_multiclass_non_contiguous() {
330 let lb = LabelBinarizer::new();
331 let y = array![10_usize, 20, 30, 20];
332 let fitted = lb.fit(&y, &()).unwrap();
333 let mat = fitted.transform(&y).unwrap();
334 let recovered = fitted.inverse_transform(&mat).unwrap();
335 assert_eq!(recovered, y);
336 }
337}