use yscv_autograd::{Graph, NodeId};
pub fn clip_grad_norm_(
graph: &mut Graph,
node_ids: &[NodeId],
max_norm: f32,
norm_type: f32,
) -> f32 {
if node_ids.is_empty() || !max_norm.is_finite() || max_norm <= 0.0 {
return 0.0;
}
let mut total_norm: f32 = if norm_type == f32::INFINITY {
let mut max_val: f32 = 0.0;
for &id in node_ids {
if let Ok(Some(grad)) = graph.grad(id) {
for &v in grad.data() {
let abs = v.abs();
if abs > max_val {
max_val = abs;
}
}
}
}
max_val
} else {
let mut acc: f32 = 0.0;
for &id in node_ids {
if let Ok(Some(grad)) = graph.grad(id) {
for &v in grad.data() {
acc += v.abs().powf(norm_type);
}
}
}
acc.powf(1.0 / norm_type)
};
if !total_norm.is_finite() {
total_norm = 0.0;
}
if total_norm > max_norm {
let scale = max_norm / total_norm;
for &id in node_ids {
if let Ok(Some(grad)) = graph.grad_mut(id) {
for v in grad.data_mut() {
*v *= scale;
}
}
}
}
total_norm
}
pub fn clip_grad_value_(graph: &mut Graph, node_ids: &[NodeId], max_val: f32) {
if node_ids.is_empty() || !max_val.is_finite() || max_val <= 0.0 {
return;
}
for &id in node_ids {
if let Ok(Some(grad)) = graph.grad_mut(id) {
for v in grad.data_mut() {
if *v > max_val {
*v = max_val;
} else if *v < -max_val {
*v = -max_val;
}
}
}
}
}