burn_train/metric/classification.rs
1use std::num::NonZeroUsize;
2
3/// Necessary data for classification metrics.
4#[derive(Default, Debug, Clone)]
5pub struct ClassificationMetricConfig {
6 pub decision_rule: DecisionRule,
7 pub class_reduction: ClassReduction,
8}
9
10/// The prediction decision rule for classification metrics.
11#[derive(Debug, Clone)]
12pub enum DecisionRule {
13 /// Consider a class predicted if its probability exceeds the threshold.
14 Threshold(f64),
15 /// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
16 TopK(NonZeroUsize),
17}
18
19impl Default for DecisionRule {
20 fn default() -> Self {
21 Self::Threshold(0.5)
22 }
23}
24
25/// The reduction strategy for classification metrics.
26#[derive(Copy, Clone, Default, Debug)]
27pub enum ClassReduction {
28 /// Computes the statistics over all classes before averaging
29 Micro,
30 /// Computes the statistics independently for each class before averaging
31 #[default]
32 Macro,
33}