sklears_multioutput/
multi_label.rs

1//! Multi-label learning algorithms
2//!
3//! This module provides algorithms for multi-label classification where each instance
4//! can belong to multiple classes simultaneously.
5
6// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Predict, Untrained},
11    types::Float,
12};
13use std::collections::HashMap;
14
15/// Binary Relevance
16///
17/// A multi-label classification strategy that treats each label as a separate
18/// binary classification problem. For each label, it trains a binary classifier
19/// to predict whether that label is relevant or not, independently of other labels.
20///
21/// # Examples
22///
23/// ```
24/// use sklears_multioutput::BinaryRelevance;
25/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
26/// use scirs2_core::ndarray::array;
27///
28/// // This is a simple example showing the structure
29/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
30/// let labels = array![[1, 0], [0, 1], [1, 1]]; // Multi-label: each column is a binary label
31/// ```
32#[derive(Debug, Clone)]
33pub struct BinaryRelevance<S = Untrained> {
34    /// state
35    pub state: S,
36    n_jobs: Option<i32>,
37}
38
39impl BinaryRelevance<Untrained> {
40    /// Create a new BinaryRelevance instance
41    pub fn new() -> Self {
42        Self {
43            state: Untrained,
44            n_jobs: None,
45        }
46    }
47
48    /// Set the number of parallel jobs
49    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
50        self.n_jobs = n_jobs;
51        self
52    }
53}
54
55impl Default for BinaryRelevance<Untrained> {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl Estimator for BinaryRelevance<Untrained> {
62    type Config = ();
63    type Error = SklearsError;
64    type Float = Float;
65
66    fn config(&self) -> &Self::Config {
67        &()
68    }
69}
70
71impl Fit<ArrayView2<'_, Float>, Array2<i32>> for BinaryRelevance<Untrained> {
72    type Fitted = BinaryRelevance<BinaryRelevanceTrained>;
73
74    #[allow(non_snake_case)]
75    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
76        let X = X.to_owned();
77        let (n_samples, n_features) = X.dim();
78
79        if n_samples != y.nrows() {
80            return Err(SklearsError::InvalidInput(
81                "X and y must have the same number of samples".to_string(),
82            ));
83        }
84
85        let n_labels = y.ncols();
86        if n_labels == 0 {
87            return Err(SklearsError::InvalidInput(
88                "y must have at least one label".to_string(),
89            ));
90        }
91
92        let mut binary_classifiers = HashMap::new();
93        let mut classes_per_label = Vec::new();
94
95        // Train one binary classifier per label
96        for label_idx in 0..n_labels {
97            let y_label = y.column(label_idx);
98
99            // Get unique classes for this label (should be binary: 0 and 1)
100            let mut label_classes: Vec<i32> = y_label
101                .iter()
102                .cloned()
103                .collect::<std::collections::HashSet<_>>()
104                .into_iter()
105                .collect();
106            label_classes.sort();
107
108            // Validate that we have binary labels
109            if label_classes.len() > 2 {
110                return Err(SklearsError::InvalidInput(format!(
111                    "Label {} has {} classes, but BinaryRelevance expects binary labels",
112                    label_idx,
113                    label_classes.len()
114                )));
115            }
116
117            // Ensure we have at least one positive and one negative example for training
118            let has_positive = label_classes.contains(&1);
119            let has_negative = label_classes.contains(&0);
120
121            if !has_positive && !has_negative {
122                return Err(SklearsError::InvalidInput(format!(
123                    "Label {} has no training examples",
124                    label_idx
125                )));
126            }
127
128            // Train binary classifier using logistic regression approach
129            let weights = train_binary_classifier(&X, &y_label)?;
130            binary_classifiers.insert(label_idx, weights);
131            classes_per_label.push(label_classes);
132        }
133
134        Ok(BinaryRelevance {
135            state: BinaryRelevanceTrained {
136                binary_classifiers,
137                classes_per_label,
138                n_labels,
139                n_features,
140            },
141            n_jobs: self.n_jobs,
142        })
143    }
144}
145
146/// Simple binary classifier training using logistic regression approximation
147fn train_binary_classifier(
148    X: &Array2<Float>,
149    y: &scirs2_core::ndarray::ArrayView1<i32>,
150) -> SklResult<(Array1<f64>, f64)> {
151    let (n_samples, n_features) = X.dim();
152
153    // Simple approach: use correlation-based weights similar to linear regression
154    let mut weights = Array1::<Float>::zeros(n_features);
155    let mut bias = 0.0;
156
157    // Compute mean of labels (proportion of positive class)
158    let y_mean: f64 = y.iter().map(|&label| label as f64).sum::<f64>() / n_samples as f64;
159
160    // Use logit of the mean as initial bias
161    bias = if y_mean > 0.0 && y_mean < 1.0 {
162        (y_mean / (1.0 - y_mean)).ln()
163    } else if y_mean >= 1.0 {
164        2.0 // Large positive value
165    } else {
166        -2.0 // Large negative value
167    };
168
169    // Compute feature-label correlations
170    for feature_idx in 0..n_features {
171        let mut x_mean = 0.0;
172        for sample_idx in 0..n_samples {
173            x_mean += X[[sample_idx, feature_idx]];
174        }
175        x_mean /= n_samples as f64;
176
177        // Compute correlation between feature and label
178        let mut numerator: f64 = 0.0;
179        let mut x_var: f64 = 0.0;
180        let mut y_var: f64 = 0.0;
181
182        for sample_idx in 0..n_samples {
183            let x_diff = X[[sample_idx, feature_idx]] - x_mean;
184            let y_diff = y[sample_idx] as f64 - y_mean;
185            numerator += x_diff * y_diff;
186            x_var += x_diff * x_diff;
187            y_var += y_diff * y_diff;
188        }
189
190        if x_var > 1e-10 && y_var > 1e-10 {
191            let correlation = numerator / (x_var.sqrt() * y_var.sqrt());
192            weights[feature_idx] = correlation; // Use correlation as weight
193        }
194    }
195
196    Ok((weights, bias))
197}
198
199impl BinaryRelevance<BinaryRelevanceTrained> {
200    /// Get the classes for each label
201    pub fn classes(&self) -> &[Vec<i32>] {
202        &self.state.classes_per_label
203    }
204
205    /// Get the number of labels
206    pub fn n_labels(&self) -> usize {
207        self.state.n_labels
208    }
209
210    /// Get the number of features
211    pub fn n_features(&self) -> usize {
212        self.state.n_features
213    }
214}
215
216impl Predict<ArrayView2<'_, Float>, Array2<i32>> for BinaryRelevance<BinaryRelevanceTrained> {
217    #[allow(non_snake_case)]
218    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
219        let X = X.to_owned();
220        let (n_samples, n_features) = X.dim();
221
222        if n_features != self.state.n_features {
223            return Err(SklearsError::InvalidInput(
224                "Number of features doesn't match training data".to_string(),
225            ));
226        }
227
228        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
229
230        // Get predictions from each binary classifier
231        for label_idx in 0..self.state.n_labels {
232            if let Some((weights, bias)) = self.state.binary_classifiers.get(&label_idx) {
233                for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
234                    // Compute logistic regression score
235                    let score: f64 = sample
236                        .iter()
237                        .zip(weights.iter())
238                        .map(|(&x, &w)| x * w)
239                        .sum::<f64>()
240                        + bias;
241
242                    // Apply sigmoid and threshold at 0.5 for binary classification
243                    let prob = 1.0 / (1.0 + (-score).exp());
244                    let prediction = if prob > 0.5 { 1 } else { 0 };
245
246                    predictions[[sample_idx, label_idx]] = prediction;
247                }
248            }
249        }
250
251        Ok(predictions)
252    }
253}
254
255/// Predict probabilities for each label
256impl BinaryRelevance<BinaryRelevanceTrained> {
257    /// Predict class probabilities for each label
258    #[allow(non_snake_case)]
259    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
260        let X = X.to_owned();
261        let (n_samples, n_features) = X.dim();
262
263        if n_features != self.state.n_features {
264            return Err(SklearsError::InvalidInput(
265                "Number of features doesn't match training data".to_string(),
266            ));
267        }
268
269        let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
270
271        // Get probabilities from each binary classifier
272        for label_idx in 0..self.state.n_labels {
273            if let Some((weights, bias)) = self.state.binary_classifiers.get(&label_idx) {
274                for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
275                    // Compute logistic regression score
276                    let score: f64 = sample
277                        .iter()
278                        .zip(weights.iter())
279                        .map(|(&x, &w)| x * w)
280                        .sum::<f64>()
281                        + bias;
282
283                    // Apply sigmoid to get probability
284                    let prob = 1.0 / (1.0 + (-score).exp());
285                    probabilities[[sample_idx, label_idx]] = prob;
286                }
287            }
288        }
289
290        Ok(probabilities)
291    }
292}
293
294/// Trained state for BinaryRelevance
295#[derive(Debug, Clone)]
296pub struct BinaryRelevanceTrained {
297    /// Binary classifiers for each label (weights, bias)
298    pub binary_classifiers: HashMap<usize, (Array1<f64>, f64)>,
299    /// Classes for each label
300    pub classes_per_label: Vec<Vec<i32>>,
301    /// Number of labels
302    pub n_labels: usize,
303    /// Number of features
304    pub n_features: usize,
305}
306
307/// Label Powerset
308///
309/// A multi-label classification strategy that transforms the multi-label problem
310/// into a multi-class problem by treating each unique combination of labels as
311/// a single class. Each label combination becomes a distinct class in the transformed
312/// problem space.
313///
314/// # Examples
315///
316/// ```
317/// use sklears_multioutput::LabelPowerset;
318/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
319/// use scirs2_core::ndarray::array;
320///
321/// // This is a simple example showing the structure
322/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
323/// let labels = array![[1, 0], [0, 1], [1, 1]]; // Multi-label: each row is a label combination
324/// ```
325#[derive(Debug, Clone)]
326pub struct LabelPowerset<S = Untrained> {
327    state: S,
328}
329
330impl LabelPowerset<Untrained> {
331    /// Create a new LabelPowerset instance
332    pub fn new() -> Self {
333        Self { state: Untrained }
334    }
335}
336
337impl Default for LabelPowerset<Untrained> {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343impl Estimator for LabelPowerset<Untrained> {
344    type Config = ();
345    type Error = SklearsError;
346    type Float = Float;
347
348    fn config(&self) -> &Self::Config {
349        &()
350    }
351}
352
353impl Fit<ArrayView2<'_, Float>, Array2<i32>> for LabelPowerset<Untrained> {
354    type Fitted = LabelPowerset<LabelPowersetTrained>;
355
356    #[allow(non_snake_case)]
357    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
358        let X = X.to_owned();
359        let (n_samples, n_features) = X.dim();
360
361        if n_samples != y.nrows() {
362            return Err(SklearsError::InvalidInput(
363                "X and y must have the same number of samples".to_string(),
364            ));
365        }
366
367        let n_labels = y.ncols();
368        if n_labels == 0 {
369            return Err(SklearsError::InvalidInput(
370                "y must have at least one label".to_string(),
371            ));
372        }
373
374        // Validate that labels are binary
375        for sample_idx in 0..n_samples {
376            for label_idx in 0..n_labels {
377                let label_value = y[[sample_idx, label_idx]];
378                if label_value != 0 && label_value != 1 {
379                    return Err(SklearsError::InvalidInput(format!(
380                        "LabelPowerset expects binary labels, but found {} at position ({}, {})",
381                        label_value, sample_idx, label_idx
382                    )));
383                }
384            }
385        }
386
387        // Transform multi-label to multi-class by creating unique combinations
388        let mut class_to_combination: HashMap<usize, Vec<i32>> = HashMap::new();
389        let mut combination_to_class: HashMap<Vec<i32>, usize> = HashMap::new();
390        let mut transformed_labels = Vec::new();
391        let mut next_class_id = 0;
392
393        for sample_idx in 0..n_samples {
394            // Extract the label combination for this sample
395            let combination: Vec<i32> = (0..n_labels)
396                .map(|label_idx| y[[sample_idx, label_idx]])
397                .collect();
398
399            // Check if we've seen this combination before
400            let class_id = if let Some(&existing_class_id) = combination_to_class.get(&combination)
401            {
402                existing_class_id
403            } else {
404                // New combination, assign a new class ID
405                let class_id = next_class_id;
406                combination_to_class.insert(combination.clone(), class_id);
407                class_to_combination.insert(class_id, combination);
408                next_class_id += 1;
409                class_id
410            };
411
412            transformed_labels.push(class_id);
413        }
414
415        // Train a single multi-class classifier on the transformed problem
416        // We'll use a simple nearest centroid approach
417        let mut class_centroids: HashMap<usize, Array1<f64>> = HashMap::new();
418
419        for &class_id in class_to_combination.keys() {
420            let mut centroid = Array1::<Float>::zeros(n_features);
421            let mut count = 0;
422
423            // Compute centroid for this class
424            for (sample_idx, &sample_class) in transformed_labels.iter().enumerate() {
425                if sample_class == class_id {
426                    for feature_idx in 0..n_features {
427                        centroid[feature_idx] += X[[sample_idx, feature_idx]];
428                    }
429                    count += 1;
430                }
431            }
432
433            if count > 0 {
434                centroid /= count as f64;
435            }
436            class_centroids.insert(class_id, centroid);
437        }
438
439        let unique_classes: Vec<usize> = class_to_combination.keys().cloned().collect();
440
441        Ok(LabelPowerset {
442            state: LabelPowersetTrained {
443                class_to_combination,
444                combination_to_class,
445                class_centroids,
446                unique_classes,
447                n_labels,
448                n_features,
449            },
450        })
451    }
452}
453
454impl LabelPowerset<LabelPowersetTrained> {
455    /// Get the unique label combinations
456    pub fn classes(&self) -> &HashMap<usize, Vec<i32>> {
457        &self.state.class_to_combination
458    }
459
460    /// Get the number of unique label combinations
461    pub fn n_classes(&self) -> usize {
462        self.state.unique_classes.len()
463    }
464
465    /// Get the number of labels
466    pub fn n_labels(&self) -> usize {
467        self.state.n_labels
468    }
469}
470
471impl Predict<ArrayView2<'_, Float>, Array2<i32>> for LabelPowerset<LabelPowersetTrained> {
472    #[allow(non_snake_case)]
473    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
474        let X = X.to_owned();
475        let (n_samples, n_features) = X.dim();
476
477        if n_features != self.state.n_features {
478            return Err(SklearsError::InvalidInput(
479                "Number of features doesn't match training data".to_string(),
480            ));
481        }
482
483        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
484
485        // For each sample, find the nearest class centroid
486        for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
487            let mut min_distance = f64::INFINITY;
488            let mut best_class_id = 0;
489
490            // Find the closest class centroid
491            for (&class_id, centroid) in &self.state.class_centroids {
492                let mut distance = 0.0;
493                for feature_idx in 0..n_features {
494                    let diff = sample[feature_idx] - centroid[feature_idx];
495                    distance += diff * diff;
496                }
497                distance = distance.sqrt();
498
499                if distance < min_distance {
500                    min_distance = distance;
501                    best_class_id = class_id;
502                }
503            }
504
505            // Convert the predicted class back to label combination
506            if let Some(label_combination) = self.state.class_to_combination.get(&best_class_id) {
507                for label_idx in 0..self.state.n_labels {
508                    predictions[[sample_idx, label_idx]] = label_combination[label_idx];
509                }
510            }
511        }
512
513        Ok(predictions)
514    }
515}
516
517/// Get decision scores for each class
518impl LabelPowerset<LabelPowersetTrained> {
519    /// Predict decision scores (negative distances to centroids)
520    #[allow(non_snake_case)]
521    pub fn decision_function(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
522        let X = X.to_owned();
523        let (n_samples, n_features) = X.dim();
524
525        if n_features != self.state.n_features {
526            return Err(SklearsError::InvalidInput(
527                "Number of features doesn't match training data".to_string(),
528            ));
529        }
530
531        let n_classes = self.state.unique_classes.len();
532        let mut scores = Array2::<Float>::zeros((n_samples, n_classes));
533
534        // For each sample, compute distances to all class centroids
535        for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
536            for (class_idx, &class_id) in self.state.unique_classes.iter().enumerate() {
537                if let Some(centroid) = self.state.class_centroids.get(&class_id) {
538                    let mut distance = 0.0;
539                    for feature_idx in 0..n_features {
540                        let diff = sample[feature_idx] - centroid[feature_idx];
541                        distance += diff * diff;
542                    }
543                    distance = distance.sqrt();
544
545                    // Use negative distance as score (higher score = closer)
546                    scores[[sample_idx, class_idx]] = -distance;
547                }
548            }
549        }
550
551        Ok(scores)
552    }
553}
554
555/// Trained state for LabelPowerset
556#[derive(Debug, Clone)]
557pub struct LabelPowersetTrained {
558    /// Mapping from class ID to label combination
559    pub class_to_combination: HashMap<usize, Vec<i32>>,
560    /// Mapping from label combination to class ID
561    pub combination_to_class: HashMap<Vec<i32>, usize>,
562    /// Centroids for each class (nearest centroid classifier)
563    pub class_centroids: HashMap<usize, Array1<f64>>,
564    /// List of unique class IDs
565    pub unique_classes: Vec<usize>,
566    /// Number of labels
567    pub n_labels: usize,
568    /// Number of features
569    pub n_features: usize,
570}
571
572/// Pruned Label Powerset
573///
574/// An extension of the Label Powerset method that prunes rare label combinations
575/// to reduce the complexity of the multi-class problem. Label combinations that
576/// appear less than `min_frequency` times in the training data are either mapped
577/// to the most similar frequent combination or to a default combination.
578///
579/// # Examples
580///
581/// ```
582/// use sklears_multioutput::PrunedLabelPowerset;
583/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
584/// use scirs2_core::ndarray::array;
585///
586/// // This is a simple example showing the structure
587/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
588/// let labels = array![[0, 1], [1, 0], [1, 1]];
589/// ```
590#[derive(Debug, Clone)]
591pub struct PrunedLabelPowerset<S = Untrained> {
592    state: S,
593    min_frequency: usize,
594    strategy: PruningStrategy,
595}
596
597/// Strategy for handling pruned label combinations
598#[derive(Debug, Clone)]
599pub enum PruningStrategy {
600    /// Map rare combinations to the most similar frequent combination
601    SimilarityMapping,
602    /// Map rare combinations to a default combination (typically all zeros)
603    DefaultMapping(Vec<i32>),
604}
605
606impl PrunedLabelPowerset<Untrained> {
607    /// Create a new PrunedLabelPowerset instance
608    pub fn new() -> Self {
609        Self {
610            state: Untrained,
611            min_frequency: 2,
612            strategy: PruningStrategy::DefaultMapping(vec![]),
613        }
614    }
615
616    /// Set the minimum frequency threshold for label combinations
617    pub fn min_frequency(mut self, min_frequency: usize) -> Self {
618        self.min_frequency = min_frequency;
619        self
620    }
621
622    /// Set the pruning strategy
623    pub fn strategy(mut self, strategy: PruningStrategy) -> Self {
624        self.strategy = strategy;
625        self
626    }
627
628    /// Get the minimum frequency threshold
629    pub fn get_min_frequency(&self) -> usize {
630        self.min_frequency
631    }
632
633    /// Get the pruning strategy
634    pub fn get_strategy(&self) -> &PruningStrategy {
635        &self.strategy
636    }
637}
638
639impl Default for PrunedLabelPowerset<Untrained> {
640    fn default() -> Self {
641        Self::new()
642    }
643}
644
645impl Estimator for PrunedLabelPowerset<Untrained> {
646    type Config = ();
647    type Error = SklearsError;
648    type Float = Float;
649
650    fn config(&self) -> &Self::Config {
651        &()
652    }
653}
654
655impl Fit<ArrayView2<'_, Float>, Array2<i32>> for PrunedLabelPowerset<Untrained> {
656    type Fitted = PrunedLabelPowerset<PrunedLabelPowersetTrained>;
657
658    #[allow(non_snake_case)]
659    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
660        let X = X.to_owned();
661        let (n_samples, n_features) = X.dim();
662
663        if n_samples != y.nrows() {
664            return Err(SklearsError::InvalidInput(
665                "X and y must have the same number of samples".to_string(),
666            ));
667        }
668
669        let n_labels = y.ncols();
670        if n_labels == 0 {
671            return Err(SklearsError::InvalidInput(
672                "y must have at least one label".to_string(),
673            ));
674        }
675
676        // Validate that labels are binary
677        for sample_idx in 0..n_samples {
678            for label_idx in 0..n_labels {
679                let label_value = y[[sample_idx, label_idx]];
680                if label_value != 0 && label_value != 1 {
681                    return Err(SklearsError::InvalidInput(format!(
682                        "PrunedLabelPowerset expects binary labels, but found {} at position ({}, {})",
683                        label_value, sample_idx, label_idx
684                    )));
685                }
686            }
687        }
688
689        // Count frequencies of label combinations
690        let mut combination_counts: HashMap<Vec<i32>, usize> = HashMap::new();
691        for sample_idx in 0..n_samples {
692            let combination: Vec<i32> = (0..n_labels)
693                .map(|label_idx| y[[sample_idx, label_idx]])
694                .collect();
695            *combination_counts.entry(combination).or_insert(0) += 1;
696        }
697
698        // Determine which combinations to keep (frequent ones)
699        let frequent_combinations: Vec<Vec<i32>> = combination_counts
700            .iter()
701            .filter(|(_, &count)| count >= self.min_frequency)
702            .map(|(combination, _)| combination.clone())
703            .collect();
704
705        if frequent_combinations.is_empty() {
706            return Err(SklearsError::InvalidInput(
707                "No label combinations meet the minimum frequency threshold".to_string(),
708            ));
709        }
710
711        // Handle default strategy setup
712        let default_combination = match &self.strategy {
713            PruningStrategy::DefaultMapping(ref default) => {
714                if default.is_empty() {
715                    vec![0; n_labels] // Default to all zeros if not specified
716                } else if default.len() != n_labels {
717                    return Err(SklearsError::InvalidInput(
718                        "Default combination length must match number of labels".to_string(),
719                    ));
720                } else {
721                    default.clone()
722                }
723            }
724            PruningStrategy::SimilarityMapping => vec![], // Not used for similarity mapping
725        };
726
727        // For default mapping strategy, ensure the default combination is included
728        // in frequent combinations if it's not already there
729        let mut final_frequent_combinations = frequent_combinations.clone();
730        if let PruningStrategy::DefaultMapping(_) = &self.strategy {
731            if !final_frequent_combinations.contains(&default_combination) {
732                final_frequent_combinations.push(default_combination.clone());
733            }
734        }
735
736        // Create mapping for rare combinations
737        let mut combination_mapping: HashMap<Vec<i32>, Vec<i32>> = HashMap::new();
738
739        for (combination, &count) in &combination_counts {
740            if count >= self.min_frequency {
741                // Keep frequent combinations as-is
742                combination_mapping.insert(combination.clone(), combination.clone());
743            } else {
744                // Map rare combinations based on strategy
745                let mapped_combination = match &self.strategy {
746                    PruningStrategy::SimilarityMapping => {
747                        // Find the most similar frequent combination
748                        let mut best_similarity = -1.0;
749                        let mut best_combination = &final_frequent_combinations[0];
750
751                        for freq_combo in &final_frequent_combinations {
752                            // Compute Jaccard similarity
753                            let intersection: i32 = combination
754                                .iter()
755                                .zip(freq_combo.iter())
756                                .map(|(&a, &b)| if a == 1 && b == 1 { 1 } else { 0 })
757                                .sum();
758                            let union: i32 = combination
759                                .iter()
760                                .zip(freq_combo.iter())
761                                .map(|(&a, &b)| if a == 1 || b == 1 { 1 } else { 0 })
762                                .sum();
763
764                            let similarity = if union > 0 {
765                                intersection as f64 / union as f64
766                            } else {
767                                1.0 // Both empty sets
768                            };
769
770                            if similarity > best_similarity {
771                                best_similarity = similarity;
772                                best_combination = freq_combo;
773                            }
774                        }
775                        best_combination.clone()
776                    }
777                    PruningStrategy::DefaultMapping(_) => default_combination.clone(),
778                };
779                combination_mapping.insert(combination.clone(), mapped_combination);
780            }
781        }
782
783        // Create the final class mapping using only frequent combinations
784        let mut class_to_combination: HashMap<usize, Vec<i32>> = HashMap::new();
785        let mut combination_to_class: HashMap<Vec<i32>, usize> = HashMap::new();
786
787        for (next_class_id, combo) in final_frequent_combinations.iter().enumerate() {
788            class_to_combination.insert(next_class_id, combo.clone());
789            combination_to_class.insert(combo.clone(), next_class_id);
790        }
791
792        // Transform labels using the mapping
793        let mut transformed_labels = Vec::new();
794        for sample_idx in 0..n_samples {
795            let original_combination: Vec<i32> = (0..n_labels)
796                .map(|label_idx| y[[sample_idx, label_idx]])
797                .collect();
798
799            let mapped_combination = combination_mapping
800                .get(&original_combination)
801                .unwrap()
802                .clone();
803
804            let class_id = *combination_to_class.get(&mapped_combination).unwrap();
805            transformed_labels.push(class_id);
806        }
807
808        // Train a nearest centroid classifier on the pruned problem
809        let mut class_centroids: HashMap<usize, Array1<f64>> = HashMap::new();
810
811        for &class_id in class_to_combination.keys() {
812            let mut centroid = Array1::<Float>::zeros(n_features);
813            let mut count = 0;
814
815            for (sample_idx, &sample_class) in transformed_labels.iter().enumerate() {
816                if sample_class == class_id {
817                    for feature_idx in 0..n_features {
818                        centroid[feature_idx] += X[[sample_idx, feature_idx]];
819                    }
820                    count += 1;
821                }
822            }
823
824            if count > 0 {
825                centroid /= count as f64;
826            }
827            class_centroids.insert(class_id, centroid);
828        }
829
830        let unique_classes: Vec<usize> = class_to_combination.keys().cloned().collect();
831
832        Ok(PrunedLabelPowerset {
833            state: PrunedLabelPowersetTrained {
834                class_to_combination,
835                combination_to_class,
836                combination_mapping,
837                class_centroids,
838                unique_classes,
839                frequent_combinations: final_frequent_combinations,
840                n_labels,
841                n_features,
842                min_frequency: self.min_frequency,
843                strategy: self.strategy.clone(),
844            },
845            min_frequency: self.min_frequency,
846            strategy: self.strategy.clone(),
847        })
848    }
849}
850
851impl PrunedLabelPowerset<PrunedLabelPowersetTrained> {
852    /// Get the frequent label combinations
853    pub fn frequent_combinations(&self) -> &[Vec<i32>] {
854        &self.state.frequent_combinations
855    }
856
857    /// Get the number of frequent combinations
858    pub fn n_frequent_classes(&self) -> usize {
859        self.state.unique_classes.len()
860    }
861
862    /// Get the combination mapping used for pruning
863    pub fn combination_mapping(&self) -> &HashMap<Vec<i32>, Vec<i32>> {
864        &self.state.combination_mapping
865    }
866
867    /// Get the minimum frequency threshold used
868    pub fn min_frequency(&self) -> usize {
869        self.state.min_frequency
870    }
871
872    /// Get the number of features
873    pub fn n_features(&self) -> usize {
874        self.state.n_features
875    }
876
877    /// Get the number of labels
878    pub fn n_labels(&self) -> usize {
879        self.state.n_labels
880    }
881
882    /// Get the class centroids
883    pub fn class_centroids(&self) -> &HashMap<usize, Array1<f64>> {
884        &self.state.class_centroids
885    }
886
887    /// Get the class to combination mapping
888    pub fn class_to_combination(&self) -> &HashMap<usize, Vec<i32>> {
889        &self.state.class_to_combination
890    }
891}
892
893impl Predict<ArrayView2<'_, Float>, Array2<i32>>
894    for PrunedLabelPowerset<PrunedLabelPowersetTrained>
895{
896    #[allow(non_snake_case)]
897    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
898        let X = X.to_owned();
899        let (n_samples, n_features) = X.dim();
900
901        if n_features != self.state.n_features {
902            return Err(SklearsError::InvalidInput(
903                "Number of features doesn't match training data".to_string(),
904            ));
905        }
906
907        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
908
909        // For each sample, find the nearest class centroid
910        for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
911            let mut min_distance = f64::INFINITY;
912            let mut best_class_id = 0;
913
914            // Find the closest class centroid among frequent combinations
915            for (&class_id, centroid) in &self.state.class_centroids {
916                let mut distance = 0.0;
917                for feature_idx in 0..n_features {
918                    let diff = sample[feature_idx] - centroid[feature_idx];
919                    distance += diff * diff;
920                }
921                distance = distance.sqrt();
922
923                if distance < min_distance {
924                    min_distance = distance;
925                    best_class_id = class_id;
926                }
927            }
928
929            // Convert the predicted class back to label combination
930            if let Some(label_combination) = self.state.class_to_combination.get(&best_class_id) {
931                for label_idx in 0..self.state.n_labels {
932                    predictions[[sample_idx, label_idx]] = label_combination[label_idx];
933                }
934            }
935        }
936
937        Ok(predictions)
938    }
939}
940
941/// Trained state for PrunedLabelPowerset
942#[derive(Debug, Clone)]
943pub struct PrunedLabelPowersetTrained {
944    /// Mapping from class ID to frequent label combination
945    pub class_to_combination: HashMap<usize, Vec<i32>>,
946    /// Mapping from frequent label combination to class ID
947    pub combination_to_class: HashMap<Vec<i32>, usize>,
948    /// Mapping from original combinations to frequent combinations
949    pub combination_mapping: HashMap<Vec<i32>, Vec<i32>>,
950    /// Centroids for each frequent class
951    pub class_centroids: HashMap<usize, Array1<f64>>,
952    /// List of unique class IDs for frequent combinations
953    pub unique_classes: Vec<usize>,
954    /// The frequent label combinations that were kept
955    pub frequent_combinations: Vec<Vec<i32>>,
956    /// Number of labels
957    pub n_labels: usize,
958    /// Number of features
959    pub n_features: usize,
960    /// Minimum frequency threshold used
961    pub min_frequency: usize,
962    /// Pruning strategy used
963    pub strategy: PruningStrategy,
964}
965
966/// One-vs-Rest Classifier
967///
968/// A multi-label classification strategy that treats each label as a separate
969/// binary classification problem using a one-vs-rest approach. This is essentially
970/// the same as BinaryRelevance but with explicit one-vs-rest semantics.
971///
972/// # Examples
973///
974/// ```
975/// use sklears_multioutput::OneVsRestClassifier;
976/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
977/// use scirs2_core::ndarray::array;
978///
979/// // This is a simple example showing the structure
980/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
981/// let labels = array![[1, 0], [0, 1], [1, 1]];
982/// ```
983#[derive(Debug, Clone)]
984pub struct OneVsRestClassifier<S = Untrained> {
985    state: S,
986    n_jobs: Option<i32>,
987}
988
989impl OneVsRestClassifier<Untrained> {
990    /// Create a new OneVsRestClassifier instance
991    pub fn new() -> Self {
992        Self {
993            state: Untrained,
994            n_jobs: None,
995        }
996    }
997
998    /// Set the number of parallel jobs
999    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
1000        self.n_jobs = n_jobs;
1001        self
1002    }
1003}
1004
1005impl Default for OneVsRestClassifier<Untrained> {
1006    fn default() -> Self {
1007        Self::new()
1008    }
1009}
1010
1011impl Estimator for OneVsRestClassifier<Untrained> {
1012    type Config = ();
1013    type Error = SklearsError;
1014    type Float = Float;
1015
1016    fn config(&self) -> &Self::Config {
1017        &()
1018    }
1019}
1020
1021impl Fit<ArrayView2<'_, Float>, Array2<i32>> for OneVsRestClassifier<Untrained> {
1022    type Fitted = OneVsRestClassifier<OneVsRestClassifierTrained>;
1023
1024    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
1025        // Delegate to BinaryRelevance implementation
1026        let br = BinaryRelevance::new().n_jobs(self.n_jobs);
1027        let fitted_br = br.fit(X, y)?;
1028
1029        Ok(OneVsRestClassifier {
1030            state: OneVsRestClassifierTrained {
1031                binary_relevance: fitted_br,
1032            },
1033            n_jobs: self.n_jobs,
1034        })
1035    }
1036}
1037
1038impl OneVsRestClassifier<OneVsRestClassifierTrained> {
1039    /// Get the classes for each label
1040    pub fn classes(&self) -> &[Vec<i32>] {
1041        self.state.binary_relevance.classes()
1042    }
1043
1044    /// Get the number of labels
1045    pub fn n_labels(&self) -> usize {
1046        self.state.binary_relevance.n_labels()
1047    }
1048}
1049
1050impl Predict<ArrayView2<'_, Float>, Array2<i32>>
1051    for OneVsRestClassifier<OneVsRestClassifierTrained>
1052{
1053    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
1054        self.state.binary_relevance.predict(X)
1055    }
1056}
1057
1058impl OneVsRestClassifier<OneVsRestClassifierTrained> {
1059    /// Predict class probabilities for each label
1060    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1061        self.state.binary_relevance.predict_proba(X)
1062    }
1063
1064    /// Get decision scores for each label
1065    #[allow(non_snake_case)]
1066    pub fn decision_function(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1067        let X = X.to_owned();
1068        let (n_samples, n_features) = X.dim();
1069
1070        if n_features != self.state.binary_relevance.n_features() {
1071            return Err(SklearsError::InvalidInput(
1072                "Number of features doesn't match training data".to_string(),
1073            ));
1074        }
1075
1076        let mut scores =
1077            Array2::<Float>::zeros((n_samples, self.state.binary_relevance.n_labels()));
1078
1079        // Get decision scores from each binary classifier (raw logistic regression scores)
1080        for label_idx in 0..self.state.binary_relevance.n_labels() {
1081            if let Some((weights, bias)) = self
1082                .state
1083                .binary_relevance
1084                .state
1085                .binary_classifiers
1086                .get(&label_idx)
1087            {
1088                for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
1089                    let score: f64 = sample
1090                        .iter()
1091                        .zip(weights.iter())
1092                        .map(|(&x, &w)| x * w)
1093                        .sum::<f64>()
1094                        + bias;
1095
1096                    scores[[sample_idx, label_idx]] = score;
1097                }
1098            }
1099        }
1100
1101        Ok(scores)
1102    }
1103}
1104
1105/// Trained state for OneVsRestClassifier
1106#[derive(Debug, Clone)]
1107pub struct OneVsRestClassifierTrained {
1108    /// The underlying BinaryRelevance classifier
1109    pub binary_relevance: BinaryRelevance<BinaryRelevanceTrained>,
1110}