optirs_learned/transformer/strategies/
gradient_processing.rs1use std::fmt::Debug;
2#[allow(dead_code)]
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11
12use crate::error::{OptimError, Result};
13
14#[derive(Debug, Clone, Copy)]
16pub enum GradientProcessingStrategy {
17 Raw,
19 Clipping,
21 Normalization,
23 AdaptiveScaling,
25 Adaptive,
27 Smoothing,
29 Accumulation,
31 Dropout,
33 Compression,
35}
36
37#[derive(Debug, Clone)]
39pub struct GradientProcessor<
40 T: Float
41 + Debug
42 + Default
43 + Clone
44 + std::iter::Sum
45 + scirs2_core::ndarray::ScalarOperand
46 + Send
47 + Sync
48 + 'static,
49> {
50 strategy: GradientProcessingStrategy,
52
53 gradient_history: VecDeque<Array1<T>>,
55
56 accumulated_gradients: Option<Array1<T>>,
58
59 gradient_stats: GradientStatistics<T>,
61
62 processing_params: GradientProcessingParams<T>,
64}
65
66#[derive(Debug, Clone)]
68pub struct GradientProcessingParams<T: Float + Debug + Send + Sync + 'static> {
69 clip_threshold: T,
71
72 smoothing_factor: T,
74
75 accumulation_steps: usize,
77
78 dropout_prob: f64,
80
81 compression_ratio: f64,
83
84 norm_eps: T,
86}
87
88#[derive(Debug, Clone)]
90pub struct GradientStatistics<T: Float + Debug + Send + Sync + 'static> {
91 mean_magnitude: T,
93
94 var_magnitude: T,
96
97 max_magnitude: T,
99
100 min_magnitude: T,
102
103 update_count: usize,
105
106 sparsity: T,
108}
109
110impl<
111 T: Float
112 + Debug
113 + Default
114 + Clone
115 + std::iter::Sum
116 + scirs2_core::ndarray::ScalarOperand
117 + Send
118 + Sync
119 + 'static,
120 > GradientProcessor<T>
121{
122 pub fn new(strategy: GradientProcessingStrategy) -> Self {
124 Self {
125 strategy,
126 gradient_history: VecDeque::new(),
127 accumulated_gradients: None,
128 gradient_stats: GradientStatistics::new(),
129 processing_params: GradientProcessingParams::default(),
130 }
131 }
132
133 pub fn new_with_params(
135 strategy: GradientProcessingStrategy,
136 params: GradientProcessingParams<T>,
137 ) -> Self {
138 Self {
139 strategy,
140 gradient_history: VecDeque::new(),
141 accumulated_gradients: None,
142 gradient_stats: GradientStatistics::new(),
143 processing_params: params,
144 }
145 }
146
147 pub fn process_gradients(&mut self, gradients: &Array1<T>) -> Result<Array1<T>> {
149 self.gradient_stats.update(gradients);
151
152 match self.strategy {
153 GradientProcessingStrategy::Raw => Ok(gradients.clone()),
154 GradientProcessingStrategy::Clipping => self.clip_gradients(gradients),
155 GradientProcessingStrategy::Normalization => self.normalize_gradients(gradients),
156 GradientProcessingStrategy::AdaptiveScaling => self.adaptive_scale_gradients(gradients),
157 GradientProcessingStrategy::Adaptive => self.adaptive_scale_gradients(gradients), GradientProcessingStrategy::Smoothing => self.smooth_gradients(gradients),
159 GradientProcessingStrategy::Accumulation => self.accumulate_gradients(gradients),
160 GradientProcessingStrategy::Dropout => self.dropout_gradients(gradients),
161 GradientProcessingStrategy::Compression => self.compress_gradients(gradients),
162 }
163 }
164
165 fn clip_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
167 let grad_norm = self.compute_gradient_norm(gradients);
168
169 if grad_norm > self.processing_params.clip_threshold {
170 let scale = self.processing_params.clip_threshold / grad_norm;
171 Ok(gradients * scale)
172 } else {
173 Ok(gradients.clone())
174 }
175 }
176
177 fn normalize_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
179 let grad_norm = self.compute_gradient_norm(gradients);
180
181 if grad_norm > self.processing_params.norm_eps {
182 Ok(gradients / grad_norm)
183 } else {
184 Ok(gradients.clone())
185 }
186 }
187
188 fn adaptive_scale_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
190 let current_norm = self.compute_gradient_norm(gradients);
191 let mean_norm = self.gradient_stats.mean_magnitude;
192
193 if mean_norm > T::zero() {
194 let adaptive_scale =
195 scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()) * mean_norm
196 / current_norm
197 + scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
198 Ok(gradients * adaptive_scale)
199 } else {
200 Ok(gradients.clone())
201 }
202 }
203
204 fn smooth_gradients(&mut self, gradients: &Array1<T>) -> Result<Array1<T>> {
206 let alpha = self.processing_params.smoothing_factor;
207
208 if let Some(prev_grad) = self.gradient_history.back() {
209 let smoothed = gradients * alpha + prev_grad * (T::one() - alpha);
210 self.gradient_history.push_back(smoothed.clone());
211
212 if self.gradient_history.len() > 10 {
214 self.gradient_history.pop_front();
215 }
216
217 Ok(smoothed)
218 } else {
219 self.gradient_history.push_back(gradients.clone());
220 Ok(gradients.clone())
221 }
222 }
223
224 fn accumulate_gradients(&mut self, gradients: &Array1<T>) -> Result<Array1<T>> {
226 if let Some(ref mut accumulated) = self.accumulated_gradients {
227 *accumulated = accumulated.clone() + gradients;
228 } else {
229 self.accumulated_gradients = Some(gradients.clone());
230 }
231
232 if self
234 .gradient_stats
235 .update_count
236 .is_multiple_of(self.processing_params.accumulation_steps)
237 {
238 if let Some(accumulated) = self.accumulated_gradients.take() {
239 let scale = scirs2_core::numeric::NumCast::from(
240 1.0 / self.processing_params.accumulation_steps as f64,
241 )
242 .unwrap_or_else(|| T::zero());
243 Ok(accumulated * scale)
244 } else {
245 Ok(gradients.clone())
246 }
247 } else {
248 Ok(Array1::zeros(gradients.len()))
250 }
251 }
252
253 fn dropout_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
255 let mut result = gradients.clone();
257
258 for (i, elem) in result.iter_mut().enumerate() {
260 if (i % 10) < (self.processing_params.dropout_prob * 10.0) as usize {
261 *elem = T::zero();
262 }
263 }
264
265 Ok(result)
266 }
267
268 fn compress_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
270 let mut result = gradients.clone();
271 let threshold = self.compute_gradient_norm(gradients)
272 * scirs2_core::numeric::NumCast::from(self.processing_params.compression_ratio)
273 .unwrap_or_else(|| T::zero());
274
275 for elem in result.iter_mut() {
277 if elem.abs() < threshold {
278 *elem = T::zero();
279 }
280 }
281
282 Ok(result)
283 }
284
285 fn compute_gradient_norm(&self, gradients: &Array1<T>) -> T {
287 let sum_squares = gradients
288 .iter()
289 .map(|&x| x * x)
290 .fold(T::zero(), |a, b| a + b);
291 sum_squares.sqrt()
292 }
293
294 pub fn statistics(&self) -> &GradientStatistics<T> {
296 &self.gradient_stats
297 }
298
299 pub fn set_strategy(&mut self, strategy: GradientProcessingStrategy) {
301 self.strategy = strategy;
302 }
303
304 pub fn set_parameters(&mut self, params: GradientProcessingParams<T>) {
306 self.processing_params = params;
307 }
308
309 pub fn reset(&mut self) {
311 self.gradient_history.clear();
312 self.accumulated_gradients = None;
313 self.gradient_stats = GradientStatistics::new();
314 }
315}
316
317impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> Default for GradientStatistics<T> {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> GradientStatistics<T> {
324 pub fn new() -> Self {
326 Self {
327 mean_magnitude: T::zero(),
328 var_magnitude: T::zero(),
329 max_magnitude: T::zero(),
330 min_magnitude: scirs2_core::numeric::NumCast::from(f64::INFINITY)
331 .unwrap_or_else(|| T::zero()),
332 update_count: 0,
333 sparsity: T::zero(),
334 }
335 }
336
337 pub fn update(&mut self, gradients: &Array1<T>) {
339 let magnitude = gradients
340 .iter()
341 .map(|&x| x * x)
342 .fold(T::zero(), |a, b| a + b)
343 .sqrt();
344
345 self.update_count += 1;
346 let count = scirs2_core::numeric::NumCast::from(self.update_count as f64)
347 .unwrap_or_else(|| T::zero());
348
349 let delta = magnitude - self.mean_magnitude;
351 self.mean_magnitude = self.mean_magnitude + delta / count;
352
353 let delta2 = magnitude - self.mean_magnitude;
355 self.var_magnitude = self.var_magnitude + delta * delta2;
356
357 if magnitude > self.max_magnitude {
359 self.max_magnitude = magnitude;
360 }
361 if magnitude < self.min_magnitude {
362 self.min_magnitude = magnitude;
363 }
364
365 let zero_count = gradients
367 .iter()
368 .filter(|&&x| {
369 x.abs() < scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero())
370 })
371 .count();
372 let current_sparsity = T::from(zero_count as f64 / gradients.len() as f64).unwrap();
373 let alpha = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
374 self.sparsity = self.sparsity * (T::one() - alpha) + current_sparsity * alpha;
375 }
376
377 pub fn mean_magnitude(&self) -> T {
379 self.mean_magnitude
380 }
381
382 pub fn variance_magnitude(&self) -> T {
384 if self.update_count > 1 {
385 self.var_magnitude / T::from((self.update_count - 1) as f64).unwrap()
386 } else {
387 T::zero()
388 }
389 }
390
391 pub fn std_magnitude(&self) -> T {
393 self.variance_magnitude().sqrt()
394 }
395
396 pub fn sparsity(&self) -> T {
398 self.sparsity
399 }
400}
401
402impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> Default
403 for GradientProcessingParams<T>
404{
405 fn default() -> Self {
406 Self {
407 clip_threshold: scirs2_core::numeric::NumCast::from(1.0).unwrap_or_else(|| T::zero()),
408 smoothing_factor: scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
409 accumulation_steps: 4,
410 dropout_prob: 0.1,
411 compression_ratio: 0.1,
412 norm_eps: scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()),
413 }
414 }
415}