1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::{thread_rng, Rng};
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, Untrained},
13 types::Float,
14};
15
16#[derive(Debug, Clone)]
21pub struct CalibratedBinaryRelevance<S = Untrained> {
22 state: S,
23 calibration_method: CalibrationMethod,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum CalibrationMethod {
29 Platt,
31 Isotonic,
33}
34
35#[derive(Debug, Clone)]
37pub struct CalibratedBinaryRelevanceTrained {
38 base_models: Vec<(Array1<Float>, Float)>, calibration_params: Vec<(Float, Float)>, calibration_method: CalibrationMethod,
41 n_features: usize,
42 n_labels: usize,
43}
44
45impl Default for CalibratedBinaryRelevance<Untrained> {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl Estimator for CalibratedBinaryRelevance<Untrained> {
52 type Config = ();
53 type Error = SklearsError;
54 type Float = Float;
55
56 fn config(&self) -> &Self::Config {
57 &()
58 }
59}
60
61impl Fit<ArrayView2<'_, Float>, Array2<i32>> for CalibratedBinaryRelevance<Untrained> {
62 type Fitted = CalibratedBinaryRelevance<CalibratedBinaryRelevanceTrained>;
63
64 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
65 let (n_samples, n_features) = X.dim();
66 let n_labels = y.ncols();
67
68 if n_samples != y.nrows() {
69 return Err(SklearsError::InvalidInput(
70 "X and y must have the same number of samples".to_string(),
71 ));
72 }
73
74 let mut base_models = Vec::new();
75 let mut calibration_params = Vec::new();
76
77 for label_idx in 0..n_labels {
79 let y_label = y.column(label_idx);
80
81 let mut weights = Array1::<Float>::zeros(n_features);
83 let mut bias = 0.0;
84 let learning_rate = 0.01;
85 let max_iter = 100;
86
87 for _iter in 0..max_iter {
89 let mut weight_gradient = Array1::<Float>::zeros(n_features);
90 let mut bias_gradient = 0.0;
91
92 for sample_idx in 0..n_samples {
93 let x = X.row(sample_idx);
94 let y_true = y_label[sample_idx] as Float;
95
96 let logit = x.dot(&weights) + bias;
97 let prob = 1.0 / (1.0 + (-logit).exp());
98 let error = prob - y_true;
99
100 for feat_idx in 0..n_features {
102 weight_gradient[feat_idx] += error * x[feat_idx];
103 }
104 bias_gradient += error;
105 }
106
107 for i in 0..n_features {
109 weights[i] -= learning_rate * weight_gradient[i] / n_samples as Float;
110 }
111 bias -= learning_rate * bias_gradient / n_samples as Float;
112 }
113
114 let mut probs = Vec::new();
116 let mut labels = Vec::new();
117 for sample_idx in 0..n_samples {
118 let x = X.row(sample_idx);
119 let logit = x.dot(&weights) + bias;
120 let prob = 1.0 / (1.0 + (-logit).exp());
121 probs.push(prob);
122 labels.push(y_label[sample_idx] as Float);
123 }
124
125 let (slope, intercept) = self.fit_calibration(&probs, &labels)?;
127
128 base_models.push((weights, bias));
129 calibration_params.push((slope, intercept));
130 }
131
132 Ok(CalibratedBinaryRelevance {
133 state: CalibratedBinaryRelevanceTrained {
134 base_models,
135 calibration_params,
136 calibration_method: self.calibration_method,
137 n_features,
138 n_labels,
139 },
140 calibration_method: self.calibration_method,
141 })
142 }
143}
144
145impl CalibratedBinaryRelevance<Untrained> {
146 pub fn new() -> Self {
148 Self {
149 state: Untrained,
150 calibration_method: CalibrationMethod::Platt,
151 }
152 }
153
154 pub fn calibration_method(mut self, method: CalibrationMethod) -> Self {
156 self.calibration_method = method;
157 self
158 }
159
160 fn fit_calibration(&self, probs: &[Float], labels: &[Float]) -> SklResult<(Float, Float)> {
162 match self.calibration_method {
164 CalibrationMethod::Platt => {
165 let mut a = -1.0;
168 let mut b = 0.0;
169 let learning_rate = 0.01;
170
171 for _iter in 0..100 {
172 let mut grad_a = 0.0;
173 let mut grad_b = 0.0;
174
175 for (i, &prob) in probs.iter().enumerate() {
176 let y_true = labels[i];
177 let logit = a * prob + b;
178 let cal_prob = 1.0 / (1.0 + (-logit).exp());
179 let error = cal_prob - y_true;
180
181 grad_a += error * prob;
182 grad_b += error;
183 }
184
185 a -= learning_rate * grad_a / probs.len() as Float;
186 b -= learning_rate * grad_b / probs.len() as Float;
187 }
188
189 Ok((a, b))
190 }
191 CalibrationMethod::Isotonic => {
192 Ok((-1.0, 0.0))
194 }
195 }
196 }
197}
198
199impl Predict<ArrayView2<'_, Float>, Array2<i32>>
200 for CalibratedBinaryRelevance<CalibratedBinaryRelevanceTrained>
201{
202 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
203 let (n_samples, n_features) = X.dim();
204
205 if n_features != self.state.n_features {
206 return Err(SklearsError::InvalidInput(
207 "X has different number of features than training data".to_string(),
208 ));
209 }
210
211 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
212
213 for sample_idx in 0..n_samples {
214 let x = X.row(sample_idx);
215
216 for label_idx in 0..self.state.n_labels {
217 let (weights, bias) = &self.state.base_models[label_idx];
218 let (slope, intercept) = self.state.calibration_params[label_idx];
219
220 let logit = x.dot(weights) + bias;
222 let base_prob = 1.0 / (1.0 + (-logit).exp());
223
224 let cal_logit = slope * base_prob + intercept;
226 let cal_prob = 1.0 / (1.0 + (-cal_logit).exp());
227
228 predictions[[sample_idx, label_idx]] = if cal_prob > 0.5 { 1 } else { 0 };
229 }
230 }
231
232 Ok(predictions)
233 }
234}
235
236impl CalibratedBinaryRelevance<CalibratedBinaryRelevanceTrained> {
237 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
239 let (n_samples, n_features) = X.dim();
240
241 if n_features != self.state.n_features {
242 return Err(SklearsError::InvalidInput(
243 "X has different number of features than training data".to_string(),
244 ));
245 }
246
247 let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
248
249 for sample_idx in 0..n_samples {
250 let x = X.row(sample_idx);
251
252 for label_idx in 0..self.state.n_labels {
253 let (weights, bias) = &self.state.base_models[label_idx];
254 let (slope, intercept) = self.state.calibration_params[label_idx];
255
256 let logit = x.dot(weights) + bias;
258 let base_prob = 1.0 / (1.0 + (-logit).exp());
259
260 let cal_logit = slope * base_prob + intercept;
262 let cal_prob = 1.0 / (1.0 + (-cal_logit).exp());
263
264 probabilities[[sample_idx, label_idx]] = cal_prob;
265 }
266 }
267
268 Ok(probabilities)
269 }
270}
271
272pub struct RandomLabelCombinations {
277 n_labels: usize,
278 n_combinations: usize,
279 label_density: Float,
280 random_state: Option<u64>,
281}
282
283impl RandomLabelCombinations {
284 pub fn new(n_labels: usize) -> Self {
286 Self {
287 n_labels,
288 n_combinations: 100,
289 label_density: 0.3,
290 random_state: None,
291 }
292 }
293
294 pub fn n_combinations(mut self, n_combinations: usize) -> Self {
296 self.n_combinations = n_combinations;
297 self
298 }
299
300 pub fn label_density(mut self, density: Float) -> Self {
302 self.label_density = density;
303 self
304 }
305
306 pub fn random_state(mut self, seed: u64) -> Self {
308 self.random_state = Some(seed);
309 self
310 }
311
312 pub fn generate(&self) -> Array2<i32> {
314 let mut rng = if let Some(_seed) = self.random_state {
315 thread_rng()
317 } else {
318 thread_rng()
319 };
320
321 let mut combinations = Array2::<i32>::zeros((self.n_combinations, self.n_labels));
322
323 for i in 0..self.n_combinations {
324 for j in 0..self.n_labels {
325 combinations[[i, j]] = if rng.gen::<Float>() < self.label_density {
326 1
327 } else {
328 0
329 };
330 }
331 }
332
333 combinations
334 }
335}
336
337#[derive(Debug, Clone)]
343pub struct MLkNN<S = Untrained> {
344 state: S,
345 k: usize,
346 smooth: Float,
347 distance_metric: DistanceMetric,
348}
349
350#[derive(Debug, Clone, Copy, PartialEq)]
352pub enum DistanceMetric {
353 Euclidean,
355 Manhattan,
357 Cosine,
359}
360
361#[derive(Debug, Clone)]
363pub struct MLkNNTrained {
364 training_data: Array2<Float>,
365 training_labels: Array2<i32>,
366 prior_probs: Array1<Float>,
367 conditional_probs: Array2<Float>, k: usize,
369 smooth: Float,
370 distance_metric: DistanceMetric,
371 n_labels: usize,
372}
373
374impl Default for MLkNN<Untrained> {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380impl Estimator for MLkNN<Untrained> {
381 type Config = ();
382 type Error = SklearsError;
383 type Float = Float;
384
385 fn config(&self) -> &Self::Config {
386 &()
387 }
388}
389
390impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MLkNN<Untrained> {
391 type Fitted = MLkNN<MLkNNTrained>;
392
393 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
394 let (n_samples, n_features) = X.dim();
395 let n_labels = y.ncols();
396
397 if n_samples != y.nrows() {
398 return Err(SklearsError::InvalidInput(
399 "X and y must have the same number of samples".to_string(),
400 ));
401 }
402
403 if self.k >= n_samples {
404 return Err(SklearsError::InvalidInput(
405 "k must be smaller than the number of training samples".to_string(),
406 ));
407 }
408
409 let mut prior_probs = Array1::<Float>::zeros(n_labels);
411 for label_idx in 0..n_labels {
412 let positive_count = y.column(label_idx).iter().filter(|&&x| x == 1).count();
413 prior_probs[label_idx] =
414 (positive_count as Float + self.smooth) / (n_samples as Float + 2.0 * self.smooth);
415 }
416
417 let mut conditional_probs = Array2::<Float>::zeros((n_labels, self.k + 1));
419
420 for sample_idx in 0..n_samples {
421 let neighbors = self.find_k_neighbors(X, sample_idx, &X.view())?;
422
423 for label_idx in 0..n_labels {
424 let label_count = neighbors
425 .iter()
426 .filter(|&&neighbor_idx| y[[neighbor_idx, label_idx]] == 1)
427 .count();
428
429 if y[[sample_idx, label_idx]] == 1 {
430 conditional_probs[[label_idx, label_count]] += 1.0;
431 }
432 }
433 }
434
435 for label_idx in 0..n_labels {
437 let total_positive = y.column(label_idx).iter().filter(|&&x| x == 1).count() as Float;
438 for count in 0..=self.k {
439 conditional_probs[[label_idx, count]] = (conditional_probs[[label_idx, count]]
440 + self.smooth)
441 / (total_positive + (self.k + 1) as Float * self.smooth);
442 }
443 }
444
445 Ok(MLkNN {
446 state: MLkNNTrained {
447 training_data: X.to_owned(),
448 training_labels: y.clone(),
449 prior_probs,
450 conditional_probs,
451 k: self.k,
452 smooth: self.smooth,
453 distance_metric: self.distance_metric,
454 n_labels,
455 },
456 k: self.k,
457 smooth: self.smooth,
458 distance_metric: self.distance_metric,
459 })
460 }
461}
462
463impl MLkNN<Untrained> {
464 pub fn new() -> Self {
466 Self {
467 state: Untrained,
468 k: 10,
469 smooth: 1.0,
470 distance_metric: DistanceMetric::Euclidean,
471 }
472 }
473
474 pub fn k(mut self, k: usize) -> Self {
476 self.k = k;
477 self
478 }
479
480 pub fn smooth(mut self, smooth: Float) -> Self {
482 self.smooth = smooth;
483 self
484 }
485
486 pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
488 self.distance_metric = metric;
489 self
490 }
491
492 fn find_k_neighbors(
494 &self,
495 X: &ArrayView2<'_, Float>,
496 sample_idx: usize,
497 training_data: &ArrayView2<'_, Float>,
498 ) -> SklResult<Vec<usize>> {
499 let query = X.row(sample_idx);
500 let mut distances = Vec::new();
501
502 for (train_idx, train_sample) in training_data.rows().into_iter().enumerate() {
503 if train_idx != sample_idx {
504 let distance = self.calculate_distance(&query, &train_sample);
505 distances.push((distance, train_idx));
506 }
507 }
508
509 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
510 let neighbors = distances
511 .into_iter()
512 .take(self.k)
513 .map(|(_, idx)| idx)
514 .collect();
515
516 Ok(neighbors)
517 }
518
519 fn calculate_distance(&self, a: &ArrayView1<'_, Float>, b: &ArrayView1<'_, Float>) -> Float {
521 match self.distance_metric {
522 DistanceMetric::Euclidean => a
523 .iter()
524 .zip(b.iter())
525 .map(|(x, y)| (x - y).powi(2))
526 .sum::<Float>()
527 .sqrt(),
528 DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
529 DistanceMetric::Cosine => {
530 let dot = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<Float>();
531 let norm_a = a.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
532 let norm_b = b.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
533 if norm_a > 0.0 && norm_b > 0.0 {
534 1.0 - dot / (norm_a * norm_b)
535 } else {
536 1.0
537 }
538 }
539 }
540 }
541}
542
543impl Predict<ArrayView2<'_, Float>, Array2<i32>> for MLkNN<MLkNNTrained> {
544 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
545 let (n_samples, n_features) = X.dim();
546
547 if n_features != self.state.training_data.ncols() {
548 return Err(SklearsError::InvalidInput(
549 "X has different number of features than training data".to_string(),
550 ));
551 }
552
553 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
554
555 for sample_idx in 0..n_samples {
556 let neighbors = self.find_k_neighbors_trained(X, sample_idx)?;
557
558 for label_idx in 0..self.state.n_labels {
559 let positive_neighbors = neighbors
561 .iter()
562 .filter(|&&neighbor_idx| {
563 self.state.training_labels[[neighbor_idx, label_idx]] == 1
564 })
565 .count();
566
567 let prob_positive = self.state.prior_probs[label_idx]
569 * self.state.conditional_probs[[label_idx, positive_neighbors]];
570 let prob_negative = (1.0 - self.state.prior_probs[label_idx])
571 * (1.0 - self.state.conditional_probs[[label_idx, positive_neighbors]]);
572
573 predictions[[sample_idx, label_idx]] =
574 if prob_positive > prob_negative { 1 } else { 0 };
575 }
576 }
577
578 Ok(predictions)
579 }
580}
581
582impl MLkNN<MLkNNTrained> {
583 fn find_k_neighbors_trained(
585 &self,
586 X: &ArrayView2<'_, Float>,
587 sample_idx: usize,
588 ) -> SklResult<Vec<usize>> {
589 let query = X.row(sample_idx);
590 let mut distances = Vec::new();
591
592 for (train_idx, train_sample) in self.state.training_data.rows().into_iter().enumerate() {
593 let distance = self.calculate_distance_trained(&query, &train_sample);
594 distances.push((distance, train_idx));
595 }
596
597 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
598 let neighbors = distances
599 .into_iter()
600 .take(self.state.k)
601 .map(|(_, idx)| idx)
602 .collect();
603
604 Ok(neighbors)
605 }
606
607 fn calculate_distance_trained(
609 &self,
610 a: &ArrayView1<'_, Float>,
611 b: &ArrayView1<'_, Float>,
612 ) -> Float {
613 match self.state.distance_metric {
614 DistanceMetric::Euclidean => a
615 .iter()
616 .zip(b.iter())
617 .map(|(x, y)| (x - y).powi(2))
618 .sum::<Float>()
619 .sqrt(),
620 DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
621 DistanceMetric::Cosine => {
622 let dot = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<Float>();
623 let norm_a = a.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
624 let norm_b = b.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
625 if norm_a > 0.0 && norm_b > 0.0 {
626 1.0 - dot / (norm_a * norm_b)
627 } else {
628 1.0
629 }
630 }
631 }
632 }
633
634 pub fn k(&self) -> usize {
636 self.state.k
637 }
638
639 pub fn prior_probabilities(&self) -> &Array1<Float> {
641 &self.state.prior_probs
642 }
643}
644
645#[derive(Debug, Clone)]
650pub struct CostSensitiveBinaryRelevance<S = Untrained> {
651 state: S,
652 cost_matrix: CostMatrix,
653 learning_rate: Float,
654 max_iterations: usize,
655 regularization: Float,
656}
657
658#[derive(Debug, Clone)]
660pub struct CostMatrix {
661 false_positive_costs: Array1<Float>,
663 false_negative_costs: Array1<Float>,
665}
666
667impl CostMatrix {
668 pub fn new(false_positive_costs: Array1<Float>, false_negative_costs: Array1<Float>) -> Self {
670 Self {
671 false_positive_costs,
672 false_negative_costs,
673 }
674 }
675
676 pub fn uniform(n_labels: usize, fp_cost: Float, fn_cost: Float) -> Self {
678 Self {
679 false_positive_costs: Array1::from_elem(n_labels, fp_cost),
680 false_negative_costs: Array1::from_elem(n_labels, fn_cost),
681 }
682 }
683
684 pub fn fp_cost(&self, label_idx: usize) -> Float {
686 self.false_positive_costs
687 .get(label_idx)
688 .copied()
689 .unwrap_or(1.0)
690 }
691
692 pub fn fn_cost(&self, label_idx: usize) -> Float {
694 self.false_negative_costs
695 .get(label_idx)
696 .copied()
697 .unwrap_or(1.0)
698 }
699}
700
701#[derive(Debug, Clone)]
703pub struct CostSensitiveBinaryRelevanceTrained {
704 models: Vec<SimpleBinaryModel>,
705 cost_matrix: CostMatrix,
706 n_features: usize,
707 n_labels: usize,
708}
709
710#[derive(Debug, Clone)]
712pub struct SimpleBinaryModel {
713 weights: Array1<Float>,
714 bias: Float,
715 threshold: Float, }
717
718impl Default for CostSensitiveBinaryRelevance<Untrained> {
719 fn default() -> Self {
720 Self::new()
721 }
722}
723
724impl Estimator for CostSensitiveBinaryRelevance<Untrained> {
725 type Config = ();
726 type Error = SklearsError;
727 type Float = Float;
728
729 fn config(&self) -> &Self::Config {
730 &()
731 }
732}
733
734impl Fit<ArrayView2<'_, Float>, Array2<i32>> for CostSensitiveBinaryRelevance<Untrained> {
735 type Fitted = CostSensitiveBinaryRelevance<CostSensitiveBinaryRelevanceTrained>;
736
737 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
738 let (n_samples, n_features) = X.dim();
739 let n_labels = y.ncols();
740
741 if n_samples != y.nrows() {
742 return Err(SklearsError::InvalidInput(
743 "X and y must have the same number of samples".to_string(),
744 ));
745 }
746
747 let mut models = Vec::new();
748
749 for label_idx in 0..n_labels {
751 let y_label = y.column(label_idx);
752 let fp_cost = self.cost_matrix.fp_cost(label_idx);
753 let fn_cost = self.cost_matrix.fn_cost(label_idx);
754
755 let mut weights = Array1::<Float>::zeros(n_features);
756 let mut bias = 0.0;
757
758 for _iter in 0..self.max_iterations {
760 let mut weight_gradient = Array1::<Float>::zeros(n_features);
761 let mut bias_gradient = 0.0;
762
763 for sample_idx in 0..n_samples {
764 let x = X.row(sample_idx);
765 let y_true = y_label[sample_idx] as Float;
766
767 let logit = x.dot(&weights) + bias;
768 let prob = 1.0 / (1.0 + (-logit).exp());
769
770 let cost_weight = if y_true == 1.0 { fn_cost } else { fp_cost };
772 let error = (prob - y_true) * cost_weight;
773
774 for feat_idx in 0..n_features {
776 weight_gradient[feat_idx] += error * x[feat_idx];
777 }
778 bias_gradient += error;
779 }
780
781 for i in 0..n_features {
783 weight_gradient[i] += self.regularization * weights[i];
784 }
785
786 for i in 0..n_features {
788 weights[i] -= self.learning_rate * weight_gradient[i] / n_samples as Float;
789 }
790 bias -= self.learning_rate * bias_gradient / n_samples as Float;
791 }
792
793 let threshold = self.calculate_cost_sensitive_threshold(fp_cost, fn_cost);
795
796 models.push(SimpleBinaryModel {
797 weights,
798 bias,
799 threshold,
800 });
801 }
802
803 Ok(CostSensitiveBinaryRelevance {
804 state: CostSensitiveBinaryRelevanceTrained {
805 models,
806 cost_matrix: self.cost_matrix,
807 n_features,
808 n_labels,
809 },
810 cost_matrix: CostMatrix::uniform(n_labels, 1.0, 1.0),
811 learning_rate: self.learning_rate,
812 max_iterations: self.max_iterations,
813 regularization: self.regularization,
814 })
815 }
816}
817
818impl CostSensitiveBinaryRelevance<Untrained> {
819 pub fn new() -> Self {
821 Self {
822 state: Untrained,
823 cost_matrix: CostMatrix::uniform(1, 1.0, 1.0),
824 learning_rate: 0.01,
825 max_iterations: 100,
826 regularization: 0.01,
827 }
828 }
829
830 pub fn cost_matrix(mut self, cost_matrix: CostMatrix) -> Self {
832 self.cost_matrix = cost_matrix;
833 self
834 }
835
836 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
838 self.learning_rate = learning_rate;
839 self
840 }
841
842 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
844 self.max_iterations = max_iterations;
845 self
846 }
847
848 pub fn regularization(mut self, regularization: Float) -> Self {
850 self.regularization = regularization;
851 self
852 }
853
854 fn calculate_cost_sensitive_threshold(&self, fp_cost: Float, fn_cost: Float) -> Float {
856 fp_cost / (fp_cost + fn_cost)
860 }
861}
862
863impl Predict<ArrayView2<'_, Float>, Array2<i32>>
864 for CostSensitiveBinaryRelevance<CostSensitiveBinaryRelevanceTrained>
865{
866 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
867 let (n_samples, n_features) = X.dim();
868
869 if n_features != self.state.n_features {
870 return Err(SklearsError::InvalidInput(
871 "X has different number of features than training data".to_string(),
872 ));
873 }
874
875 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
876
877 for sample_idx in 0..n_samples {
878 let x = X.row(sample_idx);
879
880 for (label_idx, model) in self.state.models.iter().enumerate() {
881 let logit = x.dot(&model.weights) + model.bias;
882 let prob = 1.0 / (1.0 + (-logit).exp());
883
884 predictions[[sample_idx, label_idx]] = if prob > model.threshold { 1 } else { 0 };
885 }
886 }
887
888 Ok(predictions)
889 }
890}
891
892impl CostSensitiveBinaryRelevance<CostSensitiveBinaryRelevanceTrained> {
893 pub fn cost_matrix(&self) -> &CostMatrix {
895 &self.state.cost_matrix
896 }
897
898 pub fn thresholds(&self) -> Vec<Float> {
900 self.state.models.iter().map(|m| m.threshold).collect()
901 }
902}
903
904#[allow(non_snake_case)]
905#[cfg(test)]
906mod tests {
907 use super::*;
908 use scirs2_core::ndarray::array;
910
911 #[test]
912 #[allow(non_snake_case)]
913 fn test_calibrated_binary_relevance_basic() {
914 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
915 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
916
917 let cbr = CalibratedBinaryRelevance::new().calibration_method(CalibrationMethod::Platt);
918 let trained_cbr = cbr.fit(&X.view(), &y).unwrap();
919 let predictions = trained_cbr.predict(&X.view()).unwrap();
920
921 assert_eq!(predictions.dim(), (4, 2));
922 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
923 }
924
925 #[test]
926 #[allow(non_snake_case)]
927 fn test_calibrated_binary_relevance_probabilities() {
928 let X = array![[1.0, 2.0], [2.0, 3.0]];
929 let y = array![[1, 0], [0, 1]];
930
931 let cbr = CalibratedBinaryRelevance::new();
932 let trained_cbr = cbr.fit(&X.view(), &y).unwrap();
933 let probabilities = trained_cbr.predict_proba(&X.view()).unwrap();
934
935 assert_eq!(probabilities.dim(), (2, 2));
936 assert!(probabilities.iter().all(|&p| p >= 0.0 && p <= 1.0));
937 }
938
939 #[test]
940 fn test_random_label_combinations() {
941 let generator = RandomLabelCombinations::new(3)
942 .n_combinations(5)
943 .label_density(0.5)
944 .random_state(42);
945
946 let combinations = generator.generate();
947 assert_eq!(combinations.dim(), (5, 3));
948 assert!(combinations.iter().all(|&x| x == 0 || x == 1));
949 }
950
951 #[test]
952 #[allow(non_snake_case)]
953 fn test_mlknn_basic() {
954 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.5, 2.5]];
955 let y = array![[1, 0], [0, 1], [1, 1], [0, 0], [1, 0]];
956
957 let mlknn = MLkNN::new().k(3).smooth(1.0);
958 let trained_mlknn = mlknn.fit(&X.view(), &y).unwrap();
959 let predictions = trained_mlknn.predict(&X.view()).unwrap();
960
961 assert_eq!(predictions.dim(), (5, 2));
962 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
963 assert_eq!(trained_mlknn.k(), 3);
964 }
965
966 #[test]
967 #[allow(non_snake_case)]
968 fn test_mlknn_distance_metrics() {
969 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
970 let y = array![[1, 0], [0, 1], [1, 1]];
971
972 let mlknn_euclidean = MLkNN::new().k(2).distance_metric(DistanceMetric::Euclidean);
973 let trained_euclidean = mlknn_euclidean.fit(&X.view(), &y).unwrap();
974
975 let mlknn_manhattan = MLkNN::new().k(2).distance_metric(DistanceMetric::Manhattan);
976 let trained_manhattan = mlknn_manhattan.fit(&X.view(), &y).unwrap();
977
978 let pred_euclidean = trained_euclidean.predict(&X.view()).unwrap();
979 let pred_manhattan = trained_manhattan.predict(&X.view()).unwrap();
980
981 assert_eq!(pred_euclidean.dim(), (3, 2));
982 assert_eq!(pred_manhattan.dim(), (3, 2));
983 }
984
985 #[test]
986 #[allow(non_snake_case)]
987 fn test_cost_sensitive_binary_relevance() {
988 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
989 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
990
991 let fp_costs = array![2.0, 1.0]; let fn_costs = array![1.0, 3.0]; let cost_matrix = CostMatrix::new(fp_costs, fn_costs);
994
995 let csbr = CostSensitiveBinaryRelevance::new()
996 .cost_matrix(cost_matrix)
997 .learning_rate(0.01)
998 .max_iterations(50);
999
1000 let trained_csbr = csbr.fit(&X.view(), &y).unwrap();
1001 let predictions = trained_csbr.predict(&X.view()).unwrap();
1002
1003 assert_eq!(predictions.dim(), (4, 2));
1004 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1005
1006 let thresholds = trained_csbr.thresholds();
1007 assert_eq!(thresholds.len(), 2);
1008 }
1009
1010 #[test]
1011 fn test_cost_matrix_creation() {
1012 let fp_costs = array![1.0, 2.0, 3.0];
1013 let fn_costs = array![2.0, 1.0, 1.0];
1014 let cost_matrix = CostMatrix::new(fp_costs, fn_costs);
1015
1016 assert_eq!(cost_matrix.fp_cost(0), 1.0);
1017 assert_eq!(cost_matrix.fp_cost(1), 2.0);
1018 assert_eq!(cost_matrix.fn_cost(0), 2.0);
1019 assert_eq!(cost_matrix.fn_cost(1), 1.0);
1020
1021 let uniform_costs = CostMatrix::uniform(3, 1.5, 2.5);
1022 assert_eq!(uniform_costs.fp_cost(0), 1.5);
1023 assert_eq!(uniform_costs.fn_cost(2), 2.5);
1024 }
1025
1026 #[test]
1027 fn test_calibration_methods() {
1028 let cbr_platt =
1029 CalibratedBinaryRelevance::new().calibration_method(CalibrationMethod::Platt);
1030 let cbr_isotonic =
1031 CalibratedBinaryRelevance::new().calibration_method(CalibrationMethod::Isotonic);
1032
1033 assert_eq!(cbr_platt.calibration_method, CalibrationMethod::Platt);
1035 assert_eq!(cbr_isotonic.calibration_method, CalibrationMethod::Isotonic);
1036 }
1037
1038 #[test]
1039 #[allow(non_snake_case)]
1040 fn test_mlknn_prior_probabilities() {
1041 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1042 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let mlknn = MLkNN::new().k(2).smooth(1.0);
1045 let trained_mlknn = mlknn.fit(&X.view(), &y).unwrap();
1046
1047 let priors = trained_mlknn.prior_probabilities();
1048 assert_eq!(priors.len(), 2);
1049
1050 assert!((priors[0] - 0.5).abs() < 1e-6);
1052 assert!((priors[1] - 0.5).abs() < 1e-6);
1053 }
1054}