scirs2_metrics/classification.rs
1//! Classification metrics module
2//!
3//! This module provides functions for evaluating classification models, including
4//! accuracy, precision, recall, F1 score, ROC AUC, and advanced metrics.
5//!
6//! ## Basic Metrics
7//!
8//! Basic classification metrics include accuracy, precision, recall, and F1 score.
9//!
10//! ## Advanced Metrics
11//!
12//! Advanced metrics include Matthews Correlation Coefficient, balanced accuracy,
13//! Cohen's kappa, Brier score, Jaccard similarity, and Hamming loss.
14//!
15//! ```
16//! use scirs2_core::ndarray::array;
17//! use scirs2_metrics::classification::advanced::{matthews_corrcoef, balanced_accuracy_score};
18//!
19//! let y_true = array![0, 1, 2, 0, 1, 2];
20//! let y_pred = array![0, 2, 1, 0, 0, 2];
21//!
22//! let mcc = matthews_corrcoef(&y_true, &y_pred).unwrap();
23//! let bal_acc = balanced_accuracy_score(&y_true, &y_pred).unwrap();
24//! ```
25//!
26//! ## One-vs-One Metrics
27//!
28//! One-vs-One metrics are useful for evaluating multi-class classification problems by
29//! considering each pair of classes separately.
30//!
31//! ```
32//! use scirs2_core::ndarray::array;
33//! use scirs2_metrics::classification::one_vs_one::{one_vs_one_accuracy, one_vs_one_f1_score};
34//!
35//! let y_true = array![0, 1, 2, 0, 1, 2];
36//! let y_pred = array![0, 2, 1, 0, 0, 2];
37//!
38//! let ovo_acc = one_vs_one_accuracy(&y_true, &y_pred).unwrap();
39//! let f1_scores = one_vs_one_f1_score(&y_true, &y_pred).unwrap();
40//! ```
41
42pub mod advanced;
43pub mod curves;
44pub mod one_vs_one;
45pub mod threshold;
46pub mod threshold_analyzer;
47
48use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Dimension};
49use scirs2_core::numeric::NumCast;
50
51use crate::error::{MetricsError, Result};
52
53/// Calculates accuracy score, the fraction of correctly classified samples
54///
55/// # Mathematical Formulation
56///
57/// The accuracy score is defined as:
58///
59/// ```text
60/// Accuracy = (Number of Correct Predictions) / (Total Number of Predictions)
61/// = (TP + TN) / (TP + TN + FP + FN)
62/// ```
63///
64/// Where:
65/// - TP = True Positives
66/// - TN = True Negatives
67/// - FP = False Positives
68/// - FN = False Negatives
69///
70/// For multi-class classification, accuracy is simply:
71///
72/// ```text
73/// Accuracy = (1/n) * Σ I(ŷᵢ = yᵢ)
74/// ```
75///
76/// Where:
77/// - n = total number of samples
78/// - I(·) = indicator function (1 if condition is true, 0 otherwise)
79/// - ŷᵢ = predicted label for sample i
80/// - yᵢ = true label for sample i
81///
82/// # Range
83///
84/// Accuracy is bounded between 0 and 1, where:
85/// - 0 = worst possible accuracy (all predictions wrong)
86/// - 1 = perfect accuracy (all predictions correct)
87/// - 0.5 = random guessing for balanced binary classification
88///
89/// # Arguments
90///
91/// * `y_true` - Ground truth (correct) labels
92/// * `y_pred` - Predicted labels, as returned by a classifier
93///
94/// # Returns
95///
96/// * The fraction of correctly classified samples (float)
97///
98/// # Examples
99///
100/// ```
101/// use scirs2_core::ndarray::array;
102/// use scirs2_metrics::classification::accuracy_score;
103///
104/// let y_true = array![0, 1, 2, 3];
105/// let y_pred = array![0, 2, 1, 3];
106///
107/// let acc = accuracy_score(&y_true, &y_pred).unwrap();
108/// assert!((acc - 0.5).abs() < 1e-10); // 2 out of 4 are correct
109/// ```
110#[allow(dead_code)]
111pub fn accuracy_score<T, S1, S2, D1, D2>(
112 y_true: &ArrayBase<S1, D1>,
113 y_pred: &ArrayBase<S2, D2>,
114) -> Result<f64>
115where
116 T: PartialEq + NumCast + Clone,
117 S1: Data<Elem = T>,
118 S2: Data<Elem = T>,
119 D1: Dimension,
120 D2: Dimension,
121{
122 // Check that arrays have the same shape
123 if y_true.shape() != y_pred.shape() {
124 return Err(MetricsError::InvalidInput(format!(
125 "y_true and y_pred have different shapes: {:?} vs {:?}",
126 y_true.shape(),
127 y_pred.shape()
128 )));
129 }
130
131 let n_samples = y_true.len();
132 if n_samples == 0 {
133 return Err(MetricsError::InvalidInput(
134 "Empty arrays provided".to_string(),
135 ));
136 }
137
138 // Count correct predictions
139 let mut n_correct = 0;
140 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
141 if yt == yp {
142 n_correct += 1;
143 }
144 }
145
146 Ok(n_correct as f64 / n_samples as f64)
147}
148
149/// Calculates a confusion matrix to evaluate the accuracy of a classification
150///
151/// # Arguments
152///
153/// * `y_true` - Ground truth (correct) labels
154/// * `y_pred` - Predicted labels, as returned by a classifier
155/// * `labels` - Optional list of label values to index the matrix. This may be
156/// used to reorder or select a subset of labels. If None, those that appear
157/// at least once in y_true or y_pred are used in sorted order.
158///
159/// # Returns
160///
161/// * Confusion matrix (Array2<u64>)
162/// * Vector of classes in order (Array1<T>)
163///
164/// # Examples
165///
166/// ```
167/// use scirs2_core::ndarray::array;
168/// use scirs2_metrics::classification::confusion_matrix;
169///
170/// let y_true = array![0, 1, 2, 0, 1, 2];
171/// let y_pred = array![0, 2, 1, 0, 0, 2];
172///
173/// let (cm, classes) = confusion_matrix(&y_true, &y_pred, None).unwrap();
174/// // Expected confusion matrix:
175/// // [[2, 0, 0],
176/// // [1, 0, 1],
177/// // [0, 1, 1]]
178///
179/// assert_eq!(cm[[0, 0]], 2); // True 0, predicted 0
180/// assert_eq!(cm[[1, 0]], 1); // True 1, predicted 0
181/// assert_eq!(cm[[1, 2]], 1); // True 1, predicted 2
182/// assert_eq!(cm[[2, 1]], 1); // True 2, predicted 1
183/// assert_eq!(cm[[2, 2]], 1); // True 2, predicted 2
184/// ```
185#[allow(dead_code)]
186pub fn confusion_matrix<T, S1, S2, D1, D2>(
187 y_true: &ArrayBase<S1, D1>,
188 y_pred: &ArrayBase<S2, D2>,
189 labels: Option<&[T]>,
190) -> Result<(Array2<u64>, Array1<T>)>
191where
192 T: PartialEq + NumCast + Clone + Ord + std::hash::Hash + std::fmt::Debug,
193 S1: Data<Elem = T>,
194 S2: Data<Elem = T>,
195 D1: Dimension,
196 D2: Dimension,
197{
198 // Check that arrays have the same shape
199 if y_true.shape() != y_pred.shape() {
200 return Err(MetricsError::InvalidInput(format!(
201 "y_true and y_pred have different shapes: {:?} vs {:?}",
202 y_true.shape(),
203 y_pred.shape()
204 )));
205 }
206
207 let n_samples = y_true.len();
208 if n_samples == 0 {
209 return Err(MetricsError::InvalidInput(
210 "Empty arrays provided".to_string(),
211 ));
212 }
213
214 // Determine the classes
215 let classes = if let Some(labels) = labels {
216 let mut cls = Vec::with_capacity(labels.len());
217 for label in labels {
218 cls.push(label.clone());
219 }
220 cls
221 } else {
222 let mut cls = std::collections::BTreeSet::new();
223 for yt in y_true.iter() {
224 cls.insert(yt.clone());
225 }
226 for yp in y_pred.iter() {
227 cls.insert(yp.clone());
228 }
229 cls.into_iter().collect()
230 };
231
232 // Create the confusion matrix
233 let n_classes = classes.len();
234 let mut cm = Array2::zeros((n_classes, n_classes));
235
236 // Create a map from class to index
237 let mut class_to_idx = std::collections::HashMap::new();
238 for (i, c) in classes.iter().enumerate() {
239 class_to_idx.insert(c, i);
240 }
241
242 // Fill the confusion matrix
243 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
244 if let (Some(&i), Some(&j)) = (class_to_idx.get(yt), class_to_idx.get(yp)) {
245 cm[[i, j]] += 1;
246 }
247 }
248
249 Ok((cm, Array1::from(classes)))
250}
251
252/// Calculates the precision score for binary classification
253///
254/// # Mathematical Formulation
255///
256/// Precision is defined as:
257///
258/// ```text
259/// Precision = TP / (TP + FP)
260/// ```
261///
262/// Where:
263/// - TP = True Positives (correctly predicted positive cases)
264/// - FP = False Positives (incorrectly predicted as positive)
265///
266/// Alternatively, precision can be expressed as:
267///
268/// ```text
269/// Precision = P(y_true = positive | ŷ = positive)
270/// ```
271///
272/// This represents the probability that a sample is actually positive
273/// given that the classifier predicted it as positive.
274///
275/// # Interpretation
276///
277/// Precision answers the question: "Of all the samples the classifier
278/// predicted as positive, how many were actually positive?"
279///
280/// - High precision means low false positive rate
281/// - Precision = 1.0 means no false positives
282/// - Precision = 0.0 means no true positives (all positive predictions are wrong)
283///
284/// # Range
285///
286/// Precision is bounded between 0 and 1:
287/// - 0 = worst precision (no correct positive predictions)
288/// - 1 = perfect precision (no false positive predictions)
289///
290/// # Use Cases
291///
292/// High precision is important when the cost of false positives is high,
293/// such as:
294/// - Medical diagnosis (avoid unnecessary treatments)
295/// - Spam detection (avoid blocking legitimate emails)
296/// - Quality control (avoid rejecting good products)
297///
298/// # Arguments
299///
300/// * `y_true` - Ground truth (correct) binary labels
301/// * `y_pred` - Predicted binary labels, as returned by a classifier
302/// * `pos_label` - The label to report as positive class
303///
304/// # Returns
305///
306/// * The precision score (float between 0.0 and 1.0)
307///
308/// # Examples
309///
310/// ```
311/// use scirs2_core::ndarray::array;
312/// use scirs2_metrics::classification::precision_score;
313///
314/// let y_true = array![0, 1, 0, 0, 1, 1];
315/// let y_pred = array![0, 0, 1, 0, 1, 1];
316///
317/// let precision = precision_score(&y_true, &y_pred, 1).unwrap();
318/// // There are 2 true positives and 1 false positive
319/// assert!((precision - 2.0/3.0).abs() < 1e-10);
320/// ```
321#[allow(dead_code)]
322pub fn precision_score<T, S1, S2, D1, D2>(
323 y_true: &ArrayBase<S1, D1>,
324 y_pred: &ArrayBase<S2, D2>,
325 pos_label: T,
326) -> Result<f64>
327where
328 T: PartialEq + NumCast + Clone,
329 S1: Data<Elem = T>,
330 S2: Data<Elem = T>,
331 D1: Dimension,
332 D2: Dimension,
333{
334 // Check that arrays have the same shape
335 if y_true.shape() != y_pred.shape() {
336 return Err(MetricsError::InvalidInput(format!(
337 "y_true and y_pred have different shapes: {:?} vs {:?}",
338 y_true.shape(),
339 y_pred.shape()
340 )));
341 }
342
343 let n_samples = y_true.len();
344 if n_samples == 0 {
345 return Err(MetricsError::InvalidInput(
346 "Empty arrays provided".to_string(),
347 ));
348 }
349
350 // Count _true positives and false positives
351 let mut true_positives = 0;
352 let mut false_positives = 0;
353
354 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
355 if yp == &pos_label {
356 if yt == yp {
357 true_positives += 1;
358 } else {
359 false_positives += 1;
360 }
361 }
362 }
363
364 // Calculate precision
365 if true_positives + false_positives == 0 {
366 Ok(0.0) // No positive predictions, precision is 0
367 } else {
368 Ok(true_positives as f64 / (true_positives + false_positives) as f64)
369 }
370}
371
372/// Calculates the recall score for binary classification
373///
374/// # Mathematical Formulation
375///
376/// Recall (also known as sensitivity or true positive rate) is defined as:
377///
378/// ```text
379/// Recall = TP / (TP + FN)
380/// ```
381///
382/// Where:
383/// - TP = True Positives (correctly predicted positive cases)
384/// - FN = False Negatives (incorrectly predicted as negative)
385///
386/// Alternatively, recall can be expressed as:
387///
388/// ```text
389/// Recall = P(ŷ = positive | y_true = positive)
390/// ```
391///
392/// This represents the probability that the classifier predicts positive
393/// given that the sample is actually positive.
394///
395/// # Interpretation
396///
397/// Recall answers the question: "Of all the actual positive samples,
398/// how many did the classifier correctly identify?"
399///
400/// - High recall means low false negative rate
401/// - Recall = 1.0 means no false negatives (all positive cases found)
402/// - Recall = 0.0 means no true positives (all positive cases missed)
403///
404/// # Range
405///
406/// Recall is bounded between 0 and 1:
407/// - 0 = worst recall (no positive cases identified)
408/// - 1 = perfect recall (all positive cases identified)
409///
410/// # Use Cases
411///
412/// High recall is important when the cost of false negatives is high,
413/// such as:
414/// - Medical screening (avoid missing diseases)
415/// - Security systems (avoid missing threats)
416/// - Search engines (avoid missing relevant results)
417///
418/// # Relationship to Other Metrics
419///
420/// Recall is complementary to precision:
421/// - Precision focuses on minimizing false positives
422/// - Recall focuses on minimizing false negatives
423/// - There's often a trade-off between precision and recall
424///
425/// # Arguments
426///
427/// * `y_true` - Ground truth (correct) binary labels
428/// * `y_pred` - Predicted binary labels, as returned by a classifier
429/// * `pos_label` - The label to report as positive class
430///
431/// # Returns
432///
433/// * The recall score (float between 0.0 and 1.0)
434///
435/// # Examples
436///
437/// ```
438/// use scirs2_core::ndarray::array;
439/// use scirs2_metrics::classification::recall_score;
440///
441/// let y_true = array![0, 1, 0, 0, 1, 1];
442/// let y_pred = array![0, 0, 1, 0, 1, 1];
443///
444/// let recall = recall_score(&y_true, &y_pred, 1).unwrap();
445/// // There are 2 true positives and 1 false negative
446/// assert!((recall - 2.0/3.0).abs() < 1e-10);
447/// ```
448#[allow(dead_code)]
449pub fn recall_score<T, S1, S2, D1, D2>(
450 y_true: &ArrayBase<S1, D1>,
451 y_pred: &ArrayBase<S2, D2>,
452 pos_label: T,
453) -> Result<f64>
454where
455 T: PartialEq + NumCast + Clone,
456 S1: Data<Elem = T>,
457 S2: Data<Elem = T>,
458 D1: Dimension,
459 D2: Dimension,
460{
461 // Check that arrays have the same shape
462 if y_true.shape() != y_pred.shape() {
463 return Err(MetricsError::InvalidInput(format!(
464 "y_true and y_pred have different shapes: {:?} vs {:?}",
465 y_true.shape(),
466 y_pred.shape()
467 )));
468 }
469
470 let n_samples = y_true.len();
471 if n_samples == 0 {
472 return Err(MetricsError::InvalidInput(
473 "Empty arrays provided".to_string(),
474 ));
475 }
476
477 // Count _true positives and false negatives
478 let mut true_positives = 0;
479 let mut false_negatives = 0;
480
481 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
482 if yt == &pos_label {
483 if yp == yt {
484 true_positives += 1;
485 } else {
486 false_negatives += 1;
487 }
488 }
489 }
490
491 // Calculate recall
492 if true_positives + false_negatives == 0 {
493 Ok(0.0) // No actual positives, recall is 0
494 } else {
495 Ok(true_positives as f64 / (true_positives + false_negatives) as f64)
496 }
497}
498
499/// Calculates the F1 score for binary classification
500///
501/// # Mathematical Formulation
502///
503/// The F1 score is the harmonic mean of precision and recall:
504///
505/// ```text
506/// F1 = 2 * (Precision × Recall) / (Precision + Recall)
507/// ```
508///
509/// Equivalently, it can be expressed in terms of confusion matrix elements:
510///
511/// ```text
512/// F1 = 2TP / (2TP + FP + FN)
513/// ```
514///
515/// Where:
516/// - TP = True Positives
517/// - FP = False Positives
518/// - FN = False Negatives
519///
520/// # Harmonic vs Arithmetic Mean
521///
522/// The F1 score uses harmonic mean rather than arithmetic mean because:
523/// - Harmonic mean gives more weight to smaller values
524/// - If either precision or recall is low, F1 will be low
525/// - It penalizes extreme imbalances between precision and recall
526///
527/// ```text
528/// Arithmetic mean: (P + R) / 2
529/// Harmonic mean: 2PR / (P + R)
530/// ```
531///
532/// # Interpretation
533///
534/// The F1 score provides a single metric that balances precision and recall:
535/// - F1 = 1.0: Perfect precision and recall
536/// - F1 = 0.0: Either precision or recall (or both) is zero
537/// - F1 is closer to the smaller of precision and recall
538///
539/// # Range and Properties
540///
541/// - Range: [0, 1]
542/// - F1 ≤ min(Precision, Recall)
543/// - F1 = 0 if either Precision = 0 or Recall = 0
544/// - F1 approaches max(Precision, Recall) when they are similar
545///
546/// # Use Cases
547///
548/// F1 score is particularly useful when:
549/// - You need a single metric balancing precision and recall
550/// - Class distribution is imbalanced
551/// - Both false positives and false negatives are costly
552/// - You want to avoid optimizing for just one metric
553///
554/// # Relationship to F-beta Score
555///
556/// F1 is a special case of the F-beta score with β = 1:
557///
558/// ```text
559/// F_β = (1 + β²) × (Precision × Recall) / (β² × Precision + Recall)
560/// ```
561///
562/// When β = 1, this reduces to the F1 formula above.
563///
564/// # Arguments
565///
566/// * `y_true` - Ground truth (correct) binary labels
567/// * `y_pred` - Predicted binary labels, as returned by a classifier
568/// * `pos_label` - The label to report as positive class
569///
570/// # Returns
571///
572/// * The F1 score (float between 0.0 and 1.0)
573///
574/// # Examples
575///
576/// ```
577/// use scirs2_core::ndarray::array;
578/// use scirs2_metrics::classification::f1_score;
579///
580/// let y_true = array![0, 1, 0, 0, 1, 1];
581/// let y_pred = array![0, 0, 1, 0, 1, 1];
582///
583/// let f1 = f1_score(&y_true, &y_pred, 1).unwrap();
584/// ```
585#[allow(dead_code)]
586pub fn f1_score<T, S1, S2, D1, D2>(
587 y_true: &ArrayBase<S1, D1>,
588 y_pred: &ArrayBase<S2, D2>,
589 pos_label: T,
590) -> Result<f64>
591where
592 T: PartialEq + NumCast + Clone,
593 S1: Data<Elem = T>,
594 S2: Data<Elem = T>,
595 D1: Dimension,
596 D2: Dimension,
597{
598 // F1 score is a special case of fbeta_score with beta = 1.0
599 fbeta_score(y_true, y_pred, pos_label, 1.0)
600}
601
602/// Calculates the F-beta score for binary classification
603///
604/// The F-beta score is the weighted harmonic mean of precision and recall:
605/// `F-beta = (1 + beta^2) * (precision * recall) / ((beta^2 * precision) + recall)`
606///
607/// The beta parameter determines the weight of recall in the combined score:
608/// - beta < 1 gives more weight to precision
609/// - beta > 1 gives more weight to recall
610/// - beta = 1 gives equal weight to precision and recall (F1 score)
611///
612/// # Arguments
613///
614/// * `y_true` - Ground truth (correct) binary labels
615/// * `y_pred` - Predicted binary labels, as returned by a classifier
616/// * `pos_label` - The label to report as positive class
617/// * `beta` - The weight of recall relative to precision (must be positive)
618///
619/// # Returns
620///
621/// * The F-beta score (float between 0.0 and 1.0)
622///
623/// # Examples
624///
625/// ```
626/// use scirs2_core::ndarray::array;
627/// use scirs2_metrics::classification::fbeta_score;
628///
629/// let y_true = array![0, 1, 0, 0, 1, 1];
630/// let y_pred = array![0, 0, 1, 0, 1, 1];
631///
632/// // F0.5 score (weighs precision higher than recall)
633/// let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
634///
635/// // F2 score (weighs recall higher than precision)
636/// let f_two = fbeta_score(&y_true, &y_pred, 1, 2.0).unwrap();
637/// ```
638#[allow(dead_code)]
639pub fn fbeta_score<T, S1, S2, D1, D2>(
640 y_true: &ArrayBase<S1, D1>,
641 y_pred: &ArrayBase<S2, D2>,
642 pos_label: T,
643 beta: f64,
644) -> Result<f64>
645where
646 T: PartialEq + NumCast + Clone,
647 S1: Data<Elem = T>,
648 S2: Data<Elem = T>,
649 D1: Dimension,
650 D2: Dimension,
651{
652 if beta <= 0.0 {
653 return Err(MetricsError::InvalidInput(format!(
654 "beta must be positive, got {beta}"
655 )));
656 }
657
658 let precision = precision_score(y_true, y_pred, pos_label.clone())?;
659 let recall = recall_score(y_true, y_pred, pos_label)?;
660
661 if precision + recall == 0.0 {
662 return Ok(0.0);
663 }
664
665 let beta_squared = beta * beta;
666 Ok((1.0 + beta_squared) * precision * recall / ((beta_squared * precision) + recall))
667}
668
669/// Calculate binary log loss, also known as binary cross-entropy
670///
671/// This is the loss function used in (multinomial) logistic regression
672/// and neural networks (with softmax output).
673///
674/// # Arguments
675///
676/// * `y_true` - Ground truth binary labels
677/// * `y_prob` - Predicted probabilities for the positive class
678/// * `eps` - Small value to avoid log(0)
679///
680/// # Returns
681///
682/// * The log loss (float)
683///
684/// # Examples
685///
686/// ```
687/// use scirs2_core::ndarray::array;
688/// use scirs2_metrics::classification::binary_log_loss;
689///
690/// let y_true = array![0, 1, 1, 0];
691/// let y_prob = array![0.1, 0.9, 0.8, 0.3];
692///
693/// let loss = binary_log_loss(&y_true, &y_prob, 1e-15).unwrap();
694/// ```
695#[allow(dead_code)]
696pub fn binary_log_loss<S1, S2, D1, D2>(
697 y_true: &ArrayBase<S1, D1>,
698 y_prob: &ArrayBase<S2, D2>,
699 eps: f64,
700) -> Result<f64>
701where
702 S1: Data<Elem = u32>,
703 S2: Data<Elem = f64>,
704 D1: Dimension,
705 D2: Dimension,
706{
707 // Check that arrays have the same shape
708 if y_true.shape() != y_prob.shape() {
709 return Err(MetricsError::InvalidInput(format!(
710 "y_true and y_prob have different shapes: {:?} vs {:?}",
711 y_true.shape(),
712 y_prob.shape()
713 )));
714 }
715
716 let n_samples = y_true.len();
717 if n_samples == 0 {
718 return Err(MetricsError::InvalidInput(
719 "Empty arrays provided".to_string(),
720 ));
721 }
722
723 // Compute log loss
724 let mut loss = 0.0;
725 for (yt, yp) in y_true.iter().zip(y_prob.iter()) {
726 // Clip probability to avoid log(0)
727 let clipped_yp = yp.max(eps).min(1.0 - eps);
728
729 if *yt == 1 {
730 loss -= (clipped_yp).ln();
731 } else {
732 loss -= (1.0 - clipped_yp).ln();
733 }
734 }
735
736 Ok(loss / n_samples as f64)
737}
738
739/// Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
740///
741/// The ROC curve is created by plotting the true positive rate (TPR) against
742/// the false positive rate (FPR) at various threshold settings.
743///
744/// # Arguments
745///
746/// * `y_true` - Ground truth binary labels
747/// * `y_score` - Target scores (can be probability estimates of the positive class)
748///
749/// # Returns
750///
751/// * The ROC AUC score (float between 0.0 and 1.0)
752///
753/// # Examples
754///
755/// ```
756/// use scirs2_core::ndarray::array;
757/// use scirs2_metrics::classification::roc_auc_score;
758///
759/// let y_true = array![0, 0, 1, 1];
760/// let y_score = array![0.1, 0.4, 0.35, 0.8];
761///
762/// let auc = roc_auc_score(&y_true, &y_score).unwrap();
763/// ```
764#[allow(dead_code)]
765pub fn roc_auc_score<S1, S2, D1, D2>(
766 y_true: &ArrayBase<S1, D1>,
767 y_score: &ArrayBase<S2, D2>,
768) -> Result<f64>
769where
770 S1: Data<Elem = u32>,
771 S2: Data<Elem = f64>,
772 D1: Dimension,
773 D2: Dimension,
774{
775 // Check that arrays have the same shape
776 if y_true.shape() != y_score.shape() {
777 return Err(MetricsError::InvalidInput(format!(
778 "y_true and y_score have different shapes: {:?} vs {:?}",
779 y_true.shape(),
780 y_score.shape()
781 )));
782 }
783
784 let n_samples = y_true.len();
785 if n_samples == 0 {
786 return Err(MetricsError::InvalidInput(
787 "Empty arrays provided".to_string(),
788 ));
789 }
790
791 // Compute the number of positive and negative samples
792 let mut n_pos = 0;
793 let mut n_neg = 0;
794 for &yt in y_true.iter() {
795 if yt == 1 {
796 n_pos += 1;
797 } else {
798 n_neg += 1;
799 }
800 }
801
802 if n_pos == 0 || n_neg == 0 {
803 return Err(MetricsError::InvalidInput(
804 "ROC AUC _score is not defined when only one class is present".to_string(),
805 ));
806 }
807
808 // Collect scores and _true labels
809 let mut scores_and_labels = Vec::with_capacity(n_samples);
810 for (yt, ys) in y_true.iter().zip(y_score.iter()) {
811 scores_and_labels.push((*ys, *yt));
812 }
813
814 // Sort scores in descending order
815 scores_and_labels.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
816
817 // Compute AUC by summing up the trapezoids
818 let mut auc = 0.0;
819 let mut false_positive = 0;
820 let mut true_positive = 0;
821 let mut last_false_positive = 0;
822 let mut last_true_positive = 0;
823 let mut last_score = f64::INFINITY;
824
825 for (score, label) in scores_and_labels {
826 if score != last_score {
827 // Add the area of the trapezoid
828 auc += (false_positive - last_false_positive) as f64
829 * (true_positive + last_true_positive) as f64
830 / 2.0;
831 last_score = score;
832 last_false_positive = false_positive;
833 last_true_positive = true_positive;
834 }
835
836 if label == 1 {
837 true_positive += 1;
838 } else {
839 false_positive += 1;
840 }
841 }
842
843 // Add the final trapezoid
844 auc += (n_neg - last_false_positive) as f64 * (true_positive + last_true_positive) as f64 / 2.0;
845
846 // Normalize
847 auc /= (n_pos * n_neg) as f64;
848
849 Ok(auc)
850}
851
852/// Computes the lift chart values for binary classification
853///
854/// The lift chart shows how much better a model performs compared to a random model.
855/// It is particularly useful in marketing and customer targeting applications.
856///
857/// # Arguments
858///
859/// * `y_true` - Ground truth binary labels (0 or 1)
860/// * `y_score` - Predicted probabilities for the positive class
861/// * `n_bins` - Number of bins for the lift chart
862///
863/// # Returns
864///
865/// * A tuple containing three arrays:
866/// * `percentiles` - The percentiles used (0-100)
867/// * `lift_values` - The lift values for each percentile
868/// * `cum_gains` - Cumulative gains values (for gain chart)
869///
870/// # Examples
871///
872/// ```
873/// use scirs2_core::ndarray::array;
874/// use scirs2_metrics::classification::lift_chart;
875///
876/// let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
877/// let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
878///
879/// let (percentiles, lift_values, cum_gains) = lift_chart(&y_true, &y_score, 10).unwrap();
880/// ```
881#[allow(dead_code)]
882pub fn lift_chart<S1, S2, D1, D2>(
883 y_true: &ArrayBase<S1, D1>,
884 y_score: &ArrayBase<S2, D2>,
885 n_bins: usize,
886) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>)>
887where
888 S1: Data<Elem = u32>,
889 S2: Data<Elem = f64>,
890 D1: Dimension,
891 D2: Dimension,
892{
893 // Check that arrays have the same shape
894 if y_true.shape() != y_score.shape() {
895 return Err(MetricsError::InvalidInput(format!(
896 "y_true and y_score have different shapes: {:?} vs {:?}",
897 y_true.shape(),
898 y_score.shape()
899 )));
900 }
901
902 let n_samples = y_true.len();
903 if n_samples == 0 {
904 return Err(MetricsError::InvalidInput(
905 "Empty arrays provided".to_string(),
906 ));
907 }
908
909 // Validate y_true contains only binary values
910 for yt in y_true.iter() {
911 if *yt != 0 && *yt != 1 {
912 return Err(MetricsError::InvalidInput(
913 "y_true must contain only binary values (0 or 1)".to_string(),
914 ));
915 }
916 }
917
918 // Validate n_bins
919 if n_bins < 1 {
920 return Err(MetricsError::InvalidInput(
921 "n_bins must be at least 1".to_string(),
922 ));
923 }
924
925 // Compute the overall positive rate (baseline)
926 let n_positives = y_true.iter().filter(|&&y| y == 1).count();
927 if n_positives == 0 || n_positives == n_samples {
928 return Err(MetricsError::InvalidInput(
929 "y_true must contain both positive and negative samples".to_string(),
930 ));
931 }
932 let baseline_rate = n_positives as f64 / n_samples as f64;
933
934 // Pair scores with _true labels and sort by scores in descending order
935 let mut paired_data: Vec<(f64, u32)> = y_score
936 .iter()
937 .zip(y_true.iter())
938 .map(|(&_score, &label)| (_score, label))
939 .collect();
940 paired_data.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
941
942 // Calculate percentiles, lift values, and cumulative gains
943 let bin_size = n_samples / n_bins;
944 let mut percentiles = Vec::with_capacity(n_bins);
945 let mut lift_values = Vec::with_capacity(n_bins);
946 let mut cum_gains = Vec::with_capacity(n_bins);
947
948 for i in 0..n_bins {
949 // Calculate percentile
950 let percentile = (i + 1) as f64 * 100.0 / n_bins as f64;
951
952 // Calculate number of samples to consider (based on percentile)
953 let n_considered = if i == n_bins - 1 {
954 // Include all samples in the last bin
955 n_samples
956 } else {
957 (i + 1) * bin_size
958 };
959
960 // Count positives in this subset
961 let positives_in_bin = paired_data[0..n_considered]
962 .iter()
963 .filter(|(_, label)| *label == 1)
964 .count();
965
966 // Calculate lift and cumulative gain
967 let bin_rate = positives_in_bin as f64 / n_considered as f64;
968 let lift = bin_rate / baseline_rate;
969 let cum_gain = positives_in_bin as f64 / n_positives as f64;
970
971 percentiles.push(percentile);
972 lift_values.push(lift);
973 cum_gains.push(cum_gain);
974 }
975
976 Ok((
977 Array1::from(percentiles),
978 Array1::from(lift_values),
979 Array1::from(cum_gains),
980 ))
981}
982
983/// Computes the gain chart values for binary classification
984///
985/// The gain chart (or cumulative gains chart) shows the percentage of positive
986/// outcomes captured at each percentile when observations are ranked by predicted probability.
987///
988/// # Arguments
989///
990/// * `y_true` - Ground truth binary labels (0 or 1)
991/// * `y_score` - Predicted probabilities for the positive class
992/// * `n_bins` - Number of bins for the gain chart
993///
994/// # Returns
995///
996/// * A tuple containing two arrays:
997/// * `percentiles` - The percentiles used (0-100)
998/// * `cum_gains` - Cumulative gains values at each percentile
999///
1000/// # Examples
1001///
1002/// ```
1003/// use scirs2_core::ndarray::array;
1004/// use scirs2_metrics::classification::gain_chart;
1005///
1006/// let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
1007/// let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
1008///
1009/// let (percentiles, cum_gains) = gain_chart(&y_true, &y_score, 10).unwrap();
1010/// ```
1011#[allow(dead_code)]
1012pub fn gain_chart<S1, S2, D1, D2>(
1013 y_true: &ArrayBase<S1, D1>,
1014 y_score: &ArrayBase<S2, D2>,
1015 n_bins: usize,
1016) -> Result<(Array1<f64>, Array1<f64>)>
1017where
1018 S1: Data<Elem = u32>,
1019 S2: Data<Elem = f64>,
1020 D1: Dimension,
1021 D2: Dimension,
1022{
1023 // Reuse lift_chart function to get the data
1024 let (percentiles, _lift_values, cum_gains) = lift_chart(y_true, y_score, n_bins)?;
1025 Ok((percentiles, cum_gains))
1026}
1027
1028/// Generates a text report showing the main classification metrics
1029///
1030/// # Arguments
1031///
1032/// * `y_true` - Ground truth (correct) target values
1033/// * `y_pred` - Estimated targets as returned by a classifier
1034/// * `labels` - Optional list of label values to include in the report
1035///
1036/// # Returns
1037///
1038/// * A string containing the classification report
1039///
1040/// # Examples
1041///
1042/// ```
1043/// use scirs2_core::ndarray::array;
1044/// use scirs2_metrics::classification::classification_report;
1045///
1046/// let y_true = array![0, 1, 2, 0, 1, 2];
1047/// let y_pred = array![0, 2, 1, 0, 0, 2];
1048///
1049/// let report = classification_report(&y_true, &y_pred, None).unwrap();
1050/// println!("{}", report);
1051/// ```
1052#[allow(dead_code)]
1053pub fn classification_report<T, S1, S2, D1, D2>(
1054 y_true: &ArrayBase<S1, D1>,
1055 y_pred: &ArrayBase<S2, D2>,
1056 labels: Option<&[T]>,
1057) -> Result<String>
1058where
1059 T: PartialEq + NumCast + Clone + Ord + std::hash::Hash + std::fmt::Debug,
1060 S1: Data<Elem = T>,
1061 S2: Data<Elem = T>,
1062 D1: Dimension,
1063 D2: Dimension,
1064{
1065 // Get confusion matrix
1066 let (cm, classes) = confusion_matrix(y_true, y_pred, labels)?;
1067
1068 // Prepare report
1069 let mut report = String::new();
1070 report.push_str(" precision recall f1-score support\n\n");
1071
1072 let n_classes = classes.len();
1073 let mut total_precision = 0.0;
1074 let mut total_recall = 0.0;
1075 let mut total_f1 = 0.0;
1076 let mut total_support = 0;
1077
1078 for i in 0..n_classes {
1079 let class_label = format!("{:?}", classes[i]);
1080
1081 // Calculate metrics for this class
1082 let true_positives = cm[[i, i]];
1083 let false_positives = cm.column(i).sum() - true_positives;
1084 let false_negatives = cm.row(i).sum() - true_positives;
1085 let support = cm.row(i).sum();
1086
1087 let precision = if true_positives + false_positives == 0 {
1088 0.0
1089 } else {
1090 true_positives as f64 / (true_positives + false_positives) as f64
1091 };
1092
1093 let recall = if true_positives + false_negatives == 0 {
1094 0.0
1095 } else {
1096 true_positives as f64 / (true_positives + false_negatives) as f64
1097 };
1098
1099 let f1 = if precision + recall == 0.0 {
1100 0.0
1101 } else {
1102 2.0 * precision * recall / (precision + recall)
1103 };
1104
1105 // Add to totals
1106 total_precision += precision;
1107 total_recall += recall;
1108 total_f1 += f1;
1109 total_support += support as usize;
1110
1111 // Add line to report
1112 report.push_str(&format!(
1113 "{class_label:>14} {precision:9.2} {recall:9.2} {f1:9.2} {support:9}\n"
1114 ));
1115 }
1116
1117 report.push('\n');
1118
1119 // Calculate averages
1120 let avg_precision = total_precision / n_classes as f64;
1121 let avg_recall = total_recall / n_classes as f64;
1122 let avg_f1 = total_f1 / n_classes as f64;
1123
1124 // Add averages to report
1125 report.push_str(&format!(
1126 " avg / total {avg_precision:9.2} {avg_recall:9.2} {avg_f1:9.2} {total_support:9}\n"
1127 ));
1128
1129 Ok(report)
1130}
1131
1132#[cfg(test)]
1133mod tests {
1134 use super::*;
1135 use approx::assert_abs_diff_eq;
1136 use scirs2_core::ndarray::array;
1137
1138 #[test]
1139 fn test_accuracy_score() {
1140 let y_true = array![0, 1, 2, 3];
1141 let y_pred = array![0, 2, 1, 3];
1142
1143 let acc = accuracy_score(&y_true, &y_pred).unwrap();
1144 assert_abs_diff_eq!(acc, 0.5, epsilon = 1e-10); // 2 out of 4 are correct
1145 }
1146
1147 #[test]
1148 fn test_confusion_matrix() {
1149 let y_true = array![0, 1, 2, 0, 1, 2];
1150 let y_pred = array![0, 2, 1, 0, 0, 2];
1151
1152 let (cm, classes) = confusion_matrix(&y_true, &y_pred, None).unwrap();
1153
1154 assert_eq!(cm.shape(), &[3, 3]);
1155 assert_eq!(classes.len(), 3);
1156
1157 // Expected confusion matrix:
1158 // [[2, 0, 0],
1159 // [1, 0, 1],
1160 // [0, 1, 1]]
1161 assert_eq!(cm[[0, 0]], 2); // True 0, predicted 0
1162 assert_eq!(cm[[1, 0]], 1); // True 1, predicted 0
1163 assert_eq!(cm[[1, 2]], 1); // True 1, predicted 2
1164 assert_eq!(cm[[2, 1]], 1); // True 2, predicted 1
1165 assert_eq!(cm[[2, 2]], 1); // True 2, predicted 2
1166 }
1167
1168 #[test]
1169 fn test_precision_recall_f1() {
1170 let y_true = array![0, 1, 0, 0, 1, 1];
1171 let y_pred = array![0, 0, 1, 0, 1, 1];
1172
1173 let precision = precision_score(&y_true, &y_pred, 1).unwrap();
1174 let recall = recall_score(&y_true, &y_pred, 1).unwrap();
1175 let f1 = f1_score(&y_true, &y_pred, 1).unwrap();
1176
1177 // precision = 2/3, recall = 2/3, f1 = 2*2/3*2/3 / (2/3 + 2/3) = 2/3
1178 assert_abs_diff_eq!(precision, 2.0 / 3.0, epsilon = 1e-10);
1179 assert_abs_diff_eq!(recall, 2.0 / 3.0, epsilon = 1e-10);
1180 assert_abs_diff_eq!(f1, 2.0 / 3.0, epsilon = 1e-10);
1181 }
1182
1183 #[test]
1184 fn test_fbeta_score() {
1185 let y_true = array![0, 1, 0, 0, 1, 1];
1186 let y_pred = array![0, 0, 1, 0, 1, 1];
1187
1188 // F1 score (beta = 1.0)
1189 let f1 = fbeta_score(&y_true, &y_pred, 1, 1.0).unwrap();
1190 assert_abs_diff_eq!(f1, 2.0 / 3.0, epsilon = 1e-10);
1191
1192 // F0.5 score (weighs precision higher than recall)
1193 let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
1194 // With beta=0.5, beta²=0.25
1195 // F0.5 = (1+0.25) * 2/3 * 2/3 / (0.25*2/3 + 2/3) = 1.25 * 4/9 / (1/6 + 2/3) = 5/9 / 5/6 = 5/9 * 6/5 = 30/45 = 2/3
1196 assert_abs_diff_eq!(f_half, 2.0 / 3.0, epsilon = 1e-10);
1197
1198 // F2 score (weighs recall higher than precision)
1199 let f_two = fbeta_score(&y_true, &y_pred, 1, 2.0).unwrap();
1200 // With beta=2.0, beta²=4.0
1201 // F2 = (1+4) * 2/3 * 2/3 / (4*2/3 + 2/3) = 5 * 4/9 / (8/3 + 2/3) = 20/9 / 10/3 = 20/9 * 3/10 = 60/90 = 2/3
1202 assert_abs_diff_eq!(f_two, 2.0 / 3.0, epsilon = 1e-10);
1203
1204 // This example has equal precision and recall, so all F-beta scores are the same
1205 // Let's try a more interesting example with different precision and recall
1206 let y_true = array![1, 1, 1, 1, 1, 0, 0, 0, 0, 0];
1207 let y_pred = array![1, 1, 1, 0, 0, 0, 0, 0, 1, 1];
1208 // precision = 3/5 = 0.6, recall = 3/5 = 0.6
1209
1210 // F1 score (beta = 1.0)
1211 let f1 = fbeta_score(&y_true, &y_pred, 1, 1.0).unwrap();
1212 assert_abs_diff_eq!(f1, 0.6, epsilon = 1e-10);
1213
1214 // F0.5 score (weighs precision higher than recall)
1215 let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
1216 // With beta=0.5, beta²=0.25
1217 // F0.5 = (1+0.25) * 0.6 * 0.6 / (0.25*0.6 + 0.6) = 1.25 * 0.36 / (0.15 + 0.6) = 0.45 / 0.75 = 0.6
1218 assert_abs_diff_eq!(f_half, 0.6, epsilon = 1e-10);
1219
1220 // F2 score (weighs recall higher than precision)
1221 let f_two = fbeta_score(&y_true, &y_pred, 1, 2.0).unwrap();
1222 // With beta=2.0, beta²=4.0
1223 // F2 = (1+4) * 0.6 * 0.6 / (4*0.6 + 0.6) = 5 * 0.36 / (2.4 + 0.6) = 1.8 / 3.0 = 0.6
1224 assert_abs_diff_eq!(f_two, 0.6, epsilon = 1e-10);
1225
1226 // Let's try one more with different precision and recall
1227 let y_true = array![1, 1, 1, 1, 0, 0, 0, 0];
1228 let y_pred = array![1, 1, 0, 0, 0, 0, 1, 1];
1229 // precision = 2/4 = 0.5, recall = 2/4 = 0.5
1230
1231 // F0.5 score (weighs precision higher than recall)
1232 let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
1233 // With beta=0.5, beta²=0.25
1234 // F0.5 = (1+0.25) * 0.5 * 0.5 / (0.25*0.5 + 0.5) = 1.25 * 0.25 / (0.125 + 0.5) = 0.3125 / 0.625 = 0.5
1235 assert_abs_diff_eq!(f_half, 0.5, epsilon = 1e-10);
1236 }
1237
1238 #[test]
1239 fn test_log_loss() {
1240 let y_true = array![0, 1, 1, 0];
1241 let y_prob = array![0.1, 0.9, 0.8, 0.3];
1242
1243 let loss = binary_log_loss(&y_true, &y_prob, 1e-15).unwrap();
1244 // Expected loss: -[(log(0.9) + log(0.8) + log(0.9) + log(0.7))/4]
1245 let expected =
1246 -(((1.0_f64 - 0.1).ln() + 0.9_f64.ln() + 0.8_f64.ln() + (1.0_f64 - 0.3).ln()) / 4.0);
1247 assert_abs_diff_eq!(loss, expected, epsilon = 1e-10);
1248 }
1249
1250 #[test]
1251 fn test_roc_auc() {
1252 // Perfect separation
1253 let y_true = array![0, 0, 1, 1];
1254 let y_score = array![0.1, 0.2, 0.8, 0.9];
1255 let auc = roc_auc_score(&y_true, &y_score).unwrap();
1256 assert_abs_diff_eq!(auc, 1.0, epsilon = 1e-10);
1257
1258 // Random
1259 let y_true = array![0, 1, 0, 1];
1260 let y_score = array![0.5, 0.5, 0.5, 0.5];
1261 let auc = roc_auc_score(&y_true, &y_score).unwrap();
1262 assert_abs_diff_eq!(auc, 0.5, epsilon = 1e-10);
1263 }
1264
1265 #[test]
1266 fn test_lift_chart() {
1267 let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
1268 let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
1269
1270 // Test with 5 bins
1271 let (percentiles, lift_values, cum_gains) = lift_chart(&y_true, &y_score, 5).unwrap();
1272
1273 // Verify dimensions
1274 assert_eq!(percentiles.len(), 5);
1275 assert_eq!(lift_values.len(), 5);
1276 assert_eq!(cum_gains.len(), 5);
1277
1278 // Verify percentiles
1279 assert_abs_diff_eq!(percentiles[0], 20.0, epsilon = 1e-10);
1280 assert_abs_diff_eq!(percentiles[4], 100.0, epsilon = 1e-10);
1281
1282 // The first 20% contains the highest scored cases, which should be mostly positive
1283 // This should give a lift value higher than 1
1284 assert!(lift_values[0] > 1.0);
1285
1286 // The cumulative gains at 100% should be 1.0 (all positives)
1287 assert_abs_diff_eq!(cum_gains[4], 1.0, epsilon = 1e-10);
1288 }
1289
1290 #[test]
1291 fn test_gain_chart() {
1292 let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
1293 let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
1294
1295 // Test with 5 bins
1296 let (percentiles, cum_gains) = gain_chart(&y_true, &y_score, 5).unwrap();
1297
1298 // Verify dimensions
1299 assert_eq!(percentiles.len(), 5);
1300 assert_eq!(cum_gains.len(), 5);
1301
1302 // Verify percentiles
1303 assert_abs_diff_eq!(percentiles[0], 20.0, epsilon = 1e-10);
1304 assert_abs_diff_eq!(percentiles[4], 100.0, epsilon = 1e-10);
1305
1306 // Cumulative gains should be non-decreasing
1307 for i in 1..cum_gains.len() {
1308 assert!(cum_gains[i] >= cum_gains[i - 1]);
1309 }
1310
1311 // The cumulative gains at 100% should be 1.0 (all positives)
1312 assert_abs_diff_eq!(cum_gains[4], 1.0, epsilon = 1e-10);
1313 }
1314}