Skip to main content

burn_nn/loss/
kldiv.rs

1use burn_core as burn;
2
3use super::Reduction;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::Tensor;
6use burn::tensor::backend::Backend;
7use burn::{config::Config, module::Module};
8
9/// Configuration to create a [KLDiv loss](KLDivLoss).
10#[derive(Config, Debug)]
11pub struct KLDivLossConfig {
12    /// Specifies whether target is the log space. Default: False.
13    #[config(default = false)]
14    pub log_target: bool,
15}
16
17impl KLDivLossConfig {
18    /// Initialize [KLDiv Loss](KLDivLoss).
19    pub fn init(&self) -> KLDivLoss {
20        KLDivLoss {
21            log_target: self.log_target,
22        }
23    }
24}
25
26/// Kullback-Leibler Divergence Loss
27///
28/// KL Divergence shows the difference between two probability distributions by measuring information loss
29///
30/// KLDivLoss =
31/// ```tex
32/// y_{true} \cdot (\log{y_{true}} - \log{y_{pred}})
33///     ```
34/// By default, the loss expects the input in the log-space.
35/// The targets may also be provided in the log-space if `log_target` is true.
36///
37/// See
38/// - [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence)
39#[derive(Module, Debug, Clone)]
40#[module(custom_display)]
41pub struct KLDivLoss {
42    /// Specifies whether target is the log space. Default: False.
43    pub log_target: bool,
44}
45
46impl ModuleDisplay for KLDivLoss {
47    fn custom_settings(&self) -> Option<DisplaySettings> {
48        DisplaySettings::new()
49            .with_new_line_after_attribute(false)
50            .optional()
51    }
52
53    fn custom_content(&self, content: Content) -> Option<Content> {
54        content.add("log_target", &self.log_target).optional()
55    }
56}
57
58impl KLDivLoss {
59    /// Compute the criterion on the input tensor.
60    ///
61    /// `Reduction::Auto` behaves as `Reduction::BatchMean`,`Reduction::Mean` dose not align with the math definition.
62    ///
63    /// # Shapes
64    ///
65    /// - predictions: \[batch_size,num_targets\]
66    /// - targets: \[batch_size,num_targets\]
67    /// - output: \[1\]
68    pub fn forward<const D: usize, B: Backend>(
69        &self,
70        predictions: Tensor<B, D>,
71        targets: Tensor<B, D>,
72        reduction: Reduction,
73    ) -> Tensor<B, 1> {
74        let loss = self.forward_no_reduction(predictions, targets);
75        match reduction {
76            Reduction::BatchMean | Reduction::Auto => {
77                let batch_size = loss.dims()[0] as f32;
78                loss.sum().div_scalar(batch_size)
79            }
80            Reduction::Mean => loss.mean(),
81            Reduction::Sum => loss.sum(),
82        }
83    }
84    /// Compute the criterion on the input tensor without reducing.
85    pub fn forward_no_reduction<const D: usize, B: Backend>(
86        &self,
87        predictions: Tensor<B, D>,
88        targets: Tensor<B, D>,
89    ) -> Tensor<B, D> {
90        match self.log_target {
91            true => targets.clone().exp().mul(targets.sub(predictions)),
92            false => {
93                let epsilon = targets
94                    .dtype()
95                    .finfo()
96                    .unwrap_or(burn::tensor::FloatDType::F32.finfo())
97                    .min_positive;
98                let log_target = targets.clone().clamp(epsilon, 1.0).log();
99                targets.mul(log_target.sub(predictions))
100            }
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::TestBackend;
109    use burn::tensor::TensorData;
110    type TestTensor<const D: usize> = Tensor<TestBackend, D>;
111    use burn::tensor::{Tolerance, ops::FloatElem};
112    type FT = FloatElem<TestBackend>;
113
114    #[test]
115    fn test_kl_div_loss() {
116        let predict = TensorData::from([[-1.0, -0.5], [-2.0, -0.2]]);
117        let targets = TensorData::from([[0.4, 0.6], [0.1, 0.9]]);
118
119        let device = Default::default();
120        let predict = TestTensor::<2>::from_data(predict, &device);
121        let targets = TestTensor::<2>::from_data(targets, &device);
122
123        let kl_loss = KLDivLossConfig { log_target: false }.init();
124
125        let loss_sum = kl_loss.forward(predict.clone(), targets.clone(), Reduction::Sum);
126        let loss_batch_mean =
127            kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
128        let loss_no_reduction = kl_loss.forward_no_reduction(predict, targets);
129
130        let expected_no_reduction =
131            TensorData::from([[0.0334837139, -0.0064953566], [-0.0302585065, 0.0851755068]]);
132        loss_no_reduction
133            .into_data()
134            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::absolute(1e-5));
135
136        let expected_sum = TensorData::from([0.08191]);
137        loss_sum
138            .into_data()
139            .assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
140
141        let expected_batch_mean = TensorData::from([0.04095]);
142        loss_batch_mean
143            .into_data()
144            .assert_approx_eq::<FT>(&expected_batch_mean, Tolerance::absolute(1e-5));
145    }
146
147    #[test]
148    fn test_kl_div_loss_log_target() {
149        let device = Default::default();
150        let predict = TestTensor::<1>::from_data([-1.0, -2.0], &device);
151        let targets = TestTensor::<1>::from_data([-0.5, -1.5], &device);
152
153        let kl_loss = KLDivLossConfig { log_target: true }.init();
154
155        let loss_no_reduction = kl_loss.forward_no_reduction(predict.clone(), targets.clone());
156        let expected_none = TensorData::from([0.3032653299, 0.1115650801]);
157        loss_no_reduction
158            .into_data()
159            .assert_approx_eq::<FT>(&expected_none, Tolerance::absolute(1e-5));
160
161        let loss_batch_mean =
162            kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);
163        let expected_bm = TensorData::from([0.207415204965]);
164        loss_batch_mean
165            .into_data()
166            .assert_approx_eq::<FT>(&expected_bm, Tolerance::absolute(1e-5));
167
168        let loss_sum = kl_loss.forward(predict, targets, Reduction::Sum);
169        let expected_sum = TensorData::from([0.414830409931]);
170        loss_sum
171            .into_data()
172            .assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));
173    }
174
175    #[cfg(feature = "std")]
176    #[test]
177    fn test_kl_div_ad_loss() {
178        type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 2>;
179
180        let device = Default::default();
181        let predict = TestAutodiffTensor::from_data([[-1.0, -0.5]], &device).require_grad();
182        let targets = TestAutodiffTensor::from_data([[0.4, 0.6]], &device);
183
184        let kl_loss = KLDivLossConfig { log_target: false }.init();
185        let loss = kl_loss.forward(predict.clone(), targets, Reduction::Sum);
186
187        let grads = loss.backward();
188        let grads_predict = predict.grad(&grads).unwrap();
189
190        // d/d_pred [target * (log_target - pred)] = -target
191        let expected = TensorData::from([[-0.4, -0.6]]);
192        grads_predict
193            .to_data()
194            .assert_approx_eq::<FT>(&expected, Tolerance::default());
195    }
196
197    #[test]
198    fn display() {
199        let config = KLDivLossConfig { log_target: true };
200        let loss = config.init();
201
202        assert_eq!(alloc::format!("{loss}"), "KLDivLoss {log_target: true}");
203    }
204}