Skip to main content

yscv_optim/
clip.rs

1use yscv_autograd::{Graph, NodeId};
2
3/// Clips the total norm of gradients for the given nodes in-place.
4///
5/// Computes the combined norm (controlled by `norm_type`, typically 2.0 for L2)
6/// across all gradient tensors for `node_ids`. If the total norm exceeds
7/// `max_norm`, every gradient is scaled by `max_norm / total_norm`.
8///
9/// Returns the computed total norm before clipping (useful for monitoring).
10///
11/// Nodes without gradients are silently skipped.
12/// If `max_norm` is not positive or `node_ids` is empty, no clipping is
13/// performed and the function returns 0.0.
14pub fn clip_grad_norm_(
15    graph: &mut Graph,
16    node_ids: &[NodeId],
17    max_norm: f32,
18    norm_type: f32,
19) -> f32 {
20    if node_ids.is_empty() || !max_norm.is_finite() || max_norm <= 0.0 {
21        return 0.0;
22    }
23
24    // Accumulate the total norm across all gradient tensors (read-only pass).
25    let mut total_norm: f32 = if norm_type == f32::INFINITY {
26        // Inf-norm: max absolute value across all gradients.
27        let mut max_val: f32 = 0.0;
28        for &id in node_ids {
29            if let Ok(Some(grad)) = graph.grad(id) {
30                for &v in grad.data() {
31                    let abs = v.abs();
32                    if abs > max_val {
33                        max_val = abs;
34                    }
35                }
36            }
37        }
38        max_val
39    } else {
40        // General p-norm: (sum |g_i|^p)^(1/p).
41        let mut acc: f32 = 0.0;
42        for &id in node_ids {
43            if let Ok(Some(grad)) = graph.grad(id) {
44                for &v in grad.data() {
45                    acc += v.abs().powf(norm_type);
46                }
47            }
48        }
49        acc.powf(1.0 / norm_type)
50    };
51
52    if !total_norm.is_finite() {
53        total_norm = 0.0;
54    }
55
56    // Scale gradients in-place if total norm exceeds max_norm.
57    if total_norm > max_norm {
58        let scale = max_norm / total_norm;
59        for &id in node_ids {
60            if let Ok(Some(grad)) = graph.grad_mut(id) {
61                for v in grad.data_mut() {
62                    *v *= scale;
63                }
64            }
65        }
66    }
67
68    total_norm
69}
70
71/// Clamps every gradient element to the range `[-max_val, max_val]` in-place.
72///
73/// Nodes without gradients are silently skipped.
74/// If `max_val` is not positive or `node_ids` is empty, no clamping is performed.
75pub fn clip_grad_value_(graph: &mut Graph, node_ids: &[NodeId], max_val: f32) {
76    if node_ids.is_empty() || !max_val.is_finite() || max_val <= 0.0 {
77        return;
78    }
79
80    for &id in node_ids {
81        if let Ok(Some(grad)) = graph.grad_mut(id) {
82            for v in grad.data_mut() {
83                if *v > max_val {
84                    *v = max_val;
85                } else if *v < -max_val {
86                    *v = -max_val;
87                }
88            }
89        }
90    }
91}