Skip to main content

compute_gradient_norm

Function compute_gradient_norm 

Source
pub fn compute_gradient_norm(gradients: &[&[f64]]) -> f64
Expand description

Compute the L2 norm (Frobenius norm) of a concatenated gradient vector.

Used for gradient clipping: if the norm exceeds a threshold the caller should scale all gradients by threshold / norm.