scirs2_metrics/
classification.rs

1//! Classification metrics module
2//!
3//! This module provides functions for evaluating classification models, including
4//! accuracy, precision, recall, F1 score, ROC AUC, and advanced metrics.
5//!
6//! ## Basic Metrics
7//!
8//! Basic classification metrics include accuracy, precision, recall, and F1 score.
9//!
10//! ## Advanced Metrics
11//!
12//! Advanced metrics include Matthews Correlation Coefficient, balanced accuracy,
13//! Cohen's kappa, Brier score, Jaccard similarity, and Hamming loss.
14//!
15//! ```
16//! use scirs2_core::ndarray::array;
17//! use scirs2_metrics::classification::advanced::{matthews_corrcoef, balanced_accuracy_score};
18//!
19//! let y_true = array![0, 1, 2, 0, 1, 2];
20//! let y_pred = array![0, 2, 1, 0, 0, 2];
21//!
22//! let mcc = matthews_corrcoef(&y_true, &y_pred).unwrap();
23//! let bal_acc = balanced_accuracy_score(&y_true, &y_pred).unwrap();
24//! ```
25//!
26//! ## One-vs-One Metrics
27//!
28//! One-vs-One metrics are useful for evaluating multi-class classification problems by
29//! considering each pair of classes separately.
30//!
31//! ```
32//! use scirs2_core::ndarray::array;
33//! use scirs2_metrics::classification::one_vs_one::{one_vs_one_accuracy, one_vs_one_f1_score};
34//!
35//! let y_true = array![0, 1, 2, 0, 1, 2];
36//! let y_pred = array![0, 2, 1, 0, 0, 2];
37//!
38//! let ovo_acc = one_vs_one_accuracy(&y_true, &y_pred).unwrap();
39//! let f1_scores = one_vs_one_f1_score(&y_true, &y_pred).unwrap();
40//! ```
41
42pub mod advanced;
43pub mod curves;
44pub mod one_vs_one;
45pub mod threshold;
46pub mod threshold_analyzer;
47
48use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Dimension};
49use scirs2_core::numeric::NumCast;
50
51use crate::error::{MetricsError, Result};
52
53/// Calculates accuracy score, the fraction of correctly classified samples
54///
55/// # Mathematical Formulation
56///
57/// The accuracy score is defined as:
58///
59/// ```text
60/// Accuracy = (Number of Correct Predictions) / (Total Number of Predictions)
61///          = (TP + TN) / (TP + TN + FP + FN)
62/// ```
63///
64/// Where:
65/// - TP = True Positives
66/// - TN = True Negatives  
67/// - FP = False Positives
68/// - FN = False Negatives
69///
70/// For multi-class classification, accuracy is simply:
71///
72/// ```text
73/// Accuracy = (1/n) * Σ I(ŷᵢ = yᵢ)
74/// ```
75///
76/// Where:
77/// - n = total number of samples
78/// - I(·) = indicator function (1 if condition is true, 0 otherwise)
79/// - ŷᵢ = predicted label for sample i
80/// - yᵢ = true label for sample i
81///
82/// # Range
83///
84/// Accuracy is bounded between 0 and 1, where:
85/// - 0 = worst possible accuracy (all predictions wrong)
86/// - 1 = perfect accuracy (all predictions correct)
87/// - 0.5 = random guessing for balanced binary classification
88///
89/// # Arguments
90///
91/// * `y_true` - Ground truth (correct) labels
92/// * `y_pred` - Predicted labels, as returned by a classifier
93///
94/// # Returns
95///
96/// * The fraction of correctly classified samples (float)
97///
98/// # Examples
99///
100/// ```
101/// use scirs2_core::ndarray::array;
102/// use scirs2_metrics::classification::accuracy_score;
103///
104/// let y_true = array![0, 1, 2, 3];
105/// let y_pred = array![0, 2, 1, 3];
106///
107/// let acc = accuracy_score(&y_true, &y_pred).unwrap();
108/// assert!((acc - 0.5).abs() < 1e-10); // 2 out of 4 are correct
109/// ```
110#[allow(dead_code)]
111pub fn accuracy_score<T, S1, S2, D1, D2>(
112    y_true: &ArrayBase<S1, D1>,
113    y_pred: &ArrayBase<S2, D2>,
114) -> Result<f64>
115where
116    T: PartialEq + NumCast + Clone,
117    S1: Data<Elem = T>,
118    S2: Data<Elem = T>,
119    D1: Dimension,
120    D2: Dimension,
121{
122    // Check that arrays have the same shape
123    if y_true.shape() != y_pred.shape() {
124        return Err(MetricsError::InvalidInput(format!(
125            "y_true and y_pred have different shapes: {:?} vs {:?}",
126            y_true.shape(),
127            y_pred.shape()
128        )));
129    }
130
131    let n_samples = y_true.len();
132    if n_samples == 0 {
133        return Err(MetricsError::InvalidInput(
134            "Empty arrays provided".to_string(),
135        ));
136    }
137
138    // Count correct predictions
139    let mut n_correct = 0;
140    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
141        if yt == yp {
142            n_correct += 1;
143        }
144    }
145
146    Ok(n_correct as f64 / n_samples as f64)
147}
148
149/// Calculates a confusion matrix to evaluate the accuracy of a classification
150///
151/// # Arguments
152///
153/// * `y_true` - Ground truth (correct) labels
154/// * `y_pred` - Predicted labels, as returned by a classifier
155/// * `labels` - Optional list of label values to index the matrix. This may be
156///   used to reorder or select a subset of labels. If None, those that appear
157///   at least once in y_true or y_pred are used in sorted order.
158///
159/// # Returns
160///
161/// * Confusion matrix (Array2<u64>)
162/// * Vector of classes in order (Array1<T>)
163///
164/// # Examples
165///
166/// ```
167/// use scirs2_core::ndarray::array;
168/// use scirs2_metrics::classification::confusion_matrix;
169///
170/// let y_true = array![0, 1, 2, 0, 1, 2];
171/// let y_pred = array![0, 2, 1, 0, 0, 2];
172///
173/// let (cm, classes) = confusion_matrix(&y_true, &y_pred, None).unwrap();
174/// // Expected confusion matrix:
175/// // [[2, 0, 0],
176/// //  [1, 0, 1],
177/// //  [0, 1, 1]]
178///
179/// assert_eq!(cm[[0, 0]], 2); // True 0, predicted 0
180/// assert_eq!(cm[[1, 0]], 1); // True 1, predicted 0
181/// assert_eq!(cm[[1, 2]], 1); // True 1, predicted 2
182/// assert_eq!(cm[[2, 1]], 1); // True 2, predicted 1
183/// assert_eq!(cm[[2, 2]], 1); // True 2, predicted 2
184/// ```
185#[allow(dead_code)]
186pub fn confusion_matrix<T, S1, S2, D1, D2>(
187    y_true: &ArrayBase<S1, D1>,
188    y_pred: &ArrayBase<S2, D2>,
189    labels: Option<&[T]>,
190) -> Result<(Array2<u64>, Array1<T>)>
191where
192    T: PartialEq + NumCast + Clone + Ord + std::hash::Hash + std::fmt::Debug,
193    S1: Data<Elem = T>,
194    S2: Data<Elem = T>,
195    D1: Dimension,
196    D2: Dimension,
197{
198    // Check that arrays have the same shape
199    if y_true.shape() != y_pred.shape() {
200        return Err(MetricsError::InvalidInput(format!(
201            "y_true and y_pred have different shapes: {:?} vs {:?}",
202            y_true.shape(),
203            y_pred.shape()
204        )));
205    }
206
207    let n_samples = y_true.len();
208    if n_samples == 0 {
209        return Err(MetricsError::InvalidInput(
210            "Empty arrays provided".to_string(),
211        ));
212    }
213
214    // Determine the classes
215    let classes = if let Some(labels) = labels {
216        let mut cls = Vec::with_capacity(labels.len());
217        for label in labels {
218            cls.push(label.clone());
219        }
220        cls
221    } else {
222        let mut cls = std::collections::BTreeSet::new();
223        for yt in y_true.iter() {
224            cls.insert(yt.clone());
225        }
226        for yp in y_pred.iter() {
227            cls.insert(yp.clone());
228        }
229        cls.into_iter().collect()
230    };
231
232    // Create the confusion matrix
233    let n_classes = classes.len();
234    let mut cm = Array2::zeros((n_classes, n_classes));
235
236    // Create a map from class to index
237    let mut class_to_idx = std::collections::HashMap::new();
238    for (i, c) in classes.iter().enumerate() {
239        class_to_idx.insert(c, i);
240    }
241
242    // Fill the confusion matrix
243    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
244        if let (Some(&i), Some(&j)) = (class_to_idx.get(yt), class_to_idx.get(yp)) {
245            cm[[i, j]] += 1;
246        }
247    }
248
249    Ok((cm, Array1::from(classes)))
250}
251
252/// Calculates the precision score for binary classification
253///
254/// # Mathematical Formulation
255///
256/// Precision is defined as:
257///
258/// ```text
259/// Precision = TP / (TP + FP)
260/// ```
261///
262/// Where:
263/// - TP = True Positives (correctly predicted positive cases)
264/// - FP = False Positives (incorrectly predicted as positive)
265///
266/// Alternatively, precision can be expressed as:
267///
268/// ```text
269/// Precision = P(y_true = positive | ŷ = positive)
270/// ```
271///
272/// This represents the probability that a sample is actually positive
273/// given that the classifier predicted it as positive.
274///
275/// # Interpretation
276///
277/// Precision answers the question: "Of all the samples the classifier
278/// predicted as positive, how many were actually positive?"
279///
280/// - High precision means low false positive rate
281/// - Precision = 1.0 means no false positives
282/// - Precision = 0.0 means no true positives (all positive predictions are wrong)
283///
284/// # Range
285///
286/// Precision is bounded between 0 and 1:
287/// - 0 = worst precision (no correct positive predictions)
288/// - 1 = perfect precision (no false positive predictions)
289///
290/// # Use Cases
291///
292/// High precision is important when the cost of false positives is high,
293/// such as:
294/// - Medical diagnosis (avoid unnecessary treatments)
295/// - Spam detection (avoid blocking legitimate emails)
296/// - Quality control (avoid rejecting good products)
297///
298/// # Arguments
299///
300/// * `y_true` - Ground truth (correct) binary labels
301/// * `y_pred` - Predicted binary labels, as returned by a classifier
302/// * `pos_label` - The label to report as positive class
303///
304/// # Returns
305///
306/// * The precision score (float between 0.0 and 1.0)
307///
308/// # Examples
309///
310/// ```
311/// use scirs2_core::ndarray::array;
312/// use scirs2_metrics::classification::precision_score;
313///
314/// let y_true = array![0, 1, 0, 0, 1, 1];
315/// let y_pred = array![0, 0, 1, 0, 1, 1];
316///
317/// let precision = precision_score(&y_true, &y_pred, 1).unwrap();
318/// // There are 2 true positives and 1 false positive
319/// assert!((precision - 2.0/3.0).abs() < 1e-10);
320/// ```
321#[allow(dead_code)]
322pub fn precision_score<T, S1, S2, D1, D2>(
323    y_true: &ArrayBase<S1, D1>,
324    y_pred: &ArrayBase<S2, D2>,
325    pos_label: T,
326) -> Result<f64>
327where
328    T: PartialEq + NumCast + Clone,
329    S1: Data<Elem = T>,
330    S2: Data<Elem = T>,
331    D1: Dimension,
332    D2: Dimension,
333{
334    // Check that arrays have the same shape
335    if y_true.shape() != y_pred.shape() {
336        return Err(MetricsError::InvalidInput(format!(
337            "y_true and y_pred have different shapes: {:?} vs {:?}",
338            y_true.shape(),
339            y_pred.shape()
340        )));
341    }
342
343    let n_samples = y_true.len();
344    if n_samples == 0 {
345        return Err(MetricsError::InvalidInput(
346            "Empty arrays provided".to_string(),
347        ));
348    }
349
350    // Count _true positives and false positives
351    let mut true_positives = 0;
352    let mut false_positives = 0;
353
354    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
355        if yp == &pos_label {
356            if yt == yp {
357                true_positives += 1;
358            } else {
359                false_positives += 1;
360            }
361        }
362    }
363
364    // Calculate precision
365    if true_positives + false_positives == 0 {
366        Ok(0.0) // No positive predictions, precision is 0
367    } else {
368        Ok(true_positives as f64 / (true_positives + false_positives) as f64)
369    }
370}
371
372/// Calculates the recall score for binary classification
373///
374/// # Mathematical Formulation
375///
376/// Recall (also known as sensitivity or true positive rate) is defined as:
377///
378/// ```text
379/// Recall = TP / (TP + FN)
380/// ```
381///
382/// Where:
383/// - TP = True Positives (correctly predicted positive cases)
384/// - FN = False Negatives (incorrectly predicted as negative)
385///
386/// Alternatively, recall can be expressed as:
387///
388/// ```text
389/// Recall = P(ŷ = positive | y_true = positive)
390/// ```
391///
392/// This represents the probability that the classifier predicts positive
393/// given that the sample is actually positive.
394///
395/// # Interpretation
396///
397/// Recall answers the question: "Of all the actual positive samples,
398/// how many did the classifier correctly identify?"
399///
400/// - High recall means low false negative rate
401/// - Recall = 1.0 means no false negatives (all positive cases found)
402/// - Recall = 0.0 means no true positives (all positive cases missed)
403///
404/// # Range
405///
406/// Recall is bounded between 0 and 1:
407/// - 0 = worst recall (no positive cases identified)
408/// - 1 = perfect recall (all positive cases identified)
409///
410/// # Use Cases
411///
412/// High recall is important when the cost of false negatives is high,
413/// such as:
414/// - Medical screening (avoid missing diseases)
415/// - Security systems (avoid missing threats)
416/// - Search engines (avoid missing relevant results)
417///
418/// # Relationship to Other Metrics
419///
420/// Recall is complementary to precision:
421/// - Precision focuses on minimizing false positives
422/// - Recall focuses on minimizing false negatives
423/// - There's often a trade-off between precision and recall
424///
425/// # Arguments
426///
427/// * `y_true` - Ground truth (correct) binary labels
428/// * `y_pred` - Predicted binary labels, as returned by a classifier
429/// * `pos_label` - The label to report as positive class
430///
431/// # Returns
432///
433/// * The recall score (float between 0.0 and 1.0)
434///
435/// # Examples
436///
437/// ```
438/// use scirs2_core::ndarray::array;
439/// use scirs2_metrics::classification::recall_score;
440///
441/// let y_true = array![0, 1, 0, 0, 1, 1];
442/// let y_pred = array![0, 0, 1, 0, 1, 1];
443///
444/// let recall = recall_score(&y_true, &y_pred, 1).unwrap();
445/// // There are 2 true positives and 1 false negative
446/// assert!((recall - 2.0/3.0).abs() < 1e-10);
447/// ```
448#[allow(dead_code)]
449pub fn recall_score<T, S1, S2, D1, D2>(
450    y_true: &ArrayBase<S1, D1>,
451    y_pred: &ArrayBase<S2, D2>,
452    pos_label: T,
453) -> Result<f64>
454where
455    T: PartialEq + NumCast + Clone,
456    S1: Data<Elem = T>,
457    S2: Data<Elem = T>,
458    D1: Dimension,
459    D2: Dimension,
460{
461    // Check that arrays have the same shape
462    if y_true.shape() != y_pred.shape() {
463        return Err(MetricsError::InvalidInput(format!(
464            "y_true and y_pred have different shapes: {:?} vs {:?}",
465            y_true.shape(),
466            y_pred.shape()
467        )));
468    }
469
470    let n_samples = y_true.len();
471    if n_samples == 0 {
472        return Err(MetricsError::InvalidInput(
473            "Empty arrays provided".to_string(),
474        ));
475    }
476
477    // Count _true positives and false negatives
478    let mut true_positives = 0;
479    let mut false_negatives = 0;
480
481    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
482        if yt == &pos_label {
483            if yp == yt {
484                true_positives += 1;
485            } else {
486                false_negatives += 1;
487            }
488        }
489    }
490
491    // Calculate recall
492    if true_positives + false_negatives == 0 {
493        Ok(0.0) // No actual positives, recall is 0
494    } else {
495        Ok(true_positives as f64 / (true_positives + false_negatives) as f64)
496    }
497}
498
499/// Calculates the F1 score for binary classification
500///
501/// # Mathematical Formulation
502///
503/// The F1 score is the harmonic mean of precision and recall:
504///
505/// ```text
506/// F1 = 2 * (Precision × Recall) / (Precision + Recall)
507/// ```
508///
509/// Equivalently, it can be expressed in terms of confusion matrix elements:
510///
511/// ```text
512/// F1 = 2TP / (2TP + FP + FN)
513/// ```
514///
515/// Where:
516/// - TP = True Positives
517/// - FP = False Positives  
518/// - FN = False Negatives
519///
520/// # Harmonic vs Arithmetic Mean
521///
522/// The F1 score uses harmonic mean rather than arithmetic mean because:
523/// - Harmonic mean gives more weight to smaller values
524/// - If either precision or recall is low, F1 will be low
525/// - It penalizes extreme imbalances between precision and recall
526///
527/// ```text
528/// Arithmetic mean: (P + R) / 2
529/// Harmonic mean:   2PR / (P + R)
530/// ```
531///
532/// # Interpretation
533///
534/// The F1 score provides a single metric that balances precision and recall:
535/// - F1 = 1.0: Perfect precision and recall
536/// - F1 = 0.0: Either precision or recall (or both) is zero
537/// - F1 is closer to the smaller of precision and recall
538///
539/// # Range and Properties
540///
541/// - Range: [0, 1]
542/// - F1 ≤ min(Precision, Recall)
543/// - F1 = 0 if either Precision = 0 or Recall = 0
544/// - F1 approaches max(Precision, Recall) when they are similar
545///
546/// # Use Cases
547///
548/// F1 score is particularly useful when:
549/// - You need a single metric balancing precision and recall
550/// - Class distribution is imbalanced
551/// - Both false positives and false negatives are costly
552/// - You want to avoid optimizing for just one metric
553///
554/// # Relationship to F-beta Score
555///
556/// F1 is a special case of the F-beta score with β = 1:
557///
558/// ```text
559/// F_β = (1 + β²) × (Precision × Recall) / (β² × Precision + Recall)
560/// ```
561///
562/// When β = 1, this reduces to the F1 formula above.
563///
564/// # Arguments
565///
566/// * `y_true` - Ground truth (correct) binary labels
567/// * `y_pred` - Predicted binary labels, as returned by a classifier
568/// * `pos_label` - The label to report as positive class
569///
570/// # Returns
571///
572/// * The F1 score (float between 0.0 and 1.0)
573///
574/// # Examples
575///
576/// ```
577/// use scirs2_core::ndarray::array;
578/// use scirs2_metrics::classification::f1_score;
579///
580/// let y_true = array![0, 1, 0, 0, 1, 1];
581/// let y_pred = array![0, 0, 1, 0, 1, 1];
582///
583/// let f1 = f1_score(&y_true, &y_pred, 1).unwrap();
584/// ```
585#[allow(dead_code)]
586pub fn f1_score<T, S1, S2, D1, D2>(
587    y_true: &ArrayBase<S1, D1>,
588    y_pred: &ArrayBase<S2, D2>,
589    pos_label: T,
590) -> Result<f64>
591where
592    T: PartialEq + NumCast + Clone,
593    S1: Data<Elem = T>,
594    S2: Data<Elem = T>,
595    D1: Dimension,
596    D2: Dimension,
597{
598    // F1 score is a special case of fbeta_score with beta = 1.0
599    fbeta_score(y_true, y_pred, pos_label, 1.0)
600}
601
602/// Calculates the F-beta score for binary classification
603///
604/// The F-beta score is the weighted harmonic mean of precision and recall:
605/// `F-beta = (1 + beta^2) * (precision * recall) / ((beta^2 * precision) + recall)`
606///
607/// The beta parameter determines the weight of recall in the combined score:
608/// - beta < 1 gives more weight to precision
609/// - beta > 1 gives more weight to recall
610/// - beta = 1 gives equal weight to precision and recall (F1 score)
611///
612/// # Arguments
613///
614/// * `y_true` - Ground truth (correct) binary labels
615/// * `y_pred` - Predicted binary labels, as returned by a classifier
616/// * `pos_label` - The label to report as positive class
617/// * `beta` - The weight of recall relative to precision (must be positive)
618///
619/// # Returns
620///
621/// * The F-beta score (float between 0.0 and 1.0)
622///
623/// # Examples
624///
625/// ```
626/// use scirs2_core::ndarray::array;
627/// use scirs2_metrics::classification::fbeta_score;
628///
629/// let y_true = array![0, 1, 0, 0, 1, 1];
630/// let y_pred = array![0, 0, 1, 0, 1, 1];
631///
632/// // F0.5 score (weighs precision higher than recall)
633/// let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
634///
635/// // F2 score (weighs recall higher than precision)
636/// let f_two = fbeta_score(&y_true, &y_pred, 1, 2.0).unwrap();
637/// ```
638#[allow(dead_code)]
639pub fn fbeta_score<T, S1, S2, D1, D2>(
640    y_true: &ArrayBase<S1, D1>,
641    y_pred: &ArrayBase<S2, D2>,
642    pos_label: T,
643    beta: f64,
644) -> Result<f64>
645where
646    T: PartialEq + NumCast + Clone,
647    S1: Data<Elem = T>,
648    S2: Data<Elem = T>,
649    D1: Dimension,
650    D2: Dimension,
651{
652    if beta <= 0.0 {
653        return Err(MetricsError::InvalidInput(format!(
654            "beta must be positive, got {beta}"
655        )));
656    }
657
658    let precision = precision_score(y_true, y_pred, pos_label.clone())?;
659    let recall = recall_score(y_true, y_pred, pos_label)?;
660
661    if precision + recall == 0.0 {
662        return Ok(0.0);
663    }
664
665    let beta_squared = beta * beta;
666    Ok((1.0 + beta_squared) * precision * recall / ((beta_squared * precision) + recall))
667}
668
669/// Calculate binary log loss, also known as binary cross-entropy
670///
671/// This is the loss function used in (multinomial) logistic regression
672/// and neural networks (with softmax output).
673///
674/// # Arguments
675///
676/// * `y_true` - Ground truth binary labels
677/// * `y_prob` - Predicted probabilities for the positive class
678/// * `eps` - Small value to avoid log(0)
679///
680/// # Returns
681///
682/// * The log loss (float)
683///
684/// # Examples
685///
686/// ```
687/// use scirs2_core::ndarray::array;
688/// use scirs2_metrics::classification::binary_log_loss;
689///
690/// let y_true = array![0, 1, 1, 0];
691/// let y_prob = array![0.1, 0.9, 0.8, 0.3];
692///
693/// let loss = binary_log_loss(&y_true, &y_prob, 1e-15).unwrap();
694/// ```
695#[allow(dead_code)]
696pub fn binary_log_loss<S1, S2, D1, D2>(
697    y_true: &ArrayBase<S1, D1>,
698    y_prob: &ArrayBase<S2, D2>,
699    eps: f64,
700) -> Result<f64>
701where
702    S1: Data<Elem = u32>,
703    S2: Data<Elem = f64>,
704    D1: Dimension,
705    D2: Dimension,
706{
707    // Check that arrays have the same shape
708    if y_true.shape() != y_prob.shape() {
709        return Err(MetricsError::InvalidInput(format!(
710            "y_true and y_prob have different shapes: {:?} vs {:?}",
711            y_true.shape(),
712            y_prob.shape()
713        )));
714    }
715
716    let n_samples = y_true.len();
717    if n_samples == 0 {
718        return Err(MetricsError::InvalidInput(
719            "Empty arrays provided".to_string(),
720        ));
721    }
722
723    // Compute log loss
724    let mut loss = 0.0;
725    for (yt, yp) in y_true.iter().zip(y_prob.iter()) {
726        // Clip probability to avoid log(0)
727        let clipped_yp = yp.max(eps).min(1.0 - eps);
728
729        if *yt == 1 {
730            loss -= (clipped_yp).ln();
731        } else {
732            loss -= (1.0 - clipped_yp).ln();
733        }
734    }
735
736    Ok(loss / n_samples as f64)
737}
738
739/// Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
740///
741/// The ROC curve is created by plotting the true positive rate (TPR) against
742/// the false positive rate (FPR) at various threshold settings.
743///
744/// # Arguments
745///
746/// * `y_true` - Ground truth binary labels
747/// * `y_score` - Target scores (can be probability estimates of the positive class)
748///
749/// # Returns
750///
751/// * The ROC AUC score (float between 0.0 and 1.0)
752///
753/// # Examples
754///
755/// ```
756/// use scirs2_core::ndarray::array;
757/// use scirs2_metrics::classification::roc_auc_score;
758///
759/// let y_true = array![0, 0, 1, 1];
760/// let y_score = array![0.1, 0.4, 0.35, 0.8];
761///
762/// let auc = roc_auc_score(&y_true, &y_score).unwrap();
763/// ```
764#[allow(dead_code)]
765pub fn roc_auc_score<S1, S2, D1, D2>(
766    y_true: &ArrayBase<S1, D1>,
767    y_score: &ArrayBase<S2, D2>,
768) -> Result<f64>
769where
770    S1: Data<Elem = u32>,
771    S2: Data<Elem = f64>,
772    D1: Dimension,
773    D2: Dimension,
774{
775    // Check that arrays have the same shape
776    if y_true.shape() != y_score.shape() {
777        return Err(MetricsError::InvalidInput(format!(
778            "y_true and y_score have different shapes: {:?} vs {:?}",
779            y_true.shape(),
780            y_score.shape()
781        )));
782    }
783
784    let n_samples = y_true.len();
785    if n_samples == 0 {
786        return Err(MetricsError::InvalidInput(
787            "Empty arrays provided".to_string(),
788        ));
789    }
790
791    // Compute the number of positive and negative samples
792    let mut n_pos = 0;
793    let mut n_neg = 0;
794    for &yt in y_true.iter() {
795        if yt == 1 {
796            n_pos += 1;
797        } else {
798            n_neg += 1;
799        }
800    }
801
802    if n_pos == 0 || n_neg == 0 {
803        return Err(MetricsError::InvalidInput(
804            "ROC AUC _score is not defined when only one class is present".to_string(),
805        ));
806    }
807
808    // Collect scores and _true labels
809    let mut scores_and_labels = Vec::with_capacity(n_samples);
810    for (yt, ys) in y_true.iter().zip(y_score.iter()) {
811        scores_and_labels.push((*ys, *yt));
812    }
813
814    // Sort scores in descending order
815    scores_and_labels.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
816
817    // Compute AUC by summing up the trapezoids
818    let mut auc = 0.0;
819    let mut false_positive = 0;
820    let mut true_positive = 0;
821    let mut last_false_positive = 0;
822    let mut last_true_positive = 0;
823    let mut last_score = f64::INFINITY;
824
825    for (score, label) in scores_and_labels {
826        if score != last_score {
827            // Add the area of the trapezoid
828            auc += (false_positive - last_false_positive) as f64
829                * (true_positive + last_true_positive) as f64
830                / 2.0;
831            last_score = score;
832            last_false_positive = false_positive;
833            last_true_positive = true_positive;
834        }
835
836        if label == 1 {
837            true_positive += 1;
838        } else {
839            false_positive += 1;
840        }
841    }
842
843    // Add the final trapezoid
844    auc += (n_neg - last_false_positive) as f64 * (true_positive + last_true_positive) as f64 / 2.0;
845
846    // Normalize
847    auc /= (n_pos * n_neg) as f64;
848
849    Ok(auc)
850}
851
852/// Computes the lift chart values for binary classification
853///
854/// The lift chart shows how much better a model performs compared to a random model.
855/// It is particularly useful in marketing and customer targeting applications.
856///
857/// # Arguments
858///
859/// * `y_true` - Ground truth binary labels (0 or 1)
860/// * `y_score` - Predicted probabilities for the positive class
861/// * `n_bins` - Number of bins for the lift chart
862///
863/// # Returns
864///
865/// * A tuple containing three arrays:
866///   * `percentiles` - The percentiles used (0-100)
867///   * `lift_values` - The lift values for each percentile
868///   * `cum_gains` - Cumulative gains values (for gain chart)
869///
870/// # Examples
871///
872/// ```
873/// use scirs2_core::ndarray::array;
874/// use scirs2_metrics::classification::lift_chart;
875///
876/// let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
877/// let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
878///
879/// let (percentiles, lift_values, cum_gains) = lift_chart(&y_true, &y_score, 10).unwrap();
880/// ```
881#[allow(dead_code)]
882pub fn lift_chart<S1, S2, D1, D2>(
883    y_true: &ArrayBase<S1, D1>,
884    y_score: &ArrayBase<S2, D2>,
885    n_bins: usize,
886) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>)>
887where
888    S1: Data<Elem = u32>,
889    S2: Data<Elem = f64>,
890    D1: Dimension,
891    D2: Dimension,
892{
893    // Check that arrays have the same shape
894    if y_true.shape() != y_score.shape() {
895        return Err(MetricsError::InvalidInput(format!(
896            "y_true and y_score have different shapes: {:?} vs {:?}",
897            y_true.shape(),
898            y_score.shape()
899        )));
900    }
901
902    let n_samples = y_true.len();
903    if n_samples == 0 {
904        return Err(MetricsError::InvalidInput(
905            "Empty arrays provided".to_string(),
906        ));
907    }
908
909    // Validate y_true contains only binary values
910    for yt in y_true.iter() {
911        if *yt != 0 && *yt != 1 {
912            return Err(MetricsError::InvalidInput(
913                "y_true must contain only binary values (0 or 1)".to_string(),
914            ));
915        }
916    }
917
918    // Validate n_bins
919    if n_bins < 1 {
920        return Err(MetricsError::InvalidInput(
921            "n_bins must be at least 1".to_string(),
922        ));
923    }
924
925    // Compute the overall positive rate (baseline)
926    let n_positives = y_true.iter().filter(|&&y| y == 1).count();
927    if n_positives == 0 || n_positives == n_samples {
928        return Err(MetricsError::InvalidInput(
929            "y_true must contain both positive and negative samples".to_string(),
930        ));
931    }
932    let baseline_rate = n_positives as f64 / n_samples as f64;
933
934    // Pair scores with _true labels and sort by scores in descending order
935    let mut paired_data: Vec<(f64, u32)> = y_score
936        .iter()
937        .zip(y_true.iter())
938        .map(|(&_score, &label)| (_score, label))
939        .collect();
940    paired_data.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
941
942    // Calculate percentiles, lift values, and cumulative gains
943    let bin_size = n_samples / n_bins;
944    let mut percentiles = Vec::with_capacity(n_bins);
945    let mut lift_values = Vec::with_capacity(n_bins);
946    let mut cum_gains = Vec::with_capacity(n_bins);
947
948    for i in 0..n_bins {
949        // Calculate percentile
950        let percentile = (i + 1) as f64 * 100.0 / n_bins as f64;
951
952        // Calculate number of samples to consider (based on percentile)
953        let n_considered = if i == n_bins - 1 {
954            // Include all samples in the last bin
955            n_samples
956        } else {
957            (i + 1) * bin_size
958        };
959
960        // Count positives in this subset
961        let positives_in_bin = paired_data[0..n_considered]
962            .iter()
963            .filter(|(_, label)| *label == 1)
964            .count();
965
966        // Calculate lift and cumulative gain
967        let bin_rate = positives_in_bin as f64 / n_considered as f64;
968        let lift = bin_rate / baseline_rate;
969        let cum_gain = positives_in_bin as f64 / n_positives as f64;
970
971        percentiles.push(percentile);
972        lift_values.push(lift);
973        cum_gains.push(cum_gain);
974    }
975
976    Ok((
977        Array1::from(percentiles),
978        Array1::from(lift_values),
979        Array1::from(cum_gains),
980    ))
981}
982
983/// Computes the gain chart values for binary classification
984///
985/// The gain chart (or cumulative gains chart) shows the percentage of positive
986/// outcomes captured at each percentile when observations are ranked by predicted probability.
987///
988/// # Arguments
989///
990/// * `y_true` - Ground truth binary labels (0 or 1)
991/// * `y_score` - Predicted probabilities for the positive class
992/// * `n_bins` - Number of bins for the gain chart
993///
994/// # Returns
995///
996/// * A tuple containing two arrays:
997///   * `percentiles` - The percentiles used (0-100)
998///   * `cum_gains` - Cumulative gains values at each percentile
999///
1000/// # Examples
1001///
1002/// ```
1003/// use scirs2_core::ndarray::array;
1004/// use scirs2_metrics::classification::gain_chart;
1005///
1006/// let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
1007/// let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
1008///
1009/// let (percentiles, cum_gains) = gain_chart(&y_true, &y_score, 10).unwrap();
1010/// ```
1011#[allow(dead_code)]
1012pub fn gain_chart<S1, S2, D1, D2>(
1013    y_true: &ArrayBase<S1, D1>,
1014    y_score: &ArrayBase<S2, D2>,
1015    n_bins: usize,
1016) -> Result<(Array1<f64>, Array1<f64>)>
1017where
1018    S1: Data<Elem = u32>,
1019    S2: Data<Elem = f64>,
1020    D1: Dimension,
1021    D2: Dimension,
1022{
1023    // Reuse lift_chart function to get the data
1024    let (percentiles, _lift_values, cum_gains) = lift_chart(y_true, y_score, n_bins)?;
1025    Ok((percentiles, cum_gains))
1026}
1027
1028/// Generates a text report showing the main classification metrics
1029///
1030/// # Arguments
1031///
1032/// * `y_true` - Ground truth (correct) target values
1033/// * `y_pred` - Estimated targets as returned by a classifier
1034/// * `labels` - Optional list of label values to include in the report
1035///
1036/// # Returns
1037///
1038/// * A string containing the classification report
1039///
1040/// # Examples
1041///
1042/// ```
1043/// use scirs2_core::ndarray::array;
1044/// use scirs2_metrics::classification::classification_report;
1045///
1046/// let y_true = array![0, 1, 2, 0, 1, 2];
1047/// let y_pred = array![0, 2, 1, 0, 0, 2];
1048///
1049/// let report = classification_report(&y_true, &y_pred, None).unwrap();
1050/// println!("{}", report);
1051/// ```
1052#[allow(dead_code)]
1053pub fn classification_report<T, S1, S2, D1, D2>(
1054    y_true: &ArrayBase<S1, D1>,
1055    y_pred: &ArrayBase<S2, D2>,
1056    labels: Option<&[T]>,
1057) -> Result<String>
1058where
1059    T: PartialEq + NumCast + Clone + Ord + std::hash::Hash + std::fmt::Debug,
1060    S1: Data<Elem = T>,
1061    S2: Data<Elem = T>,
1062    D1: Dimension,
1063    D2: Dimension,
1064{
1065    // Get confusion matrix
1066    let (cm, classes) = confusion_matrix(y_true, y_pred, labels)?;
1067
1068    // Prepare report
1069    let mut report = String::new();
1070    report.push_str("              precision    recall  f1-score   support\n\n");
1071
1072    let n_classes = classes.len();
1073    let mut total_precision = 0.0;
1074    let mut total_recall = 0.0;
1075    let mut total_f1 = 0.0;
1076    let mut total_support = 0;
1077
1078    for i in 0..n_classes {
1079        let class_label = format!("{:?}", classes[i]);
1080
1081        // Calculate metrics for this class
1082        let true_positives = cm[[i, i]];
1083        let false_positives = cm.column(i).sum() - true_positives;
1084        let false_negatives = cm.row(i).sum() - true_positives;
1085        let support = cm.row(i).sum();
1086
1087        let precision = if true_positives + false_positives == 0 {
1088            0.0
1089        } else {
1090            true_positives as f64 / (true_positives + false_positives) as f64
1091        };
1092
1093        let recall = if true_positives + false_negatives == 0 {
1094            0.0
1095        } else {
1096            true_positives as f64 / (true_positives + false_negatives) as f64
1097        };
1098
1099        let f1 = if precision + recall == 0.0 {
1100            0.0
1101        } else {
1102            2.0 * precision * recall / (precision + recall)
1103        };
1104
1105        // Add to totals
1106        total_precision += precision;
1107        total_recall += recall;
1108        total_f1 += f1;
1109        total_support += support as usize;
1110
1111        // Add line to report
1112        report.push_str(&format!(
1113            "{class_label:>14} {precision:9.2} {recall:9.2} {f1:9.2} {support:9}\n"
1114        ));
1115    }
1116
1117    report.push('\n');
1118
1119    // Calculate averages
1120    let avg_precision = total_precision / n_classes as f64;
1121    let avg_recall = total_recall / n_classes as f64;
1122    let avg_f1 = total_f1 / n_classes as f64;
1123
1124    // Add averages to report
1125    report.push_str(&format!(
1126        "    avg / total {avg_precision:9.2} {avg_recall:9.2} {avg_f1:9.2} {total_support:9}\n"
1127    ));
1128
1129    Ok(report)
1130}
1131
1132#[cfg(test)]
1133mod tests {
1134    use super::*;
1135    use approx::assert_abs_diff_eq;
1136    use scirs2_core::ndarray::array;
1137
1138    #[test]
1139    fn test_accuracy_score() {
1140        let y_true = array![0, 1, 2, 3];
1141        let y_pred = array![0, 2, 1, 3];
1142
1143        let acc = accuracy_score(&y_true, &y_pred).unwrap();
1144        assert_abs_diff_eq!(acc, 0.5, epsilon = 1e-10); // 2 out of 4 are correct
1145    }
1146
1147    #[test]
1148    fn test_confusion_matrix() {
1149        let y_true = array![0, 1, 2, 0, 1, 2];
1150        let y_pred = array![0, 2, 1, 0, 0, 2];
1151
1152        let (cm, classes) = confusion_matrix(&y_true, &y_pred, None).unwrap();
1153
1154        assert_eq!(cm.shape(), &[3, 3]);
1155        assert_eq!(classes.len(), 3);
1156
1157        // Expected confusion matrix:
1158        // [[2, 0, 0],
1159        //  [1, 0, 1],
1160        //  [0, 1, 1]]
1161        assert_eq!(cm[[0, 0]], 2); // True 0, predicted 0
1162        assert_eq!(cm[[1, 0]], 1); // True 1, predicted 0
1163        assert_eq!(cm[[1, 2]], 1); // True 1, predicted 2
1164        assert_eq!(cm[[2, 1]], 1); // True 2, predicted 1
1165        assert_eq!(cm[[2, 2]], 1); // True 2, predicted 2
1166    }
1167
1168    #[test]
1169    fn test_precision_recall_f1() {
1170        let y_true = array![0, 1, 0, 0, 1, 1];
1171        let y_pred = array![0, 0, 1, 0, 1, 1];
1172
1173        let precision = precision_score(&y_true, &y_pred, 1).unwrap();
1174        let recall = recall_score(&y_true, &y_pred, 1).unwrap();
1175        let f1 = f1_score(&y_true, &y_pred, 1).unwrap();
1176
1177        // precision = 2/3, recall = 2/3, f1 = 2*2/3*2/3 / (2/3 + 2/3) = 2/3
1178        assert_abs_diff_eq!(precision, 2.0 / 3.0, epsilon = 1e-10);
1179        assert_abs_diff_eq!(recall, 2.0 / 3.0, epsilon = 1e-10);
1180        assert_abs_diff_eq!(f1, 2.0 / 3.0, epsilon = 1e-10);
1181    }
1182
1183    #[test]
1184    fn test_fbeta_score() {
1185        let y_true = array![0, 1, 0, 0, 1, 1];
1186        let y_pred = array![0, 0, 1, 0, 1, 1];
1187
1188        // F1 score (beta = 1.0)
1189        let f1 = fbeta_score(&y_true, &y_pred, 1, 1.0).unwrap();
1190        assert_abs_diff_eq!(f1, 2.0 / 3.0, epsilon = 1e-10);
1191
1192        // F0.5 score (weighs precision higher than recall)
1193        let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
1194        // With beta=0.5, beta²=0.25
1195        // F0.5 = (1+0.25) * 2/3 * 2/3 / (0.25*2/3 + 2/3) = 1.25 * 4/9 / (1/6 + 2/3) = 5/9 / 5/6 = 5/9 * 6/5 = 30/45 = 2/3
1196        assert_abs_diff_eq!(f_half, 2.0 / 3.0, epsilon = 1e-10);
1197
1198        // F2 score (weighs recall higher than precision)
1199        let f_two = fbeta_score(&y_true, &y_pred, 1, 2.0).unwrap();
1200        // With beta=2.0, beta²=4.0
1201        // F2 = (1+4) * 2/3 * 2/3 / (4*2/3 + 2/3) = 5 * 4/9 / (8/3 + 2/3) = 20/9 / 10/3 = 20/9 * 3/10 = 60/90 = 2/3
1202        assert_abs_diff_eq!(f_two, 2.0 / 3.0, epsilon = 1e-10);
1203
1204        // This example has equal precision and recall, so all F-beta scores are the same
1205        // Let's try a more interesting example with different precision and recall
1206        let y_true = array![1, 1, 1, 1, 1, 0, 0, 0, 0, 0];
1207        let y_pred = array![1, 1, 1, 0, 0, 0, 0, 0, 1, 1];
1208        // precision = 3/5 = 0.6, recall = 3/5 = 0.6
1209
1210        // F1 score (beta = 1.0)
1211        let f1 = fbeta_score(&y_true, &y_pred, 1, 1.0).unwrap();
1212        assert_abs_diff_eq!(f1, 0.6, epsilon = 1e-10);
1213
1214        // F0.5 score (weighs precision higher than recall)
1215        let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
1216        // With beta=0.5, beta²=0.25
1217        // F0.5 = (1+0.25) * 0.6 * 0.6 / (0.25*0.6 + 0.6) = 1.25 * 0.36 / (0.15 + 0.6) = 0.45 / 0.75 = 0.6
1218        assert_abs_diff_eq!(f_half, 0.6, epsilon = 1e-10);
1219
1220        // F2 score (weighs recall higher than precision)
1221        let f_two = fbeta_score(&y_true, &y_pred, 1, 2.0).unwrap();
1222        // With beta=2.0, beta²=4.0
1223        // F2 = (1+4) * 0.6 * 0.6 / (4*0.6 + 0.6) = 5 * 0.36 / (2.4 + 0.6) = 1.8 / 3.0 = 0.6
1224        assert_abs_diff_eq!(f_two, 0.6, epsilon = 1e-10);
1225
1226        // Let's try one more with different precision and recall
1227        let y_true = array![1, 1, 1, 1, 0, 0, 0, 0];
1228        let y_pred = array![1, 1, 0, 0, 0, 0, 1, 1];
1229        // precision = 2/4 = 0.5, recall = 2/4 = 0.5
1230
1231        // F0.5 score (weighs precision higher than recall)
1232        let f_half = fbeta_score(&y_true, &y_pred, 1, 0.5).unwrap();
1233        // With beta=0.5, beta²=0.25
1234        // F0.5 = (1+0.25) * 0.5 * 0.5 / (0.25*0.5 + 0.5) = 1.25 * 0.25 / (0.125 + 0.5) = 0.3125 / 0.625 = 0.5
1235        assert_abs_diff_eq!(f_half, 0.5, epsilon = 1e-10);
1236    }
1237
1238    #[test]
1239    fn test_log_loss() {
1240        let y_true = array![0, 1, 1, 0];
1241        let y_prob = array![0.1, 0.9, 0.8, 0.3];
1242
1243        let loss = binary_log_loss(&y_true, &y_prob, 1e-15).unwrap();
1244        // Expected loss: -[(log(0.9) + log(0.8) + log(0.9) + log(0.7))/4]
1245        let expected =
1246            -(((1.0_f64 - 0.1).ln() + 0.9_f64.ln() + 0.8_f64.ln() + (1.0_f64 - 0.3).ln()) / 4.0);
1247        assert_abs_diff_eq!(loss, expected, epsilon = 1e-10);
1248    }
1249
1250    #[test]
1251    fn test_roc_auc() {
1252        // Perfect separation
1253        let y_true = array![0, 0, 1, 1];
1254        let y_score = array![0.1, 0.2, 0.8, 0.9];
1255        let auc = roc_auc_score(&y_true, &y_score).unwrap();
1256        assert_abs_diff_eq!(auc, 1.0, epsilon = 1e-10);
1257
1258        // Random
1259        let y_true = array![0, 1, 0, 1];
1260        let y_score = array![0.5, 0.5, 0.5, 0.5];
1261        let auc = roc_auc_score(&y_true, &y_score).unwrap();
1262        assert_abs_diff_eq!(auc, 0.5, epsilon = 1e-10);
1263    }
1264
1265    #[test]
1266    fn test_lift_chart() {
1267        let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
1268        let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
1269
1270        // Test with 5 bins
1271        let (percentiles, lift_values, cum_gains) = lift_chart(&y_true, &y_score, 5).unwrap();
1272
1273        // Verify dimensions
1274        assert_eq!(percentiles.len(), 5);
1275        assert_eq!(lift_values.len(), 5);
1276        assert_eq!(cum_gains.len(), 5);
1277
1278        // Verify percentiles
1279        assert_abs_diff_eq!(percentiles[0], 20.0, epsilon = 1e-10);
1280        assert_abs_diff_eq!(percentiles[4], 100.0, epsilon = 1e-10);
1281
1282        // The first 20% contains the highest scored cases, which should be mostly positive
1283        // This should give a lift value higher than 1
1284        assert!(lift_values[0] > 1.0);
1285
1286        // The cumulative gains at 100% should be 1.0 (all positives)
1287        assert_abs_diff_eq!(cum_gains[4], 1.0, epsilon = 1e-10);
1288    }
1289
1290    #[test]
1291    fn test_gain_chart() {
1292        let y_true = array![0, 0, 1, 0, 1, 1, 0, 1, 0, 1];
1293        let y_score = array![0.1, 0.2, 0.7, 0.3, 0.8, 0.9, 0.4, 0.6, 0.2, 0.5];
1294
1295        // Test with 5 bins
1296        let (percentiles, cum_gains) = gain_chart(&y_true, &y_score, 5).unwrap();
1297
1298        // Verify dimensions
1299        assert_eq!(percentiles.len(), 5);
1300        assert_eq!(cum_gains.len(), 5);
1301
1302        // Verify percentiles
1303        assert_abs_diff_eq!(percentiles[0], 20.0, epsilon = 1e-10);
1304        assert_abs_diff_eq!(percentiles[4], 100.0, epsilon = 1e-10);
1305
1306        // Cumulative gains should be non-decreasing
1307        for i in 1..cum_gains.len() {
1308            assert!(cum_gains[i] >= cum_gains[i - 1]);
1309        }
1310
1311        // The cumulative gains at 100% should be 1.0 (all positives)
1312        assert_abs_diff_eq!(cum_gains[4], 1.0, epsilon = 1e-10);
1313    }
1314}