burn_train/metric/
classification.rs

1use std::num::NonZeroUsize;
2
3/// Necessary data for classification metrics.
4#[derive(Default, Debug, Clone)]
5pub struct ClassificationMetricConfig {
6    pub decision_rule: DecisionRule,
7    pub class_reduction: ClassReduction,
8}
9
10/// The prediction decision rule for classification metrics.
11#[derive(Debug, Clone)]
12pub enum DecisionRule {
13    /// Consider a class predicted if its probability exceeds the threshold.
14    Threshold(f64),
15    /// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
16    TopK(NonZeroUsize),
17}
18
19impl Default for DecisionRule {
20    fn default() -> Self {
21        Self::Threshold(0.5)
22    }
23}
24
25/// The reduction strategy for classification metrics.
26#[derive(Copy, Clone, Default, Debug)]
27pub enum ClassReduction {
28    /// Computes the statistics over all classes before averaging
29    Micro,
30    /// Computes the statistics independently for each class before averaging
31    #[default]
32    Macro,
33}