Skip to main content

entrenar/optim/
clip.rs

1//! Gradient clipping utilities
2
3use crate::Tensor;
4
5/// Clip gradients by global norm
6///
7/// Computes the global norm of all gradients and scales them down if the norm
8/// exceeds max_norm. This prevents exploding gradients while preserving the
9/// relative magnitudes of gradients across parameters.
10///
11/// Algorithm:
12/// 1. global_norm = sqrt(sum of all gradient squared norms)
13/// 2. If global_norm > max_norm:
14///    - clip_coef = max_norm / global_norm
15///    - For each gradient: grad *= clip_coef
16///
17/// # Arguments
18/// * `params` - Mutable slice of parameters with gradients
19/// * `max_norm` - Maximum allowed global norm
20///
21/// # Returns
22/// The actual global norm before clipping
23pub fn clip_grad_norm(params: &mut [Tensor], max_norm: f32) -> f32 {
24    // Compute global norm: sqrt(sum of squared norms)
25    let mut total_norm_sq = 0.0;
26
27    for param in params.iter() {
28        if let Some(grad) = param.grad() {
29            // Compute squared norm of this gradient
30            let grad_norm_sq: f32 = grad.iter().map(|&g| g * g).sum();
31            total_norm_sq += grad_norm_sq;
32        }
33    }
34
35    let global_norm = total_norm_sq.sqrt();
36
37    // Only clip if global norm exceeds max_norm
38    if global_norm > max_norm {
39        let clip_coef = max_norm / global_norm;
40
41        // Scale all gradients
42        for param in params.iter_mut() {
43            if let Some(grad) = param.grad() {
44                let clipped_grad = grad * clip_coef;
45                param.set_grad(clipped_grad);
46            }
47        }
48    }
49
50    global_norm
51}
52
53/// Clip gradients by global norm on borrowed parameter references.
54///
55/// Identical to [`clip_grad_norm`] but accepts `&mut [&mut Tensor]` instead of
56/// `&mut [Tensor]`. This is useful when parameters are collected as mutable
57/// references from a model (e.g., LoRA layers + classification head).
58///
59/// # Arguments
60/// * `params` - Mutable slice of parameter references with gradients
61/// * `max_norm` - Maximum allowed global norm
62///
63/// # Returns
64/// The actual global norm before clipping
65pub fn clip_grad_norm_refs(params: &mut [&mut Tensor], max_norm: f32) -> f32 {
66    // Compute global norm: sqrt(sum of squared norms)
67    let mut total_norm_sq = 0.0;
68
69    for param in params.iter() {
70        if let Some(grad) = param.grad() {
71            let grad_norm_sq: f32 = grad.iter().map(|&g| g * g).sum();
72            total_norm_sq += grad_norm_sq;
73        }
74    }
75
76    let global_norm = total_norm_sq.sqrt();
77
78    // Only clip if global norm exceeds max_norm
79    if global_norm > max_norm {
80        let clip_coef = max_norm / global_norm;
81
82        for param in params.iter_mut() {
83            if let Some(grad) = param.grad() {
84                let clipped_grad = grad * clip_coef;
85                param.set_grad(clipped_grad);
86            }
87        }
88    }
89
90    global_norm
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::autograd::*;
97    use approx::assert_abs_diff_eq;
98
99    #[test]
100    fn test_clip_grad_norm_no_clipping() {
101        // Gradients with norm below threshold shouldn't be clipped
102        let mut params =
103            vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0], true)];
104
105        // Set small gradients
106        params[0].set_grad(ndarray::arr1(&[0.1, 0.2]));
107        params[1].set_grad(ndarray::arr1(&[0.1]));
108
109        // Global norm = sqrt(0.1^2 + 0.2^2 + 0.1^2) = sqrt(0.06) ≈ 0.245
110        let global_norm = clip_grad_norm(&mut params, 1.0);
111
112        assert_abs_diff_eq!(global_norm, 0.245, epsilon = 1e-3);
113
114        // Gradients should be unchanged
115        assert_abs_diff_eq!(
116            params[0].grad().expect("gradient should be available")[0],
117            0.1,
118            epsilon = 1e-6
119        );
120        assert_abs_diff_eq!(
121            params[0].grad().expect("gradient should be available")[1],
122            0.2,
123            epsilon = 1e-6
124        );
125        assert_abs_diff_eq!(
126            params[1].grad().expect("gradient should be available")[0],
127            0.1,
128            epsilon = 1e-6
129        );
130    }
131
132    #[test]
133    fn test_clip_grad_norm_with_clipping() {
134        // Gradients with norm above threshold should be scaled
135        let mut params =
136            vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0], true)];
137
138        // Set large gradients
139        params[0].set_grad(ndarray::arr1(&[3.0, 4.0]));
140        params[1].set_grad(ndarray::arr1(&[0.0]));
141
142        // Global norm = sqrt(3^2 + 4^2 + 0^2) = sqrt(25) = 5.0
143        let global_norm = clip_grad_norm(&mut params, 1.0);
144
145        assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
146
147        // Gradients should be scaled by clip_coef = 1.0 / 5.0 = 0.2
148        assert_abs_diff_eq!(
149            params[0].grad().expect("gradient should be available")[0],
150            0.6,
151            epsilon = 1e-6
152        ); // 3.0 * 0.2
153        assert_abs_diff_eq!(
154            params[0].grad().expect("gradient should be available")[1],
155            0.8,
156            epsilon = 1e-6
157        ); // 4.0 * 0.2
158        assert_abs_diff_eq!(
159            params[1].grad().expect("gradient should be available")[0],
160            0.0,
161            epsilon = 1e-6
162        ); // 0.0 * 0.2
163    }
164
165    #[test]
166    fn test_clip_grad_norm_exactly_at_threshold() {
167        let mut params = vec![Tensor::from_vec(vec![3.0, 4.0], true)];
168
169        // Set gradients with norm exactly equal to max_norm
170        params[0].set_grad(ndarray::arr1(&[3.0, 4.0])); // norm = 5.0
171
172        let global_norm = clip_grad_norm(&mut params, 5.0);
173
174        assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
175
176        // Should not be clipped (norm == max_norm, not >)
177        assert_abs_diff_eq!(
178            params[0].grad().expect("gradient should be available")[0],
179            3.0,
180            epsilon = 1e-6
181        );
182        assert_abs_diff_eq!(
183            params[0].grad().expect("gradient should be available")[1],
184            4.0,
185            epsilon = 1e-6
186        );
187    }
188
189    #[test]
190    fn test_clip_grad_norm_preserves_relative_magnitudes() {
191        let mut params = vec![Tensor::from_vec(vec![1.0], true), Tensor::from_vec(vec![1.0], true)];
192
193        // Set gradients with different magnitudes
194        params[0].set_grad(ndarray::arr1(&[10.0]));
195        params[1].set_grad(ndarray::arr1(&[5.0]));
196
197        // Global norm = sqrt(10^2 + 5^2) = sqrt(125) ≈ 11.18
198        let _global_norm = clip_grad_norm(&mut params, 1.0);
199
200        let grad0 = params[0].grad().expect("gradient should be available")[0];
201        let grad1 = params[1].grad().expect("gradient should be available")[0];
202
203        // Relative magnitude should be preserved: grad0 / grad1 ≈ 10 / 5 = 2
204        assert_abs_diff_eq!(grad0 / grad1, 2.0, epsilon = 1e-4);
205    }
206
207    #[test]
208    fn test_clip_grad_norm_no_gradients() {
209        // Parameters without gradients
210        let mut params = vec![
211            Tensor::from_vec(vec![1.0, 2.0], false), // requires_grad = false
212            Tensor::from_vec(vec![3.0], false),
213        ];
214
215        let global_norm = clip_grad_norm(&mut params, 1.0);
216
217        // Global norm should be 0 (no gradients)
218        assert_abs_diff_eq!(global_norm, 0.0, epsilon = 1e-6);
219    }
220
221    #[test]
222    fn test_clip_grad_norm_mixed_gradients() {
223        // Some params have gradients, others don't
224        let mut params = vec![Tensor::from_vec(vec![1.0], true), Tensor::from_vec(vec![1.0], true)];
225
226        params[0].set_grad(ndarray::arr1(&[3.0]));
227        // params[1] has no gradient set
228
229        // Global norm = sqrt(3^2) = 3.0
230        let global_norm = clip_grad_norm(&mut params, 1.0);
231
232        assert_abs_diff_eq!(global_norm, 3.0, epsilon = 1e-6);
233
234        // Only params[0] should be clipped
235        assert_abs_diff_eq!(
236            params[0].grad().expect("gradient should be available")[0],
237            1.0,
238            epsilon = 1e-6
239        ); // 3.0 * (1.0/3.0)
240        assert!(params[1].grad().is_none()); // No gradient
241    }
242
243    #[test]
244    fn test_clip_grad_norm_zero_max_norm() {
245        let mut params = vec![Tensor::from_vec(vec![1.0], true)];
246        params[0].set_grad(ndarray::arr1(&[5.0]));
247
248        let global_norm = clip_grad_norm(&mut params, 0.0);
249
250        assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
251
252        // With max_norm = 0, gradients should be clipped to 0
253        assert_abs_diff_eq!(
254            params[0].grad().expect("gradient should be available")[0],
255            0.0,
256            epsilon = 1e-6
257        );
258    }
259
260    // ── clip_grad_norm_refs tests (SSC-025) ──────────────────────────
261
262    #[test]
263    fn test_clip_grad_norm_refs_no_clipping() {
264        let mut p0 = Tensor::from_vec(vec![1.0, 2.0], true);
265        let mut p1 = Tensor::from_vec(vec![3.0], true);
266        p0.set_grad(ndarray::arr1(&[0.1, 0.2]));
267        p1.set_grad(ndarray::arr1(&[0.1]));
268
269        let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
270        assert_abs_diff_eq!(global_norm, 0.245, epsilon = 1e-3);
271
272        assert_abs_diff_eq!(
273            p0.grad().expect("gradient should be available")[0],
274            0.1,
275            epsilon = 1e-6
276        );
277        assert_abs_diff_eq!(
278            p0.grad().expect("gradient should be available")[1],
279            0.2,
280            epsilon = 1e-6
281        );
282        assert_abs_diff_eq!(
283            p1.grad().expect("gradient should be available")[0],
284            0.1,
285            epsilon = 1e-6
286        );
287    }
288
289    #[test]
290    fn test_clip_grad_norm_refs_with_clipping() {
291        let mut p0 = Tensor::from_vec(vec![1.0, 2.0], true);
292        let mut p1 = Tensor::from_vec(vec![3.0], true);
293        p0.set_grad(ndarray::arr1(&[3.0, 4.0]));
294        p1.set_grad(ndarray::arr1(&[0.0]));
295
296        let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
297        assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
298
299        assert_abs_diff_eq!(
300            p0.grad().expect("gradient should be available")[0],
301            0.6,
302            epsilon = 1e-6
303        );
304        assert_abs_diff_eq!(
305            p0.grad().expect("gradient should be available")[1],
306            0.8,
307            epsilon = 1e-6
308        );
309        assert_abs_diff_eq!(
310            p1.grad().expect("gradient should be available")[0],
311            0.0,
312            epsilon = 1e-6
313        );
314    }
315
316    #[test]
317    fn test_clip_grad_norm_refs_preserves_relative_magnitudes() {
318        let mut p0 = Tensor::from_vec(vec![1.0], true);
319        let mut p1 = Tensor::from_vec(vec![1.0], true);
320        p0.set_grad(ndarray::arr1(&[10.0]));
321        p1.set_grad(ndarray::arr1(&[5.0]));
322
323        let _global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
324
325        let grad0 = p0.grad().expect("gradient should be available")[0];
326        let grad1 = p1.grad().expect("gradient should be available")[0];
327        assert_abs_diff_eq!(grad0 / grad1, 2.0, epsilon = 1e-4);
328    }
329
330    #[test]
331    fn test_clip_grad_norm_refs_no_gradients() {
332        let mut p0 = Tensor::from_vec(vec![1.0, 2.0], false);
333        let mut p1 = Tensor::from_vec(vec![3.0], false);
334
335        let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
336        assert_abs_diff_eq!(global_norm, 0.0, epsilon = 1e-6);
337    }
338}