1use ferrolearn_core::error::FerroError;
24use ferrolearn_core::traits::{Fit, Transform};
25use ndarray::Array2;
26
27#[derive(Debug, Clone, Default)]
36pub struct MultiLabelBinarizer;
37
38impl MultiLabelBinarizer {
39 #[must_use]
41 pub fn new() -> Self {
42 Self
43 }
44}
45
46#[derive(Debug, Clone)]
54pub struct FittedMultiLabelBinarizer {
55 classes: Vec<usize>,
57}
58
59impl FittedMultiLabelBinarizer {
60 #[must_use]
62 pub fn classes(&self) -> &[usize] {
63 &self.classes
64 }
65
66 #[must_use]
68 pub fn n_classes(&self) -> usize {
69 self.classes.len()
70 }
71
72 pub fn inverse_transform(&self, y: &Array2<f64>) -> Result<Vec<Vec<usize>>, FerroError> {
81 let k = self.classes.len();
82 if y.ncols() != k {
83 return Err(FerroError::ShapeMismatch {
84 expected: vec![y.nrows(), k],
85 actual: vec![y.nrows(), y.ncols()],
86 context: "FittedMultiLabelBinarizer::inverse_transform".into(),
87 });
88 }
89
90 let n = y.nrows();
91 let mut result = Vec::with_capacity(n);
92
93 for i in 0..n {
94 let mut labels = Vec::new();
95 for (j, &cls) in self.classes.iter().enumerate() {
96 if y[[i, j]] >= 0.5 {
97 labels.push(cls);
98 }
99 }
100 result.push(labels);
101 }
102
103 Ok(result)
104 }
105}
106
107impl Fit<Vec<Vec<usize>>, ()> for MultiLabelBinarizer {
112 type Fitted = FittedMultiLabelBinarizer;
113 type Error = FerroError;
114
115 fn fit(
121 &self,
122 y: &Vec<Vec<usize>>,
123 _target: &(),
124 ) -> Result<FittedMultiLabelBinarizer, FerroError> {
125 if y.is_empty() {
126 return Err(FerroError::InsufficientSamples {
127 required: 1,
128 actual: 0,
129 context: "MultiLabelBinarizer::fit".into(),
130 });
131 }
132
133 let mut classes: Vec<usize> = y.iter().flatten().copied().collect();
134 classes.sort_unstable();
135 classes.dedup();
136
137 Ok(FittedMultiLabelBinarizer { classes })
138 }
139}
140
141impl Transform<Vec<Vec<usize>>> for FittedMultiLabelBinarizer {
142 type Output = Array2<f64>;
143 type Error = FerroError;
144
145 fn transform(&self, y: &Vec<Vec<usize>>) -> Result<Array2<f64>, FerroError> {
155 let k = self.classes.len();
156 let n = y.len();
157
158 let class_to_idx: std::collections::HashMap<usize, usize> = self
160 .classes
161 .iter()
162 .enumerate()
163 .map(|(i, &c)| (c, i))
164 .collect();
165
166 let mut out = Array2::zeros((n, k));
167
168 for (i, labels) in y.iter().enumerate() {
169 for &label in labels {
170 let &idx = class_to_idx.get(&label).ok_or_else(|| {
171 FerroError::InvalidParameter {
172 name: "y".into(),
173 reason: format!("unknown label {label} not seen during fit"),
174 }
175 })?;
176 out[[i, idx]] = 1.0;
177 }
178 }
179
180 Ok(out)
181 }
182}
183
184#[cfg(test)]
189mod tests {
190 use super::*;
191 use ndarray::array;
192
193 #[test]
194 fn test_fit_discovers_sorted_classes() {
195 let mlb = MultiLabelBinarizer::new();
196 let y = vec![vec![2, 0], vec![1]];
197 let fitted = mlb.fit(&y, &()).unwrap();
198 assert_eq!(fitted.classes(), &[0, 1, 2]);
199 }
200
201 #[test]
202 fn test_fit_empty_input_error() {
203 let mlb = MultiLabelBinarizer::new();
204 let y: Vec<Vec<usize>> = vec![];
205 assert!(mlb.fit(&y, &()).is_err());
206 }
207
208 #[test]
209 fn test_transform_multi_hot() {
210 let mlb = MultiLabelBinarizer::new();
211 let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
212 let fitted = mlb.fit(&y, &()).unwrap();
213 let mat = fitted.transform(&y).unwrap();
214 assert_eq!(mat.shape(), &[3, 3]);
215 assert_eq!(mat[[0, 0]], 1.0);
217 assert_eq!(mat[[0, 1]], 0.0);
218 assert_eq!(mat[[0, 2]], 1.0);
219 assert_eq!(mat[[1, 0]], 0.0);
221 assert_eq!(mat[[1, 1]], 1.0);
222 assert_eq!(mat[[1, 2]], 0.0);
223 assert_eq!(mat[[2, 0]], 1.0);
225 assert_eq!(mat[[2, 1]], 1.0);
226 assert_eq!(mat[[2, 2]], 1.0);
227 }
228
229 #[test]
230 fn test_transform_unknown_label_error() {
231 let mlb = MultiLabelBinarizer::new();
232 let y = vec![vec![0, 1]];
233 let fitted = mlb.fit(&y, &()).unwrap();
234 let y2 = vec![vec![0, 5]]; assert!(fitted.transform(&y2).is_err());
236 }
237
238 #[test]
239 fn test_inverse_transform_roundtrip() {
240 let mlb = MultiLabelBinarizer::new();
241 let y = vec![vec![0, 2], vec![1], vec![0, 1, 2]];
242 let fitted = mlb.fit(&y, &()).unwrap();
243 let mat = fitted.transform(&y).unwrap();
244 let recovered = fitted.inverse_transform(&mat).unwrap();
245 assert_eq!(recovered, y);
246 }
247
248 #[test]
249 fn test_inverse_transform_shape_mismatch() {
250 let mlb = MultiLabelBinarizer::new();
251 let y = vec![vec![0, 1, 2]];
252 let fitted = mlb.fit(&y, &()).unwrap();
253 let bad = Array2::<f64>::zeros((2, 2));
255 assert!(fitted.inverse_transform(&bad).is_err());
256 }
257
258 #[test]
259 fn test_empty_label_set() {
260 let mlb = MultiLabelBinarizer::new();
261 let y = vec![vec![0, 1], vec![]]; let fitted = mlb.fit(&y, &()).unwrap();
263 let mat = fitted.transform(&y).unwrap();
264 assert_eq!(mat.shape(), &[2, 2]);
265 assert_eq!(mat[[1, 0]], 0.0);
267 assert_eq!(mat[[1, 1]], 0.0);
268 }
269
270 #[test]
271 fn test_inverse_transform_empty_row() {
272 let mlb = MultiLabelBinarizer::new();
273 let y = vec![vec![0, 1], vec![]];
274 let fitted = mlb.fit(&y, &()).unwrap();
275 let mat = fitted.transform(&y).unwrap();
276 let recovered = fitted.inverse_transform(&mat).unwrap();
277 assert_eq!(recovered, y);
278 }
279
280 #[test]
281 fn test_non_contiguous_classes() {
282 let mlb = MultiLabelBinarizer::new();
283 let y = vec![vec![10, 30], vec![20]];
284 let fitted = mlb.fit(&y, &()).unwrap();
285 assert_eq!(fitted.classes(), &[10, 20, 30]);
286 let mat = fitted.transform(&y).unwrap();
287 assert_eq!(mat.shape(), &[2, 3]);
288 assert_eq!(mat[[0, 0]], 1.0); assert_eq!(mat[[0, 1]], 0.0); assert_eq!(mat[[0, 2]], 1.0); }
292
293 #[test]
294 fn test_inverse_transform_non_contiguous_roundtrip() {
295 let mlb = MultiLabelBinarizer::new();
296 let y = vec![vec![10, 30], vec![20]];
297 let fitted = mlb.fit(&y, &()).unwrap();
298 let mat = fitted.transform(&y).unwrap();
299 let recovered = fitted.inverse_transform(&mat).unwrap();
300 assert_eq!(recovered, y);
301 }
302
303 #[test]
304 fn test_duplicate_labels_in_input() {
305 let mlb = MultiLabelBinarizer::new();
306 let y = vec![vec![0, 0, 1]]; let fitted = mlb.fit(&y, &()).unwrap();
308 let mat = fitted.transform(&y).unwrap();
309 assert_eq!(mat.shape(), &[1, 2]);
311 assert_eq!(mat[[0, 0]], 1.0);
312 assert_eq!(mat[[0, 1]], 1.0);
313 }
314
315 #[test]
316 fn test_inverse_threshold() {
317 let mlb = MultiLabelBinarizer::new();
318 let y = vec![vec![0, 1, 2]];
319 let fitted = mlb.fit(&y, &()).unwrap();
320 let mat = array![[0.4, 0.6, 0.5]];
322 let recovered = fitted.inverse_transform(&mat).unwrap();
323 assert_eq!(recovered, vec![vec![1, 2]]); }
325}