burn_optim/grad_clipping/
base.rs1use burn_core as burn;
2
3use burn::tensor::backend::Backend;
4use burn::{config::Config, tensor::Tensor};
5
6#[derive(Config, Debug)]
8pub enum GradientClippingConfig {
9 Value(f32),
11
12 Norm(f32),
14}
15
16impl GradientClippingConfig {
17 pub fn init(&self) -> GradientClipping {
23 match self {
24 GradientClippingConfig::Value(val) => GradientClipping::Value(*val),
25 GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val),
26 }
27 }
28}
29
30#[derive(Clone)]
34pub enum GradientClipping {
35 Value(f32),
37
38 Norm(f32),
40}
41
42impl GradientClipping {
43 pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {
53 match self {
54 GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),
55 GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm),
56 }
57 }
58
59 fn clip_by_value<B: Backend, const D: usize>(
60 &self,
61 grad: Tensor<B, D>,
62 threshold: f32,
63 ) -> Tensor<B, D> {
64 let greater_mask = grad.clone().greater_elem(threshold);
65 let lower_mask = grad.clone().lower_elem(-threshold);
66
67 let clipped_grad = grad.mask_fill(greater_mask, threshold);
68
69 clipped_grad.mask_fill(lower_mask, -threshold)
70 }
71
72 fn clip_by_norm<B: Backend, const D: usize>(
73 &self,
74 grad: Tensor<B, D>,
75 threshold: f32,
76 ) -> Tensor<B, D> {
77 let norm = Self::l2_norm(grad.clone());
78 let min_positive = grad
79 .dtype()
80 .finfo()
81 .unwrap_or(burn::tensor::FloatDType::F32.finfo())
82 .min_positive;
83 let clip_coef = threshold / norm.add_scalar(min_positive);
84 let clip_coef_clamped = clip_coef.clamp_max(1.0);
85 grad.mul(clip_coef_clamped.unsqueeze())
86 }
87
88 fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {
89 let squared = tensor.square();
90 let sum = squared.sum();
91 sum.sqrt()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::TestBackend;
99 use burn::tensor::Tensor;
100
101 #[test]
102 fn test_clip_by_value() {
103 let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
104 [
105 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
106 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
107 ],
108 &Default::default(),
109 );
110
111 let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient);
112 let clipped_gradient_data = clipped_gradient.into_data();
113
114 for value in clipped_gradient_data.iter::<f32>() {
115 assert!(value <= 0.5);
116 }
117 }
118
119 #[test]
120 fn test_clip_by_norm() {
121 let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
122 [
123 [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
124 [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
125 ],
126 &Default::default(),
127 );
128
129 let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient);
130 let clipped_gradient_data = clipped_gradient.into_data();
131
132 for value in clipped_gradient_data.iter::<f32>() {
133 assert!(value <= 0.88);
134 }
135 }
136 #[test]
137 fn test_clip_by_norm_no_clipping() {
138 let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(
139 [[0.3, 0.4, 0.5, 0.2], [0.1, 0.6, 0.3, 0.4]],
140 &Default::default(),
141 );
142
143 let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient.clone());
144
145 clipped_gradient
146 .into_data()
147 .assert_eq(&gradient.into_data(), true);
148 }
149}