burn_train/metric/
classification.rs

1use std::num::NonZeroUsize;
2
3/// Necessary data for classification metrics.
4#[derive(Default)]
5pub struct ClassificationMetricConfig {
6    pub decision_rule: DecisionRule,
7    pub class_reduction: ClassReduction,
8}
9
10/// The prediction decision rule for classification metrics.
11pub enum DecisionRule {
12    /// Consider a class predicted if its probability exceeds the threshold.
13    Threshold(f64),
14    /// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
15    TopK(NonZeroUsize),
16}
17
18impl Default for DecisionRule {
19    fn default() -> Self {
20        Self::Threshold(0.5)
21    }
22}
23
24/// The reduction strategy for classification metrics.
25#[derive(Copy, Clone, Default)]
26pub enum ClassReduction {
27    /// Computes the statistics over all classes before averaging
28    Micro,
29    /// Computes the statistics independently for each class before averaging
30    #[default]
31    Macro,
32}