burn_train/metric/classification.rs
1use std::num::NonZeroUsize;
2
3/// Necessary data for classification metrics.
4#[derive(Default)]
5pub struct ClassificationMetricConfig {
6 pub decision_rule: DecisionRule,
7 pub class_reduction: ClassReduction,
8}
9
10/// The prediction decision rule for classification metrics.
11pub enum DecisionRule {
12 /// Consider a class predicted if its probability exceeds the threshold.
13 Threshold(f64),
14 /// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
15 TopK(NonZeroUsize),
16}
17
18impl Default for DecisionRule {
19 fn default() -> Self {
20 Self::Threshold(0.5)
21 }
22}
23
24/// The reduction strategy for classification metrics.
25#[derive(Copy, Clone, Default)]
26pub enum ClassReduction {
27 /// Computes the statistics over all classes before averaging
28 Micro,
29 /// Computes the statistics independently for each class before averaging
30 #[default]
31 Macro,
32}