1use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Estimator, Fit, Trained, Transform, Untrained},
15 types::Float,
16};
17
18#[derive(Debug, Clone)]
20pub struct LabelBinarizerConfig {
21 pub neg_label: i32,
23 pub pos_label: i32,
25 pub sparse_output: bool,
27}
28
29impl Default for LabelBinarizerConfig {
30 fn default() -> Self {
31 Self {
32 neg_label: 0,
33 pos_label: 1,
34 sparse_output: false,
35 }
36 }
37}
38
39pub struct LabelBinarizer<T: Eq + Hash + Clone = i32, State = Untrained> {
41 config: LabelBinarizerConfig,
42 state: PhantomData<State>,
43 classes_: Option<Vec<T>>,
44 class_to_index_: Option<HashMap<T, usize>>,
45}
46
47impl<T: Eq + Hash + Clone> LabelBinarizer<T, Untrained> {
48 pub fn new() -> Self {
50 Self {
51 config: LabelBinarizerConfig::default(),
52 state: PhantomData,
53 classes_: None,
54 class_to_index_: None,
55 }
56 }
57
58 pub fn neg_label(mut self, neg_label: i32) -> Self {
60 self.config.neg_label = neg_label;
61 self
62 }
63
64 pub fn pos_label(mut self, pos_label: i32) -> Self {
66 self.config.pos_label = pos_label;
67 self
68 }
69}
70
71impl<T: Eq + Hash + Clone> Default for LabelBinarizer<T, Untrained> {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Untrained> {
78 type Config = LabelBinarizerConfig;
79 type Error = SklearsError;
80 type Float = Float;
81
82 fn config(&self) -> &Self::Config {
83 &self.config
84 }
85}
86
87impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Trained> {
88 type Config = LabelBinarizerConfig;
89 type Error = SklearsError;
90 type Float = Float;
91
92 fn config(&self) -> &Self::Config {
93 &self.config
94 }
95}
96
97impl<T: Eq + Hash + Clone + Ord + Send + Sync> Fit<Array1<T>, ()> for LabelBinarizer<T, Untrained> {
98 type Fitted = LabelBinarizer<T, Trained>;
99
100 fn fit(self, y: &Array1<T>, _x: &()) -> Result<Self::Fitted> {
101 let mut classes = HashSet::new();
103 for label in y.iter() {
104 classes.insert(label.clone());
105 }
106
107 let mut sorted_classes: Vec<T> = classes.into_iter().collect();
109 sorted_classes.sort();
110
111 let class_to_index: HashMap<T, usize> = sorted_classes
113 .iter()
114 .enumerate()
115 .map(|(i, c)| (c.clone(), i))
116 .collect();
117
118 Ok(LabelBinarizer {
119 config: self.config,
120 state: PhantomData,
121 classes_: Some(sorted_classes),
122 class_to_index_: Some(class_to_index),
123 })
124 }
125}
126
127impl<T: Eq + Hash + Clone> Transform<Array1<T>, Array2<Float>> for LabelBinarizer<T, Trained> {
128 fn transform(&self, y: &Array1<T>) -> Result<Array2<Float>> {
129 let classes = self.classes_.as_ref().unwrap();
130 let class_to_index = self.class_to_index_.as_ref().unwrap();
131 let n_samples = y.len();
132 let n_classes = classes.len();
133
134 if n_classes == 0 {
135 return Err(SklearsError::InvalidInput(
136 "No classes found during fit".to_string(),
137 ));
138 }
139
140 if n_classes == 2 {
142 let mut result = Array2::zeros((n_samples, 1));
143 for (i, label) in y.iter().enumerate() {
144 if let Some(&class_idx) = class_to_index.get(label) {
145 result[[i, 0]] = if class_idx == 1 {
146 self.config.pos_label as Float
147 } else {
148 self.config.neg_label as Float
149 };
150 } else {
151 return Err(SklearsError::InvalidInput(
152 "Unknown label encountered during transform".to_string(),
153 ));
154 }
155 }
156 Ok(result)
157 } else {
158 let mut result =
160 Array2::from_elem((n_samples, n_classes), self.config.neg_label as Float);
161 for (i, label) in y.iter().enumerate() {
162 if let Some(&class_idx) = class_to_index.get(label) {
163 result[[i, class_idx]] = self.config.pos_label as Float;
164 } else {
165 return Err(SklearsError::InvalidInput(
166 "Unknown label encountered during transform".to_string(),
167 ));
168 }
169 }
170 Ok(result)
171 }
172 }
173}
174
175impl<T: Eq + Hash + Clone> LabelBinarizer<T, Trained> {
176 pub fn classes(&self) -> &Vec<T> {
178 self.classes_.as_ref().unwrap()
179 }
180
181 pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Array1<T>> {
183 let classes = self.classes_.as_ref().unwrap();
184 let n_samples = y.nrows();
185 let n_classes = classes.len();
186
187 if n_classes == 2 && y.ncols() == 1 {
188 let mut result = Vec::with_capacity(n_samples);
190 let threshold = (self.config.neg_label + self.config.pos_label) as Float / 2.0;
191
192 for i in 0..n_samples {
193 let class_idx = if y[[i, 0]] > threshold { 1 } else { 0 };
194 result.push(classes[class_idx].clone());
195 }
196 Ok(Array1::from_vec(result))
197 } else if y.ncols() == n_classes {
198 let mut result = Vec::with_capacity(n_samples);
200
201 for i in 0..n_samples {
202 let row = y.row(i);
204 let mut max_idx = 0;
205 let mut max_val = row[0];
206
207 for j in 1..n_classes {
208 if row[j] > max_val {
209 max_val = row[j];
210 max_idx = j;
211 }
212 }
213
214 result.push(classes[max_idx].clone());
215 }
216 Ok(Array1::from_vec(result))
217 } else {
218 Err(SklearsError::InvalidInput(format!(
219 "Shape mismatch: y has {} columns but {} classes were expected",
220 y.ncols(),
221 n_classes
222 )))
223 }
224 }
225}
226
227#[derive(Debug, Clone, Default)]
229pub struct MultiLabelBinarizerConfig {
230 pub classes: Option<Vec<String>>,
232 pub sparse_output: bool,
234}
235
236pub struct MultiLabelBinarizer<State = Untrained> {
238 config: MultiLabelBinarizerConfig,
239 state: PhantomData<State>,
240 classes_: Option<Vec<String>>,
241 class_to_index_: Option<HashMap<String, usize>>,
242}
243
244impl MultiLabelBinarizer<Untrained> {
245 pub fn new() -> Self {
247 Self {
248 config: MultiLabelBinarizerConfig::default(),
249 state: PhantomData,
250 classes_: None,
251 class_to_index_: None,
252 }
253 }
254
255 pub fn classes(mut self, classes: Vec<String>) -> Self {
257 self.config.classes = Some(classes);
258 self
259 }
260}
261
262impl Default for MultiLabelBinarizer<Untrained> {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268impl Estimator for MultiLabelBinarizer<Untrained> {
269 type Config = MultiLabelBinarizerConfig;
270 type Error = SklearsError;
271 type Float = Float;
272
273 fn config(&self) -> &Self::Config {
274 &self.config
275 }
276}
277
278impl Estimator for MultiLabelBinarizer<Trained> {
279 type Config = MultiLabelBinarizerConfig;
280 type Error = SklearsError;
281 type Float = Float;
282
283 fn config(&self) -> &Self::Config {
284 &self.config
285 }
286}
287
288impl Fit<Vec<Vec<String>>, ()> for MultiLabelBinarizer<Untrained> {
289 type Fitted = MultiLabelBinarizer<Trained>;
290
291 fn fit(self, y: &Vec<Vec<String>>, _x: &()) -> Result<Self::Fitted> {
292 let classes = if let Some(ref classes) = self.config.classes {
293 classes.clone()
294 } else {
295 let mut unique_classes = HashSet::new();
297 for labels in y.iter() {
298 for label in labels.iter() {
299 unique_classes.insert(label.clone());
300 }
301 }
302
303 let mut sorted_classes: Vec<String> = unique_classes.into_iter().collect();
304 sorted_classes.sort();
305 sorted_classes
306 };
307
308 let class_to_index: HashMap<String, usize> = classes
310 .iter()
311 .enumerate()
312 .map(|(i, c)| (c.clone(), i))
313 .collect();
314
315 Ok(MultiLabelBinarizer {
316 config: self.config,
317 state: PhantomData,
318 classes_: Some(classes),
319 class_to_index_: Some(class_to_index),
320 })
321 }
322}
323
324impl Transform<Vec<Vec<String>>, Array2<Float>> for MultiLabelBinarizer<Trained> {
325 fn transform(&self, y: &Vec<Vec<String>>) -> Result<Array2<Float>> {
326 let classes = self.classes_.as_ref().unwrap();
327 let class_to_index = self.class_to_index_.as_ref().unwrap();
328 let n_samples = y.len();
329 let n_classes = classes.len();
330
331 let mut result = Array2::zeros((n_samples, n_classes));
332
333 for (i, labels) in y.iter().enumerate() {
334 for label in labels.iter() {
335 if let Some(&class_idx) = class_to_index.get(label) {
336 result[[i, class_idx]] = 1.0;
337 }
338 }
340 }
341
342 Ok(result)
343 }
344}
345
346impl MultiLabelBinarizer<Trained> {
347 pub fn classes(&self) -> &Vec<String> {
349 self.classes_.as_ref().unwrap()
350 }
351
352 pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Vec<Vec<String>>> {
354 let classes = self.classes_.as_ref().unwrap();
355 let n_samples = y.nrows();
356 let n_classes = classes.len();
357
358 if y.ncols() != n_classes {
359 return Err(SklearsError::InvalidInput(format!(
360 "Shape mismatch: y has {} columns but {} classes were expected",
361 y.ncols(),
362 n_classes
363 )));
364 }
365
366 let mut result = Vec::with_capacity(n_samples);
367
368 for i in 0..n_samples {
369 let mut labels = Vec::new();
370 for j in 0..n_classes {
371 if y[[i, j]] > 0.5 {
372 labels.push(classes[j].clone());
373 }
374 }
375 result.push(labels);
376 }
377
378 Ok(result)
379 }
380}
381
382#[allow(non_snake_case)]
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use scirs2_core::ndarray::array;
387
388 #[test]
389 fn test_label_binarizer_binary() {
390 let y = array![1, 0, 1, 0, 1];
391
392 let binarizer = LabelBinarizer::new().fit(&y, &()).unwrap();
393
394 let y_bin = binarizer.transform(&y).unwrap();
395
396 assert_eq!(y_bin.shape(), &[5, 1]);
398 assert_eq!(y_bin[[0, 0]], 1.0);
399 assert_eq!(y_bin[[1, 0]], 0.0);
400 assert_eq!(y_bin[[2, 0]], 1.0);
401 }
402
403 #[test]
404 fn test_label_binarizer_multiclass() {
405 let y = array![0, 1, 2, 1, 0];
406
407 let binarizer = LabelBinarizer::new().fit(&y, &()).unwrap();
408
409 let y_bin = binarizer.transform(&y).unwrap();
410
411 assert_eq!(y_bin.shape(), &[5, 3]);
413 assert_eq!(y_bin.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
415 assert_eq!(y_bin.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
417 assert_eq!(y_bin.row(2).to_vec(), vec![0.0, 0.0, 1.0]);
419 }
420
421 #[test]
422 fn test_label_binarizer_inverse_transform() {
423 let y = array!["cat", "dog", "cat", "bird", "dog"];
424
425 let binarizer = LabelBinarizer::new().fit(&y, &()).unwrap();
426
427 let y_bin = binarizer.transform(&y).unwrap();
428 let y_inv = binarizer.inverse_transform(&y_bin).unwrap();
429
430 assert_eq!(y, y_inv);
431 }
432
433 #[test]
434 fn test_label_binarizer_custom_labels() {
435 let y = array![1, 0, 1, 0];
436
437 let binarizer = LabelBinarizer::new()
438 .neg_label(-1)
439 .pos_label(1)
440 .fit(&y, &())
441 .unwrap();
442
443 let y_bin = binarizer.transform(&y).unwrap();
444
445 assert_eq!(y_bin[[0, 0]], 1.0); assert_eq!(y_bin[[1, 0]], -1.0); }
448
449 #[test]
450 fn test_multilabel_binarizer() {
451 let y = vec![
452 vec!["sci-fi".to_string(), "thriller".to_string()],
453 vec!["comedy".to_string()],
454 vec!["sci-fi".to_string(), "comedy".to_string()],
455 ];
456
457 let binarizer = MultiLabelBinarizer::new().fit(&y, &()).unwrap();
458
459 let y_bin = binarizer.transform(&y).unwrap();
460
461 assert_eq!(y_bin.shape(), &[3, 3]);
463 let classes = binarizer.classes();
464 assert_eq!(classes.len(), 3);
465
466 let row0_sum: Float = y_bin.row(0).sum();
468 assert_eq!(row0_sum, 2.0);
469
470 let row1_sum: Float = y_bin.row(1).sum();
472 assert_eq!(row1_sum, 1.0);
473 }
474
475 #[test]
476 fn test_multilabel_binarizer_inverse() {
477 let y = vec![
478 vec!["red".to_string(), "blue".to_string()],
479 vec!["green".to_string()],
480 vec!["red".to_string(), "green".to_string()],
481 ];
482
483 let binarizer = MultiLabelBinarizer::new().fit(&y, &()).unwrap();
484
485 let y_bin = binarizer.transform(&y).unwrap();
486 let y_inv = binarizer.inverse_transform(&y_bin).unwrap();
487
488 for (original, reconstructed) in y.iter().zip(y_inv.iter()) {
490 let orig_set: HashSet<_> = original.iter().collect();
491 let recon_set: HashSet<_> = reconstructed.iter().collect();
492 assert_eq!(orig_set, recon_set);
493 }
494 }
495
496 #[test]
497 fn test_multilabel_binarizer_with_classes() {
498 let y = vec![
499 vec!["a".to_string(), "b".to_string()],
500 vec!["c".to_string()],
501 ];
502
503 let classes = vec![
504 "a".to_string(),
505 "b".to_string(),
506 "c".to_string(),
507 "d".to_string(),
508 ];
509
510 let binarizer = MultiLabelBinarizer::new()
511 .classes(classes.clone())
512 .fit(&y, &())
513 .unwrap();
514
515 let y_bin = binarizer.transform(&y).unwrap();
516
517 assert_eq!(y_bin.shape(), &[2, 4]);
519 assert_eq!(binarizer.classes(), &classes);
520 }
521}