burn_train/learner/
classification.rs

1use crate::metric::TopKAccuracyInput;
2use crate::metric::{
3    AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, processor::ItemLazy,
4};
5use burn_core::tensor::backend::Backend;
6use burn_core::tensor::{Int, Tensor, Transaction};
7use burn_ndarray::NdArray;
8
9/// Simple classification output adapted for multiple metrics.
10#[derive(new)]
11pub struct ClassificationOutput<B: Backend> {
12    /// The loss.
13    pub loss: Tensor<B, 1>,
14
15    /// The output.
16    pub output: Tensor<B, 2>,
17
18    /// The targets.
19    pub targets: Tensor<B, 1, Int>,
20}
21
22impl<B: Backend> ItemLazy for ClassificationOutput<B> {
23    type ItemSync = ClassificationOutput<NdArray>;
24
25    fn sync(self) -> Self::ItemSync {
26        let [output, loss, targets] = Transaction::default()
27            .register(self.output)
28            .register(self.loss)
29            .register(self.targets)
30            .execute()
31            .try_into()
32            .expect("Correct amount of tensor data");
33
34        let device = &Default::default();
35
36        ClassificationOutput {
37            output: Tensor::from_data(output, device),
38            loss: Tensor::from_data(loss, device),
39            targets: Tensor::from_data(targets, device),
40        }
41    }
42}
43
44impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
45    fn adapt(&self) -> AccuracyInput<B> {
46        AccuracyInput::new(self.output.clone(), self.targets.clone())
47    }
48}
49
50impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
51    fn adapt(&self) -> LossInput<B> {
52        LossInput::new(self.loss.clone())
53    }
54}
55
56impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
57    fn adapt(&self) -> TopKAccuracyInput<B> {
58        TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
59    }
60}
61
62impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
63    fn adapt(&self) -> ConfusionStatsInput<B> {
64        let [_, num_classes] = self.output.dims();
65        if num_classes > 1 {
66            ConfusionStatsInput::new(
67                self.output.clone(),
68                self.targets.clone().one_hot(num_classes).bool(),
69            )
70        } else {
71            ConfusionStatsInput::new(
72                self.output.clone(),
73                self.targets.clone().unsqueeze_dim(1).bool(),
74            )
75        }
76    }
77}
78
79/// Multi-label classification output adapted for multiple metrics.
80#[derive(new)]
81pub struct MultiLabelClassificationOutput<B: Backend> {
82    /// The loss.
83    pub loss: Tensor<B, 1>,
84
85    /// The output.
86    pub output: Tensor<B, 2>,
87
88    /// The targets.
89    pub targets: Tensor<B, 2, Int>,
90}
91
92impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
93    type ItemSync = MultiLabelClassificationOutput<NdArray>;
94
95    fn sync(self) -> Self::ItemSync {
96        let [output, loss, targets] = Transaction::default()
97            .register(self.output)
98            .register(self.loss)
99            .register(self.targets)
100            .execute()
101            .try_into()
102            .expect("Correct amount of tensor data");
103
104        let device = &Default::default();
105
106        MultiLabelClassificationOutput {
107            output: Tensor::from_data(output, device),
108            loss: Tensor::from_data(loss, device),
109            targets: Tensor::from_data(targets, device),
110        }
111    }
112}
113
114impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
115    fn adapt(&self) -> HammingScoreInput<B> {
116        HammingScoreInput::new(self.output.clone(), self.targets.clone())
117    }
118}
119
120impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
121    fn adapt(&self) -> LossInput<B> {
122        LossInput::new(self.loss.clone())
123    }
124}
125
126impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
127    fn adapt(&self) -> ConfusionStatsInput<B> {
128        ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
129    }
130}