burn_train/learner/
classification.rs

1use crate::metric::{
2    processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput,
3};
4use burn_core::tensor::backend::Backend;
5use burn_core::tensor::{Int, Tensor, Transaction};
6use burn_ndarray::NdArray;
7
8/// Simple classification output adapted for multiple metrics.
9#[derive(new)]
10pub struct ClassificationOutput<B: Backend> {
11    /// The loss.
12    pub loss: Tensor<B, 1>,
13
14    /// The output.
15    pub output: Tensor<B, 2>,
16
17    /// The targets.
18    pub targets: Tensor<B, 1, Int>,
19}
20
21impl<B: Backend> ItemLazy for ClassificationOutput<B> {
22    type ItemSync = ClassificationOutput<NdArray>;
23
24    fn sync(self) -> Self::ItemSync {
25        let [output, loss, targets] = Transaction::default()
26            .register(self.output)
27            .register(self.loss)
28            .register(self.targets)
29            .execute()
30            .try_into()
31            .expect("Correct amount of tensor data");
32
33        let device = &Default::default();
34
35        ClassificationOutput {
36            output: Tensor::from_data(output, device),
37            loss: Tensor::from_data(loss, device),
38            targets: Tensor::from_data(targets, device),
39        }
40    }
41}
42
43impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
44    fn adapt(&self) -> AccuracyInput<B> {
45        AccuracyInput::new(self.output.clone(), self.targets.clone())
46    }
47}
48
49impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
50    fn adapt(&self) -> LossInput<B> {
51        LossInput::new(self.loss.clone())
52    }
53}
54
55impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
56    fn adapt(&self) -> ConfusionStatsInput<B> {
57        let [_, num_classes] = self.output.dims();
58        if num_classes > 1 {
59            ConfusionStatsInput::new(
60                self.output.clone(),
61                self.targets.clone().one_hot(num_classes).bool(),
62            )
63        } else {
64            ConfusionStatsInput::new(
65                self.output.clone(),
66                self.targets.clone().unsqueeze_dim(1).bool(),
67            )
68        }
69    }
70}
71
72/// Multi-label classification output adapted for multiple metrics.
73#[derive(new)]
74pub struct MultiLabelClassificationOutput<B: Backend> {
75    /// The loss.
76    pub loss: Tensor<B, 1>,
77
78    /// The output.
79    pub output: Tensor<B, 2>,
80
81    /// The targets.
82    pub targets: Tensor<B, 2, Int>,
83}
84
85impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
86    type ItemSync = MultiLabelClassificationOutput<NdArray>;
87
88    fn sync(self) -> Self::ItemSync {
89        let [output, loss, targets] = Transaction::default()
90            .register(self.output)
91            .register(self.loss)
92            .register(self.targets)
93            .execute()
94            .try_into()
95            .expect("Correct amount of tensor data");
96
97        let device = &Default::default();
98
99        MultiLabelClassificationOutput {
100            output: Tensor::from_data(output, device),
101            loss: Tensor::from_data(loss, device),
102            targets: Tensor::from_data(targets, device),
103        }
104    }
105}
106
107impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
108    fn adapt(&self) -> HammingScoreInput<B> {
109        HammingScoreInput::new(self.output.clone(), self.targets.clone())
110    }
111}
112
113impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
114    fn adapt(&self) -> LossInput<B> {
115        LossInput::new(self.loss.clone())
116    }
117}
118
119impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
120    fn adapt(&self) -> ConfusionStatsInput<B> {
121        ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
122    }
123}