1use crate::{UtilsError, UtilsResult};
2use scirs2_core::ndarray::{Array1, Array2, Axis};
3use std::collections::{HashMap, HashSet};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
6pub enum MultiClassStrategy {
7 OneVsRest,
8 OneVsOne,
9}
10
11pub struct OneVsRestClassifier<C> {
12 pub estimators: Vec<C>,
13 pub classes: Vec<i32>,
14 pub strategy: MultiClassStrategy,
15}
16
17impl<C> Default for OneVsRestClassifier<C> {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl<C> OneVsRestClassifier<C> {
24 pub fn new() -> Self {
25 Self {
26 estimators: Vec::new(),
27 classes: Vec::new(),
28 strategy: MultiClassStrategy::OneVsRest,
29 }
30 }
31}
32
33pub fn type_of_target(y: &Array1<i32>) -> UtilsResult<String> {
34 if y.is_empty() {
35 return Err(UtilsError::EmptyInput);
36 }
37
38 let unique_values: HashSet<i32> = y.iter().copied().collect();
39 let n_unique = unique_values.len();
40
41 if n_unique == 1 {
42 Ok("unknown".to_string())
43 } else if n_unique == 2 {
44 Ok("binary".to_string())
45 } else {
46 Ok("multiclass".to_string())
47 }
48}
49
50pub fn check_classification_targets(y: &Array1<i32>) -> UtilsResult<()> {
51 let target_type = type_of_target(y)?;
52
53 match target_type.as_str() {
54 "binary" | "multiclass" => Ok(()),
55 "unknown" => Err(UtilsError::InvalidParameter(
56 "Unknown label type: all samples have the same label".to_string(),
57 )),
58 _ => Err(UtilsError::InvalidParameter(format!(
59 "Unknown target type: {target_type}"
60 ))),
61 }
62}
63
64pub fn unique_labels_multiclass(y: &Array1<i32>) -> Vec<i32> {
65 let mut unique: Vec<i32> = y.iter().copied().collect();
66 unique.sort();
67 unique.dedup();
68 unique
69}
70
71pub fn class_distribution(y: &Array1<i32>) -> HashMap<i32, usize> {
72 let mut counts = HashMap::new();
73 for &label in y.iter() {
74 *counts.entry(label).or_insert(0) += 1;
75 }
76 counts
77}
78
79pub fn check_multi_class(y: &Array1<i32>) -> UtilsResult<bool> {
80 let unique_labels = unique_labels_multiclass(y);
81
82 if unique_labels.len() < 2 {
83 Err(UtilsError::InvalidParameter(
84 "Need at least 2 classes for classification".to_string(),
85 ))
86 } else {
87 Ok(unique_labels.len() > 2)
88 }
89}
90
91pub fn one_vs_rest_transform(y: &Array1<i32>, positive_class: i32) -> Array1<i32> {
92 y.mapv(|label| if label == positive_class { 1 } else { 0 })
93}
94
95pub fn one_vs_one_pairs(classes: &[i32]) -> Vec<(i32, i32)> {
96 let mut pairs = Vec::new();
97
98 for (i, &class_a) in classes.iter().enumerate() {
99 for &class_b in classes.iter().skip(i + 1) {
100 pairs.push((class_a, class_b));
101 }
102 }
103
104 pairs
105}
106
107pub fn one_vs_one_transform(
108 y: &Array1<i32>,
109 class_a: i32,
110 class_b: i32,
111) -> (Array1<i32>, Vec<usize>) {
112 let mut new_y = Vec::new();
113 let mut indices = Vec::new();
114
115 for (i, &label) in y.iter().enumerate() {
116 if label == class_a || label == class_b {
117 new_y.push(if label == class_a { 0 } else { 1 });
118 indices.push(i);
119 }
120 }
121
122 (Array1::from_vec(new_y), indices)
123}
124
125pub fn is_multilabel(y: &Array2<i32>) -> bool {
126 y.ncols() > 1
127}
128
129pub fn multilabel_to_indicator(y: &Array1<i32>, classes: &[i32]) -> UtilsResult<Array2<i32>> {
130 let n_samples = y.len();
131 let n_classes = classes.len();
132 let mut indicator = Array2::zeros((n_samples, n_classes));
133
134 for (i, &label) in y.iter().enumerate() {
135 if let Some(class_idx) = classes.iter().position(|&c| c == label) {
136 indicator[[i, class_idx]] = 1;
137 } else {
138 return Err(UtilsError::InvalidParameter(format!(
139 "Label {label} not found in classes"
140 )));
141 }
142 }
143
144 Ok(indicator)
145}
146
147pub fn indicator_to_multilabel(
148 indicator: &Array2<i32>,
149 classes: &[i32],
150) -> UtilsResult<Vec<Vec<i32>>> {
151 if indicator.ncols() != classes.len() {
152 return Err(UtilsError::ShapeMismatch {
153 expected: vec![indicator.nrows(), classes.len()],
154 actual: vec![indicator.nrows(), indicator.ncols()],
155 });
156 }
157
158 let mut result = Vec::new();
159
160 for row in indicator.axis_iter(Axis(0)) {
161 let mut labels = Vec::new();
162 for (j, &value) in row.iter().enumerate() {
163 if value == 1 {
164 labels.push(classes[j]);
165 }
166 }
167 result.push(labels);
168 }
169
170 Ok(result)
171}
172
173pub fn check_binary_indicators_multioutput(y: &Array2<i32>) -> UtilsResult<()> {
174 for value in y.iter() {
175 if *value != 0 && *value != 1 {
176 return Err(UtilsError::InvalidParameter(
177 "Binary indicators must contain only 0 and 1".to_string(),
178 ));
179 }
180 }
181 Ok(())
182}
183
184pub fn compute_class_weight_balanced(y: &Array1<i32>) -> HashMap<i32, f64> {
185 let class_counts = class_distribution(y);
186 let n_samples = y.len() as f64;
187 let n_classes = class_counts.len() as f64;
188
189 let mut weights = HashMap::new();
190 for (&class, &count) in &class_counts {
191 weights.insert(class, n_samples / (n_classes * count as f64));
192 }
193
194 weights
195}
196
197pub fn compute_sample_weight(y: &Array1<i32>, class_weight: &HashMap<i32, f64>) -> Array1<f64> {
198 let mut sample_weights = Array1::zeros(y.len());
199
200 for (i, &label) in y.iter().enumerate() {
201 sample_weights[i] = class_weight.get(&label).copied().unwrap_or(1.0);
202 }
203
204 sample_weights
205}
206
207#[allow(non_snake_case)]
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use scirs2_core::ndarray::array;
212
213 #[test]
214 fn test_type_of_target() {
215 let binary = array![0, 1, 0, 1];
216 assert_eq!(type_of_target(&binary).unwrap(), "binary");
217
218 let multiclass = array![0, 1, 2, 0, 1, 2];
219 assert_eq!(type_of_target(&multiclass).unwrap(), "multiclass");
220
221 let constant = array![1, 1, 1, 1];
222 assert_eq!(type_of_target(&constant).unwrap(), "unknown");
223 }
224
225 #[test]
226 fn test_check_multi_class() {
227 let binary = array![0, 1, 0, 1];
228 assert!(!check_multi_class(&binary).unwrap());
229
230 let multiclass = array![0, 1, 2, 0, 1, 2];
231 assert!(check_multi_class(&multiclass).unwrap());
232 }
233
234 #[test]
235 fn test_one_vs_rest_transform() {
236 let y = array![0, 1, 2, 0, 1, 2];
237 let binary_y = one_vs_rest_transform(&y, 1);
238 assert_eq!(binary_y, array![0, 1, 0, 0, 1, 0]);
239 }
240
241 #[test]
242 fn test_one_vs_one_pairs() {
243 let classes = vec![0, 1, 2];
244 let pairs = one_vs_one_pairs(&classes);
245 assert_eq!(pairs, vec![(0, 1), (0, 2), (1, 2)]);
246 }
247
248 #[test]
249 fn test_one_vs_one_transform() {
250 let y = array![0, 1, 2, 0, 1, 2];
251 let (binary_y, indices) = one_vs_one_transform(&y, 0, 2);
252 assert_eq!(binary_y, array![0, 1, 0, 1]);
253 assert_eq!(indices, vec![0, 2, 3, 5]);
254 }
255
256 #[test]
257 fn test_multilabel_to_indicator() {
258 let y = array![0, 1, 2];
259 let classes = vec![0, 1, 2];
260 let indicator = multilabel_to_indicator(&y, &classes).unwrap();
261
262 let expected = Array2::from_shape_vec((3, 3), vec![1, 0, 0, 0, 1, 0, 0, 0, 1]).unwrap();
263 assert_eq!(indicator, expected);
264 }
265
266 #[test]
267 fn test_compute_class_weight_balanced() {
268 let y = array![0, 0, 1, 1, 1, 2]; let weights = compute_class_weight_balanced(&y);
270
271 assert!((weights[&0] - 1.0).abs() < 1e-10); assert!((weights[&1] - 2.0 / 3.0).abs() < 1e-10); assert!((weights[&2] - 2.0).abs() < 1e-10); }
276
277 #[test]
278 fn test_class_distribution() {
279 let y = array![0, 1, 0, 2, 1, 1];
280 let dist = class_distribution(&y);
281
282 assert_eq!(dist[&0], 2);
283 assert_eq!(dist[&1], 3);
284 assert_eq!(dist[&2], 1);
285 }
286}