1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
//! Gradient Clipping callback
//!
//! This module provides a callback for gradient clipping during training.
use super::{Callback, CallbackContext, CallbackTiming};
use crate::error::Result;
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::{Debug, Display};
/// Gradient clipping method
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GradientClippingMethod {
/// Clip by global norm (divides by global norm if it exceeds max_norm)
ClipByGlobalNorm,
/// Clip by value (clip each value to be within [-max_value, max_value])
ClipByValue,
}
/// Gradient clipping callback
#[derive(Debug)]
pub struct GradientClipping<F: Float + NumAssign + Debug + ScalarOperand + Display> {
/// Maximum norm for gradient clipping
pub max_norm: F,
/// Clipping method
pub method: GradientClippingMethod,
/// Whether to log clipping statistics
pub log_stats: bool,
/// Whether clipping was applied in the last step
clipping_applied: bool,
/// Clipping ratio in the last step (if global norm method is used)
clipping_ratio: Option<F>,
}
impl<F: Float + NumAssign + Debug + ScalarOperand + Display> GradientClipping<F> {
/// Create a new gradient clipping callback using global norm
pub fn by_global_norm(max_norm: F, log_stats: bool) -> Self {
Self {
max_norm,
method: GradientClippingMethod::ClipByGlobalNorm,
log_stats,
clipping_applied: false,
clipping_ratio: None,
}
}
/// Create a new gradient clipping callback using value clipping
pub fn by_value(max_value: F, log_stats: bool) -> Self {
Self {
max_norm: max_value,
method: GradientClippingMethod::ClipByValue,
log_stats,
clipping_applied: false,
clipping_ratio: None,
}
}
/// Returns whether clipping was applied in the last step
pub fn was_clipping_applied(&self) -> bool {
self.clipping_applied
}
/// Returns the clipping ratio from the last step (if global norm method was used)
pub fn get_clipping_ratio(&self) -> Option<F> {
self.clipping_ratio
}
/// Clip gradients by global norm
fn clip_by_global_norm<L: Layer<F> + ?Sized>(&mut self, model: &mut L) -> Result<()> {
let gradients = model.gradients();
// Compute global norm
let mut global_norm_sq = F::zero();
for grad in &gradients {
for &val in grad.iter() {
global_norm_sq += val * val;
}
}
let global_norm = global_norm_sq.sqrt();
// Clip if necessary
if global_norm > self.max_norm {
let scale = self.max_norm / global_norm;
self.clipping_applied = true;
self.clipping_ratio = Some(scale);
let clipped_gradients: Vec<Array<F, IxDyn>> =
gradients.iter().map(|grad| grad.clone() * scale).collect();
// Apply clipped gradients
model.set_gradients(&clipped_gradients)?;
if self.log_stats {
println!(
"Gradient clipping applied - global norm: {:.4}, scale: {:.4}",
global_norm, scale
);
}
} else {
self.clipping_applied = false;
self.clipping_ratio = None;
}
Ok(())
}
/// Clip gradients by value
fn clip_by_value<L: Layer<F> + ?Sized>(&mut self, model: &mut L) -> Result<()> {
let gradients = model.gradients();
// Check if any value exceeds the maximum
let mut clipping_needed = false;
for grad in &gradients {
for &val in grad.iter() {
if val.abs() > self.max_norm {
clipping_needed = true;
break;
}
}
if clipping_needed {
break;
}
}
if clipping_needed {
let clipped_gradients: Vec<Array<F, IxDyn>> = gradients
.iter()
.map(|grad| {
let mut clipped = grad.clone();
for val in clipped.iter_mut() {
if *val > self.max_norm {
*val = self.max_norm;
} else if *val < -self.max_norm {
*val = -self.max_norm;
}
}
clipped
})
.collect();
model.set_gradients(&clipped_gradients)?;
self.clipping_applied = true;
if self.log_stats {
println!(
"Gradient value clipping applied - max value: {:.4}",
self.max_norm
);
}
} else {
self.clipping_applied = false;
}
Ok(())
}
}
impl<F: Float + NumAssign + Debug + ScalarOperand + Display> Callback<F> for GradientClipping<F> {
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
// The callback should be executed after each batch, before optimization
if timing == CallbackTiming::AfterBatch {
if let Some(_batch_loss) = context.batch_loss {
// Access the model from the context
if let Some(model) = context.model.as_mut() {
match self.method {
GradientClippingMethod::ClipByGlobalNorm => {
if let Err(e) = self.clip_by_global_norm(&mut **model) {
eprintln!("Error in clip_by_global_norm: {}", e);
}
}
GradientClippingMethod::ClipByValue => {
if let Err(e) = self.clip_by_value(&mut **model) {
eprintln!("Error in clip_by_value: {}", e);
}
}
}
} else {
// Fallback behavior if model is not available
if self.log_stats {
println!("Gradient clipping: model not available in context");
}
self.clipping_applied = false;
}
}
}
Ok(())
}
}