Skip to main content

burn_optim/grad_clipping/
base.rs

1use burn_core as burn;
2
3use burn::tensor::backend::Backend;
4use burn::{config::Config, tensor::Tensor};
5
6/// Gradient Clipping provides a way to mitigate exploding gradients
7#[derive(Config, Debug)]
8pub enum GradientClippingConfig {
9    /// Clip the gradient by value.
10    Value(f32),
11
12    /// Clip the gradient by norm.
13    Norm(f32),
14}
15
16impl GradientClippingConfig {
17    /// Initialize the gradient clipping.
18    ///
19    /// # Returns
20    ///
21    /// The gradient clipping.
22    pub fn init(&self) -> GradientClipping {
23        match self {
24            GradientClippingConfig::Value(val) => GradientClipping::Value(*val),
25            GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val),
26        }
27    }
28}
29
30/// Gradient Clipping provides a way to mitigate exploding gradients
31/// by clipping every component of the gradient by value or by norm during
32/// backpropagation.
33#[derive(Clone)]
34pub enum GradientClipping {
35    /// Clip the gradient by value.
36    Value(f32),
37
38    /// Clip the gradient by norm.
39    Norm(f32),
40}
41
42impl GradientClipping {
43    /// Clip the gradient.
44    ///
45    /// # Arguments
46    ///
47    /// * `grad` - The gradient to clip.
48    ///
49    /// # Returns
50    ///
51    /// The clipped gradient.
52    pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {
53        match self {
54            GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),
55            GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm),
56        }
57    }
58
59    fn clip_by_value<B: Backend, const D: usize>(
60        &self,
61        grad: Tensor<B, D>,
62        threshold: f32,
63    ) -> Tensor<B, D> {
64        let greater_mask = grad.clone().greater_elem(threshold);
65        let lower_mask = grad.clone().lower_elem(-threshold);
66
67        let clipped_grad = grad.mask_fill(greater_mask, threshold);
68
69        clipped_grad.mask_fill(lower_mask, -threshold)
70    }
71
72    fn clip_by_norm<B: Backend, const D: usize>(
73        &self,
74        grad: Tensor<B, D>,
75        threshold: f32,
76    ) -> Tensor<B, D> {
77        let norm = Self::l2_norm(grad.clone());
78        let min_positive = grad
79            .dtype()
80            .finfo()
81            .unwrap_or(burn::tensor::FloatDType::F32.finfo())
82            .min_positive;
83        let clip_coef = threshold / norm.add_scalar(min_positive);
84        let clip_coef_clamped = clip_coef.clamp_max(1.0);
85        grad.mul(clip_coef_clamped.unsqueeze())
86    }
87
88    fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
89        let squared = tensor.square();
90        let sum = squared.sum();
91        sum.sqrt()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::TestBackend;
99    use burn::tensor::Tensor;
100
101    #[test]
102    fn test_clip_by_value() {
103        let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
104            [
105                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
106                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
107            ],
108            &Default::default(),
109        );
110
111        let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient);
112        let clipped_gradient_data = clipped_gradient.into_data();
113
114        for value in clipped_gradient_data.iter::<f32>() {
115            assert!(value <= 0.5);
116        }
117    }
118
119    #[test]
120    fn test_clip_by_norm() {
121        let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
122            [
123                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
124                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
125            ],
126            &Default::default(),
127        );
128
129        let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient);
130        let clipped_gradient_data = clipped_gradient.into_data();
131
132        for value in clipped_gradient_data.iter::<f32>() {
133            assert!(value <= 0.88);
134        }
135    }
136    #[test]
137    fn test_clip_by_norm_no_clipping() {
138        let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
139            [[0.3, 0.4, 0.5, 0.2], [0.1, 0.6, 0.3, 0.4]],
140            &Default::default(),
141        );
142
143        let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient.clone());
144
145        clipped_gradient
146            .into_data()
147            .assert_eq(&gradient.into_data(), true);
148    }
149}