1use ferrolearn_core::error::FerroError;
25use ferrolearn_core::traits::{Fit, Transform};
26use ndarray::{Array1, Array2};
27
28#[derive(Debug, Clone, Default)]
37pub struct LabelBinarizer;
38
39impl LabelBinarizer {
40 #[must_use]
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47#[derive(Debug, Clone)]
55pub struct FittedLabelBinarizer {
56 classes: Vec<usize>,
58}
59
60impl FittedLabelBinarizer {
61 #[must_use]
63 pub fn classes(&self) -> &[usize] {
64 &self.classes
65 }
66
67 #[must_use]
69 pub fn n_classes(&self) -> usize {
70 self.classes.len()
71 }
72
73 pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Array1<usize>, FerroError> {
82 let k = self.classes.len();
83 let expected_cols = if k == 2 { 1 } else { k };
84
85 if y.ncols() != expected_cols {
86 return Err(FerroError::ShapeMismatch {
87 expected: vec![y.nrows(), expected_cols],
88 actual: vec![y.nrows(), y.ncols()],
89 context: "FittedLabelBinarizer::inverse_transform".into(),
90 });
91 }
92
93 let n = y.nrows();
94 let mut result = Array1::zeros(n);
95
96 if k == 2 {
97 for i in 0..n {
99 result[i] = if y[[i, 0]] >= 0.5 {
100 self.classes[1]
101 } else {
102 self.classes[0]
103 };
104 }
105 } else {
106 for i in 0..n {
108 let row = y.row(i);
109 let mut best_j = 0;
110 let mut best_v = f64::NEG_INFINITY;
111 for (j, &v) in row.iter().enumerate() {
112 if v > best_v {
113 best_v = v;
114 best_j = j;
115 }
116 }
117 result[i] = self.classes[best_j];
118 }
119 }
120
121 Ok(result)
122 }
123}
124
125impl Fit<Array1<usize>, ()> for LabelBinarizer {
130 type Fitted = FittedLabelBinarizer;
131 type Error = FerroError;
132
133 fn fit(&self, y: &Array1<usize>, _target: &()) -> Result<FittedLabelBinarizer, FerroError> {
139 if y.is_empty() {
140 return Err(FerroError::InsufficientSamples {
141 required: 1,
142 actual: 0,
143 context: "LabelBinarizer::fit".into(),
144 });
145 }
146
147 let mut classes: Vec<usize> = y.iter().copied().collect();
148 classes.sort_unstable();
149 classes.dedup();
150
151 Ok(FittedLabelBinarizer { classes })
152 }
153}
154
155impl Transform<Array1<usize>> for FittedLabelBinarizer {
156 type Output = Array2<f64>;
157 type Error = FerroError;
158
159 fn transform(&self, y: &Array1<usize>) -> Result<Array2<f64>, FerroError> {
169 let k = self.classes.len();
170 let n = y.len();
171
172 let class_to_idx: std::collections::HashMap<usize, usize> = self
174 .classes
175 .iter()
176 .enumerate()
177 .map(|(i, &c)| (c, i))
178 .collect();
179
180 if k == 2 {
181 let mut out = Array2::zeros((n, 1));
183 for (i, &label) in y.iter().enumerate() {
184 let idx = class_to_idx.get(&label).ok_or_else(|| {
185 FerroError::InvalidParameter {
186 name: "y".into(),
187 reason: format!("unknown label {label} not seen during fit"),
188 }
189 })?;
190 out[[i, 0]] = if *idx == 1 { 1.0 } else { 0.0 };
191 }
192 Ok(out)
193 } else {
194 let mut out = Array2::zeros((n, k));
196 for (i, &label) in y.iter().enumerate() {
197 let &idx = class_to_idx.get(&label).ok_or_else(|| {
198 FerroError::InvalidParameter {
199 name: "y".into(),
200 reason: format!("unknown label {label} not seen during fit"),
201 }
202 })?;
203 out[[i, idx]] = 1.0;
204 }
205 Ok(out)
206 }
207 }
208}
209
210#[cfg(test)]
215mod tests {
216 use super::*;
217 use ndarray::array;
218
219 #[test]
220 fn test_fit_discovers_sorted_classes() {
221 let lb = LabelBinarizer::new();
222 let y = array![2_usize, 0, 1, 2, 0];
223 let fitted = lb.fit(&y, &()).unwrap();
224 assert_eq!(fitted.classes(), &[0, 1, 2]);
225 }
226
227 #[test]
228 fn test_fit_empty_input_error() {
229 let lb = LabelBinarizer::new();
230 let y: Array1<usize> = Array1::zeros(0);
231 assert!(lb.fit(&y, &()).is_err());
232 }
233
234 #[test]
235 fn test_binary_transform_single_column() {
236 let lb = LabelBinarizer::new();
237 let y = array![0_usize, 1, 0, 1];
238 let fitted = lb.fit(&y, &()).unwrap();
239 let mat = fitted.transform(&y).unwrap();
240 assert_eq!(mat.shape(), &[4, 1]);
241 assert_eq!(mat[[0, 0]], 0.0); assert_eq!(mat[[1, 0]], 1.0); assert_eq!(mat[[2, 0]], 0.0);
244 assert_eq!(mat[[3, 0]], 1.0);
245 }
246
247 #[test]
248 fn test_multiclass_transform_indicator_matrix() {
249 let lb = LabelBinarizer::new();
250 let y = array![0_usize, 1, 2, 1];
251 let fitted = lb.fit(&y, &()).unwrap();
252 let mat = fitted.transform(&y).unwrap();
253 assert_eq!(mat.shape(), &[4, 3]);
254 assert_eq!(mat[[0, 0]], 1.0);
256 assert_eq!(mat[[0, 1]], 0.0);
257 assert_eq!(mat[[0, 2]], 0.0);
258 assert_eq!(mat[[2, 0]], 0.0);
260 assert_eq!(mat[[2, 1]], 0.0);
261 assert_eq!(mat[[2, 2]], 1.0);
262 }
263
264 #[test]
265 fn test_inverse_transform_multiclass() {
266 let lb = LabelBinarizer::new();
267 let y = array![0_usize, 1, 2, 1];
268 let fitted = lb.fit(&y, &()).unwrap();
269 let mat = fitted.transform(&y).unwrap();
270 let recovered = fitted.inverse_transform(&mat).unwrap();
271 assert_eq!(recovered, y);
272 }
273
274 #[test]
275 fn test_inverse_transform_binary() {
276 let lb = LabelBinarizer::new();
277 let y = array![0_usize, 1, 0, 1];
278 let fitted = lb.fit(&y, &()).unwrap();
279 let mat = fitted.transform(&y).unwrap();
280 let recovered = fitted.inverse_transform(&mat).unwrap();
281 assert_eq!(recovered, y);
282 }
283
284 #[test]
285 fn test_transform_unknown_label_error() {
286 let lb = LabelBinarizer::new();
287 let y = array![0_usize, 1, 2];
288 let fitted = lb.fit(&y, &()).unwrap();
289 let y2 = array![0_usize, 3]; assert!(fitted.transform(&y2).is_err());
291 }
292
293 #[test]
294 fn test_inverse_transform_shape_mismatch() {
295 let lb = LabelBinarizer::new();
296 let y = array![0_usize, 1, 2];
297 let fitted = lb.fit(&y, &()).unwrap();
298 let bad = Array2::<f64>::zeros((2, 2));
300 assert!(fitted.inverse_transform(&bad).is_err());
301 }
302
303 #[test]
304 fn test_single_class() {
305 let lb = LabelBinarizer::new();
306 let y = array![5_usize, 5, 5];
307 let fitted = lb.fit(&y, &()).unwrap();
308 assert_eq!(fitted.n_classes(), 1);
309 let mat = fitted.transform(&y).unwrap();
311 assert_eq!(mat.shape(), &[3, 1]);
312 }
313
314 #[test]
315 fn test_non_contiguous_classes() {
316 let lb = LabelBinarizer::new();
317 let y = array![10_usize, 20, 30, 10];
318 let fitted = lb.fit(&y, &()).unwrap();
319 assert_eq!(fitted.classes(), &[10, 20, 30]);
320 let mat = fitted.transform(&y).unwrap();
321 assert_eq!(mat.shape(), &[4, 3]);
322 assert_eq!(mat[[0, 0]], 1.0); assert_eq!(mat[[1, 1]], 1.0); assert_eq!(mat[[2, 2]], 1.0); }
326
327 #[test]
328 fn test_roundtrip_multiclass_non_contiguous() {
329 let lb = LabelBinarizer::new();
330 let y = array![10_usize, 20, 30, 20];
331 let fitted = lb.fit(&y, &()).unwrap();
332 let mat = fitted.transform(&y).unwrap();
333 let recovered = fitted.inverse_transform(&mat).unwrap();
334 assert_eq!(recovered, y);
335 }
336}