burn_train/learner/
classification.rs

1use crate::metric::TopKAccuracyInput;
2use crate::metric::{
3    AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, PerplexityInput,
4    processor::ItemLazy,
5};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{Int, Tensor, Transaction};
8use burn_ndarray::NdArray;
9
10/// Simple classification output adapted for multiple metrics.
11#[derive(new)]
12pub struct ClassificationOutput<B: Backend> {
13    /// The loss.
14    pub loss: Tensor<B, 1>,
15
16    /// The output.
17    pub output: Tensor<B, 2>,
18
19    /// The targets.
20    pub targets: Tensor<B, 1, Int>,
21}
22
23impl<B: Backend> ItemLazy for ClassificationOutput<B> {
24    type ItemSync = ClassificationOutput<NdArray>;
25
26    fn sync(self) -> Self::ItemSync {
27        let [output, loss, targets] = Transaction::default()
28            .register(self.output)
29            .register(self.loss)
30            .register(self.targets)
31            .execute()
32            .try_into()
33            .expect("Correct amount of tensor data");
34
35        let device = &Default::default();
36
37        ClassificationOutput {
38            output: Tensor::from_data(output, device),
39            loss: Tensor::from_data(loss, device),
40            targets: Tensor::from_data(targets, device),
41        }
42    }
43}
44
45impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
46    fn adapt(&self) -> AccuracyInput<B> {
47        AccuracyInput::new(self.output.clone(), self.targets.clone())
48    }
49}
50
51impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
52    fn adapt(&self) -> LossInput<B> {
53        LossInput::new(self.loss.clone())
54    }
55}
56
57impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
58    fn adapt(&self) -> TopKAccuracyInput<B> {
59        TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
60    }
61}
62
63impl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {
64    fn adapt(&self) -> PerplexityInput<B> {
65        PerplexityInput::new(self.output.clone(), self.targets.clone())
66    }
67}
68
69impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
70    fn adapt(&self) -> ConfusionStatsInput<B> {
71        let [_, num_classes] = self.output.dims();
72        if num_classes > 1 {
73            ConfusionStatsInput::new(
74                self.output.clone(),
75                self.targets.clone().one_hot(num_classes).bool(),
76            )
77        } else {
78            ConfusionStatsInput::new(
79                self.output.clone(),
80                self.targets.clone().unsqueeze_dim(1).bool(),
81            )
82        }
83    }
84}
85
86/// Multi-label classification output adapted for multiple metrics.
87#[derive(new)]
88pub struct MultiLabelClassificationOutput<B: Backend> {
89    /// The loss.
90    pub loss: Tensor<B, 1>,
91
92    /// The output.
93    pub output: Tensor<B, 2>,
94
95    /// The targets.
96    pub targets: Tensor<B, 2, Int>,
97}
98
99impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
100    type ItemSync = MultiLabelClassificationOutput<NdArray>;
101
102    fn sync(self) -> Self::ItemSync {
103        let [output, loss, targets] = Transaction::default()
104            .register(self.output)
105            .register(self.loss)
106            .register(self.targets)
107            .execute()
108            .try_into()
109            .expect("Correct amount of tensor data");
110
111        let device = &Default::default();
112
113        MultiLabelClassificationOutput {
114            output: Tensor::from_data(output, device),
115            loss: Tensor::from_data(loss, device),
116            targets: Tensor::from_data(targets, device),
117        }
118    }
119}
120
121impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
122    fn adapt(&self) -> HammingScoreInput<B> {
123        HammingScoreInput::new(self.output.clone(), self.targets.clone())
124    }
125}
126
127impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
128    fn adapt(&self) -> LossInput<B> {
129        LossInput::new(self.loss.clone())
130    }
131}
132
133impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
134    fn adapt(&self) -> ConfusionStatsInput<B> {
135        ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
136    }
137}