burn_train/metric/
confusion_stats.rs

1use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule};
2use burn_core::prelude::{Backend, Bool, Int, Tensor};
3use std::fmt::{self, Debug};
4
5/// Input for confusion statistics error types.
6#[derive(new, Debug, Clone)]
7pub struct ConfusionStatsInput<B: Backend> {
8    /// Sample x Class Non thresholded normalized predictions.
9    pub predictions: Tensor<B, 2>,
10    /// Sample x Class one-hot encoded target.
11    pub targets: Tensor<B, 2, Bool>,
12}
13
14impl<B: Backend> From<ConfusionStatsInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {
15    fn from(input: ConfusionStatsInput<B>) -> Self {
16        (input.predictions, input.targets)
17    }
18}
19
20impl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for ConfusionStatsInput<B> {
21    fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {
22        Self::new(value.0, value.1)
23    }
24}
25
26#[derive(Clone)]
27pub struct ConfusionStats<B: Backend> {
28    confusion_classes: Tensor<B, 2, Int>,
29    class_reduction: ClassReduction,
30}
31
32impl<B: Backend> Debug for ConfusionStats<B> {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        let to_vec = |tensor_data: Tensor<B, 1>| {
35            tensor_data
36                .to_data()
37                .to_vec::<f32>()
38                .expect("A vector representation of the input Tensor is expected")
39        };
40        let ratio_of_support_vec =
41            |metric: Tensor<B, 1>| to_vec(self.clone().ratio_of_support(metric));
42        f.debug_struct("ConfusionStats")
43            .field("tp", &ratio_of_support_vec(self.clone().true_positive()))
44            .field("fp", &ratio_of_support_vec(self.clone().false_positive()))
45            .field("tn", &ratio_of_support_vec(self.clone().true_negative()))
46            .field("fn", &ratio_of_support_vec(self.clone().false_negative()))
47            .field("support", &to_vec(self.clone().support()))
48            .finish()
49    }
50}
51
52impl<B: Backend> ConfusionStats<B> {
53    /// Expects `predictions` to be normalized.
54    pub fn new(input: &ConfusionStatsInput<B>, config: &ClassificationMetricConfig) -> Self {
55        let prediction_mask = match config.decision_rule {
56            DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold),
57            DecisionRule::TopK(top_k) => {
58                let mask = input.predictions.zeros_like();
59                let indexes =
60                    input
61                        .predictions
62                        .clone()
63                        .argsort_descending(1)
64                        .narrow(1, 0, top_k.get());
65                let values = indexes.ones_like().float();
66                mask.scatter(1, indexes, values).bool()
67            }
68        };
69        Self {
70            confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2,
71            class_reduction: config.class_reduction,
72        }
73    }
74
75    /// sum over samples
76    fn aggregate(
77        sample_class_mask: Tensor<B, 2, Bool>,
78        class_reduction: ClassReduction,
79    ) -> Tensor<B, 1> {
80        use ClassReduction::{Macro, Micro};
81        match class_reduction {
82            Micro => sample_class_mask.float().sum(),
83            Macro => sample_class_mask.float().sum_dim(0).squeeze_dim(0),
84        }
85    }
86
87    pub fn true_positive(self) -> Tensor<B, 1> {
88        Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction)
89    }
90
91    pub fn true_negative(self) -> Tensor<B, 1> {
92        Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction)
93    }
94
95    pub fn false_positive(self) -> Tensor<B, 1> {
96        Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction)
97    }
98
99    pub fn false_negative(self) -> Tensor<B, 1> {
100        Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction)
101    }
102
103    pub fn positive(self) -> Tensor<B, 1> {
104        self.clone().true_positive() + self.false_negative()
105    }
106
107    pub fn negative(self) -> Tensor<B, 1> {
108        self.clone().true_negative() + self.false_positive()
109    }
110
111    pub fn predicted_positive(self) -> Tensor<B, 1> {
112        self.clone().true_positive() + self.false_positive()
113    }
114
115    pub fn support(self) -> Tensor<B, 1> {
116        self.clone().positive() + self.negative()
117    }
118
119    pub fn ratio_of_support(self, metric: Tensor<B, 1>) -> Tensor<B, 1> {
120        metric / self.clone().support()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::{ConfusionStats, ConfusionStatsInput};
127    use crate::{
128        TestBackend,
129        metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
130        tests::{ClassificationType, THRESHOLD, dummy_classification_input},
131    };
132    use burn_core::prelude::TensorData;
133    use rstest::{fixture, rstest};
134    use std::num::NonZeroUsize;
135
136    fn top_k_config(
137        top_k: NonZeroUsize,
138        class_reduction: ClassReduction,
139    ) -> ClassificationMetricConfig {
140        ClassificationMetricConfig {
141            decision_rule: DecisionRule::TopK(top_k),
142            class_reduction,
143        }
144    }
145    #[fixture]
146    #[once]
147    fn top_k_config_k1_micro() -> ClassificationMetricConfig {
148        top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro)
149    }
150
151    #[fixture]
152    #[once]
153    fn top_k_config_k1_macro() -> ClassificationMetricConfig {
154        top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro)
155    }
156    #[fixture]
157    #[once]
158    fn top_k_config_k2_micro() -> ClassificationMetricConfig {
159        top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro)
160    }
161    #[fixture]
162    #[once]
163    fn top_k_config_k2_macro() -> ClassificationMetricConfig {
164        top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro)
165    }
166
167    fn threshold_config(
168        threshold: f64,
169        class_reduction: ClassReduction,
170    ) -> ClassificationMetricConfig {
171        ClassificationMetricConfig {
172            decision_rule: DecisionRule::Threshold(threshold),
173            class_reduction,
174        }
175    }
176    #[fixture]
177    #[once]
178    fn threshold_config_micro() -> ClassificationMetricConfig {
179        threshold_config(THRESHOLD, ClassReduction::Micro)
180    }
181    #[fixture]
182    #[once]
183    fn threshold_config_macro() -> ClassificationMetricConfig {
184        threshold_config(THRESHOLD, ClassReduction::Macro)
185    }
186
187    #[rstest]
188    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
189    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
190    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())]
191    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())]
192    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
193    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())]
194    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())]
195    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())]
196    fn test_true_positive(
197        #[case] classification_type: ClassificationType,
198        #[case] config: ClassificationMetricConfig,
199        #[case] expected: Vec<i64>,
200    ) {
201        let input: ConfusionStatsInput<TestBackend> =
202            dummy_classification_input(&classification_type).into();
203        ConfusionStats::new(&input, &config)
204            .true_positive()
205            .int()
206            .into_data()
207            .assert_eq(&TensorData::from(expected.as_slice()), true);
208    }
209
210    #[rstest]
211    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
212    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
213    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())]
214    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())]
215    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
216    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())]
217    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
218    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())]
219    fn test_true_negative(
220        #[case] classification_type: ClassificationType,
221        #[case] config: ClassificationMetricConfig,
222        #[case] expected: Vec<i64>,
223    ) {
224        let input: ConfusionStatsInput<TestBackend> =
225            dummy_classification_input(&classification_type).into();
226        ConfusionStats::new(&input, &config)
227            .true_negative()
228            .int()
229            .into_data()
230            .assert_eq(&TensorData::from(expected.as_slice()), true);
231    }
232
233    #[rstest]
234    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
235    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
236    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
237    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())]
238    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())]
239    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())]
240    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
241    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())]
242    fn test_false_positive(
243        #[case] classification_type: ClassificationType,
244        #[case] config: ClassificationMetricConfig,
245        #[case] expected: Vec<i64>,
246    ) {
247        let input: ConfusionStatsInput<TestBackend> =
248            dummy_classification_input(&classification_type).into();
249        ConfusionStats::new(&input, &config)
250            .false_positive()
251            .int()
252            .into_data()
253            .assert_eq(&TensorData::from(expected.as_slice()), true);
254    }
255
256    #[rstest]
257    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
258    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
259    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
260    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())]
261    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())]
262    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())]
263    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())]
264    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())]
265    fn test_false_negatives(
266        #[case] classification_type: ClassificationType,
267        #[case] config: ClassificationMetricConfig,
268        #[case] expected: Vec<i64>,
269    ) {
270        let input: ConfusionStatsInput<TestBackend> =
271            dummy_classification_input(&classification_type).into();
272        ConfusionStats::new(&input, &config)
273            .false_negative()
274            .int()
275            .into_data()
276            .assert_eq(&TensorData::from(expected.as_slice()), true);
277    }
278
279    #[rstest]
280    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
281    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
282    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
283    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())]
284    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())]
285    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())]
286    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())]
287    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())]
288    fn test_positive(
289        #[case] classification_type: ClassificationType,
290        #[case] config: ClassificationMetricConfig,
291        #[case] expected: Vec<i64>,
292    ) {
293        let input: ConfusionStatsInput<TestBackend> =
294            dummy_classification_input(&classification_type).into();
295        ConfusionStats::new(&input, &config)
296            .positive()
297            .int()
298            .into_data()
299            .assert_eq(&TensorData::from(expected.as_slice()), true);
300    }
301
302    #[rstest]
303    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())]
304    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())]
305    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())]
306    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())]
307    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
308    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())]
309    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())]
310    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())]
311    fn test_negative(
312        #[case] classification_type: ClassificationType,
313        #[case] config: ClassificationMetricConfig,
314        #[case] expected: Vec<i64>,
315    ) {
316        let input: ConfusionStatsInput<TestBackend> =
317            dummy_classification_input(&classification_type).into();
318        ConfusionStats::new(&input, &config)
319            .negative()
320            .int()
321            .into_data()
322            .assert_eq(&TensorData::from(expected.as_slice()), true);
323    }
324
325    #[rstest]
326    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
327    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
328    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
329    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())]
330    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
331    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())]
332    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())]
333    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())]
334    fn test_predicted_positive(
335        #[case] classification_type: ClassificationType,
336        #[case] config: ClassificationMetricConfig,
337        #[case] expected: Vec<i64>,
338    ) {
339        let input: ConfusionStatsInput<TestBackend> =
340            dummy_classification_input(&classification_type).into();
341        ConfusionStats::new(&input, &config)
342            .predicted_positive()
343            .int()
344            .into_data()
345            .assert_eq(&TensorData::from(expected.as_slice()), true);
346    }
347}