Skip to main content

burn_train/learner/
classification.rs

1use crate::metric::{
2    AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,
3    PerplexityInput, TopKAccuracyInput, processor::ItemLazy,
4};
5use burn_core::tensor::backend::Backend;
6use burn_core::tensor::{Int, Tensor, Transaction};
7use burn_flex::Flex;
8
9/// Simple classification output adapted for multiple metrics.
10///
11/// Supported metrics:
12/// - Accuracy
13/// - AUROC
14/// - TopKAccuracy
15/// - Perplexity
16/// - Precision (via ConfusionStatsInput)
17/// - Recall (via ConfusionStatsInput)
18/// - FBetaScore (via ConfusionStatsInput)
19/// - Loss.
20#[derive(new)]
21pub struct ClassificationOutput<B: Backend> {
22    /// The loss.
23    pub loss: Tensor<B, 1>,
24
25    /// The class logits or probabilities. Shape: \[batch_size, num_classes\].
26    pub output: Tensor<B, 2>,
27
28    /// The ground truth class index for each sample. Shape: \[batch_size\].
29    pub targets: Tensor<B, 1, Int>,
30}
31
32impl<B: Backend> ItemLazy for ClassificationOutput<B> {
33    // Flex's IntElem is i32; class indices > i32::MAX would truncate on sync.
34    type ItemSync = ClassificationOutput<Flex>;
35
36    fn sync(self) -> Self::ItemSync {
37        let [output, loss, targets] = Transaction::default()
38            .register(self.output)
39            .register(self.loss)
40            .register(self.targets)
41            .execute()
42            .try_into()
43            .expect("Correct amount of tensor data");
44
45        let device = &Default::default();
46
47        ClassificationOutput {
48            output: Tensor::from_data(output, device),
49            loss: Tensor::from_data(loss, device),
50            targets: Tensor::from_data(targets, device),
51        }
52    }
53}
54
55impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
56    fn adapt(&self) -> AccuracyInput<B> {
57        AccuracyInput::new(self.output.clone(), self.targets.clone())
58    }
59}
60
61impl<B: Backend> Adaptor<AurocInput<B>> for ClassificationOutput<B> {
62    fn adapt(&self) -> AurocInput<B> {
63        AurocInput::new(self.output.clone(), self.targets.clone())
64    }
65}
66
67impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
68    fn adapt(&self) -> LossInput<B> {
69        LossInput::new(self.loss.clone())
70    }
71}
72
73impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
74    fn adapt(&self) -> TopKAccuracyInput<B> {
75        TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
76    }
77}
78
79impl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {
80    fn adapt(&self) -> PerplexityInput<B> {
81        PerplexityInput::new(self.output.clone(), self.targets.clone())
82    }
83}
84
85impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
86    fn adapt(&self) -> ConfusionStatsInput<B> {
87        let [_, num_classes] = self.output.dims();
88        if num_classes > 1 {
89            ConfusionStatsInput::new(
90                self.output.clone(),
91                self.targets.clone().one_hot(num_classes).bool(),
92            )
93        } else {
94            ConfusionStatsInput::new(
95                self.output.clone(),
96                self.targets.clone().unsqueeze_dim(1).bool(),
97            )
98        }
99    }
100}
101
102/// Multi-label classification output adapted for multiple metrics.
103///
104/// Supported metrics:
105/// - HammingScore
106/// - Precision (via ConfusionStatsInput)
107/// - Recall (via ConfusionStatsInput)
108/// - FBetaScore (via ConfusionStatsInput)
109/// - Loss
110#[derive(new)]
111pub struct MultiLabelClassificationOutput<B: Backend> {
112    /// The loss.
113    pub loss: Tensor<B, 1>,
114
115    /// The label logits or probabilities. Shape: \[batch_size, num_classes\].
116    pub output: Tensor<B, 2>,
117
118    /// The ground truth labels. Shape: \[batch_size, num_classes\].
119    pub targets: Tensor<B, 2, Int>,
120}
121
122impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
123    // Flex's IntElem is i32; label indices > i32::MAX would truncate on sync.
124    type ItemSync = MultiLabelClassificationOutput<Flex>;
125
126    fn sync(self) -> Self::ItemSync {
127        let [output, loss, targets] = Transaction::default()
128            .register(self.output)
129            .register(self.loss)
130            .register(self.targets)
131            .execute()
132            .try_into()
133            .expect("Correct amount of tensor data");
134
135        let device = &Default::default();
136
137        MultiLabelClassificationOutput {
138            output: Tensor::from_data(output, device),
139            loss: Tensor::from_data(loss, device),
140            targets: Tensor::from_data(targets, device),
141        }
142    }
143}
144
145impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
146    fn adapt(&self) -> HammingScoreInput<B> {
147        HammingScoreInput::new(self.output.clone(), self.targets.clone())
148    }
149}
150
151impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
152    fn adapt(&self) -> LossInput<B> {
153        LossInput::new(self.loss.clone())
154    }
155}
156
157impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
158    fn adapt(&self) -> ConfusionStatsInput<B> {
159        ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
160    }
161}