1use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule};
2use burn_core::prelude::{Backend, Bool, Int, Tensor};
3use std::fmt::{self, Debug};
4
5#[derive(new, Debug, Clone)]
7pub struct ConfusionStatsInput<B: Backend> {
8 pub predictions: Tensor<B, 2>,
10 pub targets: Tensor<B, 2, Bool>,
12}
13
14impl<B: Backend> From<ConfusionStatsInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {
15 fn from(input: ConfusionStatsInput<B>) -> Self {
16 (input.predictions, input.targets)
17 }
18}
19
20impl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for ConfusionStatsInput<B> {
21 fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {
22 Self::new(value.0, value.1)
23 }
24}
25
26#[derive(Clone)]
27pub struct ConfusionStats<B: Backend> {
28 confusion_classes: Tensor<B, 2, Int>,
29 class_reduction: ClassReduction,
30}
31
32impl<B: Backend> Debug for ConfusionStats<B> {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 let to_vec = |tensor_data: Tensor<B, 1>| {
35 tensor_data
36 .to_data()
37 .to_vec::<f32>()
38 .expect("A vector representation of the input Tensor is expected")
39 };
40 let ratio_of_support_vec =
41 |metric: Tensor<B, 1>| to_vec(self.clone().ratio_of_support(metric));
42 f.debug_struct("ConfusionStats")
43 .field("tp", &ratio_of_support_vec(self.clone().true_positive()))
44 .field("fp", &ratio_of_support_vec(self.clone().false_positive()))
45 .field("tn", &ratio_of_support_vec(self.clone().true_negative()))
46 .field("fn", &ratio_of_support_vec(self.clone().false_negative()))
47 .field("support", &to_vec(self.clone().support()))
48 .finish()
49 }
50}
51
52impl<B: Backend> ConfusionStats<B> {
53 pub fn new(input: &ConfusionStatsInput<B>, config: &ClassificationMetricConfig) -> Self {
55 let prediction_mask = match config.decision_rule {
56 DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold),
57 DecisionRule::TopK(top_k) => {
58 let mask = input.predictions.zeros_like();
59 let indexes =
60 input
61 .predictions
62 .clone()
63 .argsort_descending(1)
64 .narrow(1, 0, top_k.get());
65 let values = indexes.ones_like().float();
66 mask.scatter(1, indexes, values).bool()
67 }
68 };
69 Self {
70 confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2,
71 class_reduction: config.class_reduction,
72 }
73 }
74
75 fn aggregate(
77 sample_class_mask: Tensor<B, 2, Bool>,
78 class_reduction: ClassReduction,
79 ) -> Tensor<B, 1> {
80 use ClassReduction::{Macro, Micro};
81 match class_reduction {
82 Micro => sample_class_mask.float().sum(),
83 Macro => sample_class_mask.float().sum_dim(0).squeeze_dim(0),
84 }
85 }
86
87 pub fn true_positive(self) -> Tensor<B, 1> {
88 Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction)
89 }
90
91 pub fn true_negative(self) -> Tensor<B, 1> {
92 Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction)
93 }
94
95 pub fn false_positive(self) -> Tensor<B, 1> {
96 Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction)
97 }
98
99 pub fn false_negative(self) -> Tensor<B, 1> {
100 Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction)
101 }
102
103 pub fn positive(self) -> Tensor<B, 1> {
104 self.clone().true_positive() + self.false_negative()
105 }
106
107 pub fn negative(self) -> Tensor<B, 1> {
108 self.clone().true_negative() + self.false_positive()
109 }
110
111 pub fn predicted_positive(self) -> Tensor<B, 1> {
112 self.clone().true_positive() + self.false_positive()
113 }
114
115 pub fn support(self) -> Tensor<B, 1> {
116 self.clone().positive() + self.negative()
117 }
118
119 pub fn ratio_of_support(self, metric: Tensor<B, 1>) -> Tensor<B, 1> {
120 metric / self.clone().support()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::{ConfusionStats, ConfusionStatsInput};
127 use crate::{
128 TestBackend,
129 metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
130 tests::{ClassificationType, THRESHOLD, dummy_classification_input},
131 };
132 use burn_core::prelude::TensorData;
133 use rstest::{fixture, rstest};
134 use std::num::NonZeroUsize;
135
136 fn top_k_config(
137 top_k: NonZeroUsize,
138 class_reduction: ClassReduction,
139 ) -> ClassificationMetricConfig {
140 ClassificationMetricConfig {
141 decision_rule: DecisionRule::TopK(top_k),
142 class_reduction,
143 }
144 }
145 #[fixture]
146 #[once]
147 fn top_k_config_k1_micro() -> ClassificationMetricConfig {
148 top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro)
149 }
150
151 #[fixture]
152 #[once]
153 fn top_k_config_k1_macro() -> ClassificationMetricConfig {
154 top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro)
155 }
156 #[fixture]
157 #[once]
158 fn top_k_config_k2_micro() -> ClassificationMetricConfig {
159 top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro)
160 }
161 #[fixture]
162 #[once]
163 fn top_k_config_k2_macro() -> ClassificationMetricConfig {
164 top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro)
165 }
166
167 fn threshold_config(
168 threshold: f64,
169 class_reduction: ClassReduction,
170 ) -> ClassificationMetricConfig {
171 ClassificationMetricConfig {
172 decision_rule: DecisionRule::Threshold(threshold),
173 class_reduction,
174 }
175 }
176 #[fixture]
177 #[once]
178 fn threshold_config_micro() -> ClassificationMetricConfig {
179 threshold_config(THRESHOLD, ClassReduction::Micro)
180 }
181 #[fixture]
182 #[once]
183 fn threshold_config_macro() -> ClassificationMetricConfig {
184 threshold_config(THRESHOLD, ClassReduction::Macro)
185 }
186
187 #[rstest]
188 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
189 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
190 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())]
191 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())]
192 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
193 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())]
194 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())]
195 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())]
196 fn test_true_positive(
197 #[case] classification_type: ClassificationType,
198 #[case] config: ClassificationMetricConfig,
199 #[case] expected: Vec<i64>,
200 ) {
201 let input: ConfusionStatsInput<TestBackend> =
202 dummy_classification_input(&classification_type).into();
203 ConfusionStats::new(&input, &config)
204 .true_positive()
205 .int()
206 .into_data()
207 .assert_eq(&TensorData::from(expected.as_slice()), true);
208 }
209
210 #[rstest]
211 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
212 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
213 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())]
214 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())]
215 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
216 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())]
217 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
218 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())]
219 fn test_true_negative(
220 #[case] classification_type: ClassificationType,
221 #[case] config: ClassificationMetricConfig,
222 #[case] expected: Vec<i64>,
223 ) {
224 let input: ConfusionStatsInput<TestBackend> =
225 dummy_classification_input(&classification_type).into();
226 ConfusionStats::new(&input, &config)
227 .true_negative()
228 .int()
229 .into_data()
230 .assert_eq(&TensorData::from(expected.as_slice()), true);
231 }
232
233 #[rstest]
234 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
235 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
236 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
237 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())]
238 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())]
239 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())]
240 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
241 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())]
242 fn test_false_positive(
243 #[case] classification_type: ClassificationType,
244 #[case] config: ClassificationMetricConfig,
245 #[case] expected: Vec<i64>,
246 ) {
247 let input: ConfusionStatsInput<TestBackend> =
248 dummy_classification_input(&classification_type).into();
249 ConfusionStats::new(&input, &config)
250 .false_positive()
251 .int()
252 .into_data()
253 .assert_eq(&TensorData::from(expected.as_slice()), true);
254 }
255
256 #[rstest]
257 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
258 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
259 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
260 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())]
261 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())]
262 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())]
263 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())]
264 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())]
265 fn test_false_negatives(
266 #[case] classification_type: ClassificationType,
267 #[case] config: ClassificationMetricConfig,
268 #[case] expected: Vec<i64>,
269 ) {
270 let input: ConfusionStatsInput<TestBackend> =
271 dummy_classification_input(&classification_type).into();
272 ConfusionStats::new(&input, &config)
273 .false_negative()
274 .int()
275 .into_data()
276 .assert_eq(&TensorData::from(expected.as_slice()), true);
277 }
278
279 #[rstest]
280 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
281 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
282 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
283 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())]
284 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())]
285 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())]
286 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())]
287 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())]
288 fn test_positive(
289 #[case] classification_type: ClassificationType,
290 #[case] config: ClassificationMetricConfig,
291 #[case] expected: Vec<i64>,
292 ) {
293 let input: ConfusionStatsInput<TestBackend> =
294 dummy_classification_input(&classification_type).into();
295 ConfusionStats::new(&input, &config)
296 .positive()
297 .int()
298 .into_data()
299 .assert_eq(&TensorData::from(expected.as_slice()), true);
300 }
301
302 #[rstest]
303 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())]
304 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())]
305 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())]
306 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())]
307 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
308 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())]
309 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())]
310 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())]
311 fn test_negative(
312 #[case] classification_type: ClassificationType,
313 #[case] config: ClassificationMetricConfig,
314 #[case] expected: Vec<i64>,
315 ) {
316 let input: ConfusionStatsInput<TestBackend> =
317 dummy_classification_input(&classification_type).into();
318 ConfusionStats::new(&input, &config)
319 .negative()
320 .int()
321 .into_data()
322 .assert_eq(&TensorData::from(expected.as_slice()), true);
323 }
324
325 #[rstest]
326 #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
327 #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
328 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
329 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())]
330 #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
331 #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())]
332 #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())]
333 #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())]
334 fn test_predicted_positive(
335 #[case] classification_type: ClassificationType,
336 #[case] config: ClassificationMetricConfig,
337 #[case] expected: Vec<i64>,
338 ) {
339 let input: ConfusionStatsInput<TestBackend> =
340 dummy_classification_input(&classification_type).into();
341 ConfusionStats::new(&input, &config)
342 .predicted_positive()
343 .int()
344 .into_data()
345 .assert_eq(&TensorData::from(expected.as_slice()), true);
346 }
347}