1use crate::Tensor;
4
5pub fn clip_grad_norm(params: &mut [Tensor], max_norm: f32) -> f32 {
24 let mut total_norm_sq = 0.0;
26
27 for param in params.iter() {
28 if let Some(grad) = param.grad() {
29 let grad_norm_sq: f32 = grad.iter().map(|&g| g * g).sum();
31 total_norm_sq += grad_norm_sq;
32 }
33 }
34
35 let global_norm = total_norm_sq.sqrt();
36
37 if global_norm > max_norm {
39 let clip_coef = max_norm / global_norm;
40
41 for param in params.iter_mut() {
43 if let Some(grad) = param.grad() {
44 let clipped_grad = grad * clip_coef;
45 param.set_grad(clipped_grad);
46 }
47 }
48 }
49
50 global_norm
51}
52
53pub fn clip_grad_norm_refs(params: &mut [&mut Tensor], max_norm: f32) -> f32 {
66 let mut total_norm_sq = 0.0;
68
69 for param in params.iter() {
70 if let Some(grad) = param.grad() {
71 let grad_norm_sq: f32 = grad.iter().map(|&g| g * g).sum();
72 total_norm_sq += grad_norm_sq;
73 }
74 }
75
76 let global_norm = total_norm_sq.sqrt();
77
78 if global_norm > max_norm {
80 let clip_coef = max_norm / global_norm;
81
82 for param in params.iter_mut() {
83 if let Some(grad) = param.grad() {
84 let clipped_grad = grad * clip_coef;
85 param.set_grad(clipped_grad);
86 }
87 }
88 }
89
90 global_norm
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::autograd::*;
97 use approx::assert_abs_diff_eq;
98
99 #[test]
100 fn test_clip_grad_norm_no_clipping() {
101 let mut params =
103 vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0], true)];
104
105 params[0].set_grad(ndarray::arr1(&[0.1, 0.2]));
107 params[1].set_grad(ndarray::arr1(&[0.1]));
108
109 let global_norm = clip_grad_norm(&mut params, 1.0);
111
112 assert_abs_diff_eq!(global_norm, 0.245, epsilon = 1e-3);
113
114 assert_abs_diff_eq!(
116 params[0].grad().expect("gradient should be available")[0],
117 0.1,
118 epsilon = 1e-6
119 );
120 assert_abs_diff_eq!(
121 params[0].grad().expect("gradient should be available")[1],
122 0.2,
123 epsilon = 1e-6
124 );
125 assert_abs_diff_eq!(
126 params[1].grad().expect("gradient should be available")[0],
127 0.1,
128 epsilon = 1e-6
129 );
130 }
131
132 #[test]
133 fn test_clip_grad_norm_with_clipping() {
134 let mut params =
136 vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0], true)];
137
138 params[0].set_grad(ndarray::arr1(&[3.0, 4.0]));
140 params[1].set_grad(ndarray::arr1(&[0.0]));
141
142 let global_norm = clip_grad_norm(&mut params, 1.0);
144
145 assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
146
147 assert_abs_diff_eq!(
149 params[0].grad().expect("gradient should be available")[0],
150 0.6,
151 epsilon = 1e-6
152 ); assert_abs_diff_eq!(
154 params[0].grad().expect("gradient should be available")[1],
155 0.8,
156 epsilon = 1e-6
157 ); assert_abs_diff_eq!(
159 params[1].grad().expect("gradient should be available")[0],
160 0.0,
161 epsilon = 1e-6
162 ); }
164
165 #[test]
166 fn test_clip_grad_norm_exactly_at_threshold() {
167 let mut params = vec![Tensor::from_vec(vec![3.0, 4.0], true)];
168
169 params[0].set_grad(ndarray::arr1(&[3.0, 4.0])); let global_norm = clip_grad_norm(&mut params, 5.0);
173
174 assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
175
176 assert_abs_diff_eq!(
178 params[0].grad().expect("gradient should be available")[0],
179 3.0,
180 epsilon = 1e-6
181 );
182 assert_abs_diff_eq!(
183 params[0].grad().expect("gradient should be available")[1],
184 4.0,
185 epsilon = 1e-6
186 );
187 }
188
189 #[test]
190 fn test_clip_grad_norm_preserves_relative_magnitudes() {
191 let mut params = vec![Tensor::from_vec(vec![1.0], true), Tensor::from_vec(vec![1.0], true)];
192
193 params[0].set_grad(ndarray::arr1(&[10.0]));
195 params[1].set_grad(ndarray::arr1(&[5.0]));
196
197 let _global_norm = clip_grad_norm(&mut params, 1.0);
199
200 let grad0 = params[0].grad().expect("gradient should be available")[0];
201 let grad1 = params[1].grad().expect("gradient should be available")[0];
202
203 assert_abs_diff_eq!(grad0 / grad1, 2.0, epsilon = 1e-4);
205 }
206
207 #[test]
208 fn test_clip_grad_norm_no_gradients() {
209 let mut params = vec![
211 Tensor::from_vec(vec![1.0, 2.0], false), Tensor::from_vec(vec![3.0], false),
213 ];
214
215 let global_norm = clip_grad_norm(&mut params, 1.0);
216
217 assert_abs_diff_eq!(global_norm, 0.0, epsilon = 1e-6);
219 }
220
221 #[test]
222 fn test_clip_grad_norm_mixed_gradients() {
223 let mut params = vec![Tensor::from_vec(vec![1.0], true), Tensor::from_vec(vec![1.0], true)];
225
226 params[0].set_grad(ndarray::arr1(&[3.0]));
227 let global_norm = clip_grad_norm(&mut params, 1.0);
231
232 assert_abs_diff_eq!(global_norm, 3.0, epsilon = 1e-6);
233
234 assert_abs_diff_eq!(
236 params[0].grad().expect("gradient should be available")[0],
237 1.0,
238 epsilon = 1e-6
239 ); assert!(params[1].grad().is_none()); }
242
243 #[test]
244 fn test_clip_grad_norm_zero_max_norm() {
245 let mut params = vec![Tensor::from_vec(vec![1.0], true)];
246 params[0].set_grad(ndarray::arr1(&[5.0]));
247
248 let global_norm = clip_grad_norm(&mut params, 0.0);
249
250 assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
251
252 assert_abs_diff_eq!(
254 params[0].grad().expect("gradient should be available")[0],
255 0.0,
256 epsilon = 1e-6
257 );
258 }
259
260 #[test]
263 fn test_clip_grad_norm_refs_no_clipping() {
264 let mut p0 = Tensor::from_vec(vec![1.0, 2.0], true);
265 let mut p1 = Tensor::from_vec(vec![3.0], true);
266 p0.set_grad(ndarray::arr1(&[0.1, 0.2]));
267 p1.set_grad(ndarray::arr1(&[0.1]));
268
269 let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
270 assert_abs_diff_eq!(global_norm, 0.245, epsilon = 1e-3);
271
272 assert_abs_diff_eq!(
273 p0.grad().expect("gradient should be available")[0],
274 0.1,
275 epsilon = 1e-6
276 );
277 assert_abs_diff_eq!(
278 p0.grad().expect("gradient should be available")[1],
279 0.2,
280 epsilon = 1e-6
281 );
282 assert_abs_diff_eq!(
283 p1.grad().expect("gradient should be available")[0],
284 0.1,
285 epsilon = 1e-6
286 );
287 }
288
289 #[test]
290 fn test_clip_grad_norm_refs_with_clipping() {
291 let mut p0 = Tensor::from_vec(vec![1.0, 2.0], true);
292 let mut p1 = Tensor::from_vec(vec![3.0], true);
293 p0.set_grad(ndarray::arr1(&[3.0, 4.0]));
294 p1.set_grad(ndarray::arr1(&[0.0]));
295
296 let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
297 assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
298
299 assert_abs_diff_eq!(
300 p0.grad().expect("gradient should be available")[0],
301 0.6,
302 epsilon = 1e-6
303 );
304 assert_abs_diff_eq!(
305 p0.grad().expect("gradient should be available")[1],
306 0.8,
307 epsilon = 1e-6
308 );
309 assert_abs_diff_eq!(
310 p1.grad().expect("gradient should be available")[0],
311 0.0,
312 epsilon = 1e-6
313 );
314 }
315
316 #[test]
317 fn test_clip_grad_norm_refs_preserves_relative_magnitudes() {
318 let mut p0 = Tensor::from_vec(vec![1.0], true);
319 let mut p1 = Tensor::from_vec(vec![1.0], true);
320 p0.set_grad(ndarray::arr1(&[10.0]));
321 p1.set_grad(ndarray::arr1(&[5.0]));
322
323 let _global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
324
325 let grad0 = p0.grad().expect("gradient should be available")[0];
326 let grad1 = p1.grad().expect("gradient should be available")[0];
327 assert_abs_diff_eq!(grad0 / grad1, 2.0, epsilon = 1e-4);
328 }
329
330 #[test]
331 fn test_clip_grad_norm_refs_no_gradients() {
332 let mut p0 = Tensor::from_vec(vec![1.0, 2.0], false);
333 let mut p1 = Tensor::from_vec(vec![3.0], false);
334
335 let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
336 assert_abs_diff_eq!(global_norm, 0.0, epsilon = 1e-6);
337 }
338}