1use burn_core as burn;
2
3use super::Reduction;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::Tensor;
6use burn::tensor::backend::Backend;
7use burn::{config::Config, module::Module};
8
9#[derive(Config, Debug)]
11pub struct KLDivLossConfig {
12 #[config(default = false)]
14 pub log_target: bool,
15}
16
17impl KLDivLossConfig {
18 pub fn init(&self) -> KLDivLoss {
20 KLDivLoss {
21 log_target: self.log_target,
22 }
23 }
24}
25
26#[derive(Module, Debug, Clone)]
40#[module(custom_display)]
41pub struct KLDivLoss {
42 pub log_target: bool,
44}
45
46impl ModuleDisplay for KLDivLoss {
47 fn custom_settings(&self) -> Option<DisplaySettings> {
48 DisplaySettings::new()
49 .with_new_line_after_attribute(false)
50 .optional()
51 }
52
53 fn custom_content(&self, content: Content) -> Option<Content> {
54 content.add("log_target", &self.log_target).optional()
55 }
56}
57
58impl KLDivLoss {
59 pub fn forward<const D: usize, B: Backend>(
69 &self,
70 predictions: Tensor<B, D>,
71 targets: Tensor<B, D>,
72 reduction: Reduction,
73 ) -> Tensor<B, 1> {
74 let loss = self.forward_no_reduction(predictions, targets);
75 match reduction {
76 Reduction::BatchMean | Reduction::Auto => {
77 let batch_size = loss.dims()[0] as f32;
78 loss.sum().div_scalar(batch_size)
79 }
80 Reduction::Mean => loss.mean(),
81 Reduction::Sum => loss.sum(),
82 }
83 }
84 pub fn forward_no_reduction<const D: usize, B: Backend>(
86 &self,
87 predictions: Tensor<B, D>,
88 targets: Tensor<B, D>,
89 ) -> Tensor<B, D> {
90 match self.log_target {
91 true => targets.clone().exp().mul(targets.sub(predictions)),
92 false => {
93 let epsilon = targets
94 .dtype()
95 .finfo()
96 .unwrap_or(burn::tensor::FloatDType::F32.finfo())
97 .min_positive;
98 let log_target = targets.clone().clamp(epsilon, 1.0).log();
99 targets.mul(log_target.sub(predictions))
100 }
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::TestBackend;
109 use burn::tensor::TensorData;
110 type TestTensor<const D: usize> = Tensor<TestBackend, D>;
111 use burn::tensor::{Tolerance, ops::FloatElem};
112 type FT = FloatElem<TestBackend>;
113
114 #[test]
115 fn test_kl_div_loss() {
116 let predict = TensorData::from([[-1.0, -0.5], [-2.0, -0.2]]);
117 let targets = TensorData::from([[0.4, 0.6], [0.1, 0.9]]);
118
119 let device = Default::default();
120 let predict = TestTensor::<2>::from_data(predict, &device);
121 let targets = TestTensor::<2>::from_data(targets, &device);
122
123 let kl_loss = KLDivLossConfig { log_target: false }.init();
124
125 let loss_sum = kl_loss.forward(predict.clone(), targets.clone(), Reduction::Sum);
126 let loss_batch_mean =
127 kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
128 let loss_no_reduction = kl_loss.forward_no_reduction(predict, targets);
129
130 let expected_no_reduction =
131 TensorData::from([[0.0334837139, -0.0064953566], [-0.0302585065, 0.0851755068]]);
132 loss_no_reduction
133 .into_data()
134 .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::absolute(1e-5));
135
136 let expected_sum = TensorData::from([0.08191]);
137 loss_sum
138 .into_data()
139 .assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
140
141 let expected_batch_mean = TensorData::from([0.04095]);
142 loss_batch_mean
143 .into_data()
144 .assert_approx_eq::<FT>(&expected_batch_mean, Tolerance::absolute(1e-5));
145 }
146
147 #[test]
148 fn test_kl_div_loss_log_target() {
149 let device = Default::default();
150 let predict = TestTensor::<1>::from_data([-1.0, -2.0], &device);
151 let targets = TestTensor::<1>::from_data([-0.5, -1.5], &device);
152
153 let kl_loss = KLDivLossConfig { log_target: true }.init();
154
155 let loss_no_reduction = kl_loss.forward_no_reduction(predict.clone(), targets.clone());
156 let expected_none = TensorData::from([0.3032653299, 0.1115650801]);
157 loss_no_reduction
158 .into_data()
159 .assert_approx_eq::<FT>(&expected_none, Tolerance::absolute(1e-5));
160
161 let loss_batch_mean =
162 kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
163 let expected_bm = TensorData::from([0.207415204965]);
164 loss_batch_mean
165 .into_data()
166 .assert_approx_eq::<FT>(&expected_bm, Tolerance::absolute(1e-5));
167
168 let loss_sum = kl_loss.forward(predict, targets, Reduction::Sum);
169 let expected_sum = TensorData::from([0.414830409931]);
170 loss_sum
171 .into_data()
172 .assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
173 }
174
175 #[cfg(feature = "std")]
176 #[test]
177 fn test_kl_div_ad_loss() {
178 type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 2>;
179
180 let device = Default::default();
181 let predict = TestAutodiffTensor::from_data([[-1.0, -0.5]], &device).require_grad();
182 let targets = TestAutodiffTensor::from_data([[0.4, 0.6]], &device);
183
184 let kl_loss = KLDivLossConfig { log_target: false }.init();
185 let loss = kl_loss.forward(predict.clone(), targets, Reduction::Sum);
186
187 let grads = loss.backward();
188 let grads_predict = predict.grad(&grads).unwrap();
189
190 let expected = TensorData::from([[-0.4, -0.6]]);
192 grads_predict
193 .to_data()
194 .assert_approx_eq::<FT>(&expected, Tolerance::default());
195 }
196
197 #[test]
198 fn display() {
199 let config = KLDivLossConfig { log_target: true };
200 let loss = config.init();
201
202 assert_eq!(alloc::format!("{loss}"), "KLDivLoss {log_target: true}");
203 }
204}