optirs_learned/transformer/strategies/
gradient_processing.rs

1use std::fmt::Debug;
2// Gradient processing strategies for transformer optimization
3//
4// This module implements various gradient transformation and processing strategies
5// used by the transformer optimizer to improve optimization performance.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11
12use crate::error::{OptimError, Result};
13
14/// Gradient processing strategies
15#[derive(Debug, Clone, Copy)]
16pub enum GradientProcessingStrategy {
17    /// Raw gradients without processing
18    Raw,
19    /// Gradient clipping
20    Clipping,
21    /// Gradient normalization
22    Normalization,
23    /// Adaptive gradient scaling
24    AdaptiveScaling,
25    /// Adaptive processing (general)
26    Adaptive,
27    /// Gradient smoothing
28    Smoothing,
29    /// Gradient accumulation
30    Accumulation,
31    /// Gradient dropout
32    Dropout,
33    /// Gradient compression
34    Compression,
35}
36
37/// Gradient processor for transformer optimizer
38#[derive(Debug, Clone)]
39pub struct GradientProcessor<
40    T: Float
41        + Debug
42        + Default
43        + Clone
44        + std::iter::Sum
45        + scirs2_core::ndarray::ScalarOperand
46        + Send
47        + Sync
48        + 'static,
49> {
50    /// Processing strategy
51    strategy: GradientProcessingStrategy,
52
53    /// Gradient history for smoothing
54    gradient_history: VecDeque<Array1<T>>,
55
56    /// Accumulated gradients
57    accumulated_gradients: Option<Array1<T>>,
58
59    /// Gradient statistics
60    gradient_stats: GradientStatistics<T>,
61
62    /// Processing parameters
63    processing_params: GradientProcessingParams<T>,
64}
65
66/// Gradient processing parameters
67#[derive(Debug, Clone)]
68pub struct GradientProcessingParams<T: Float + Debug + Send + Sync + 'static> {
69    /// Clipping threshold
70    clip_threshold: T,
71
72    /// Smoothing factor
73    smoothing_factor: T,
74
75    /// Accumulation steps
76    accumulation_steps: usize,
77
78    /// Dropout probability
79    dropout_prob: f64,
80
81    /// Compression ratio
82    compression_ratio: f64,
83
84    /// Normalization epsilon
85    norm_eps: T,
86}
87
88/// Gradient statistics tracking
89#[derive(Debug, Clone)]
90pub struct GradientStatistics<T: Float + Debug + Send + Sync + 'static> {
91    /// Running mean of gradient magnitudes
92    mean_magnitude: T,
93
94    /// Running variance of gradient magnitudes
95    var_magnitude: T,
96
97    /// Maximum gradient magnitude seen
98    max_magnitude: T,
99
100    /// Minimum gradient magnitude seen
101    min_magnitude: T,
102
103    /// Update count
104    update_count: usize,
105
106    /// Gradient sparsity
107    sparsity: T,
108}
109
110impl<
111        T: Float
112            + Debug
113            + Default
114            + Clone
115            + std::iter::Sum
116            + scirs2_core::ndarray::ScalarOperand
117            + Send
118            + Sync
119            + 'static,
120    > GradientProcessor<T>
121{
122    /// Create new gradient processor
123    pub fn new(strategy: GradientProcessingStrategy) -> Self {
124        Self {
125            strategy,
126            gradient_history: VecDeque::new(),
127            accumulated_gradients: None,
128            gradient_stats: GradientStatistics::new(),
129            processing_params: GradientProcessingParams::default(),
130        }
131    }
132
133    /// Create with custom parameters
134    pub fn new_with_params(
135        strategy: GradientProcessingStrategy,
136        params: GradientProcessingParams<T>,
137    ) -> Self {
138        Self {
139            strategy,
140            gradient_history: VecDeque::new(),
141            accumulated_gradients: None,
142            gradient_stats: GradientStatistics::new(),
143            processing_params: params,
144        }
145    }
146
147    /// Process gradients according to the selected strategy
148    pub fn process_gradients(&mut self, gradients: &Array1<T>) -> Result<Array1<T>> {
149        // Update statistics first
150        self.gradient_stats.update(gradients);
151
152        match self.strategy {
153            GradientProcessingStrategy::Raw => Ok(gradients.clone()),
154            GradientProcessingStrategy::Clipping => self.clip_gradients(gradients),
155            GradientProcessingStrategy::Normalization => self.normalize_gradients(gradients),
156            GradientProcessingStrategy::AdaptiveScaling => self.adaptive_scale_gradients(gradients),
157            GradientProcessingStrategy::Adaptive => self.adaptive_scale_gradients(gradients), // Use adaptive scaling as default
158            GradientProcessingStrategy::Smoothing => self.smooth_gradients(gradients),
159            GradientProcessingStrategy::Accumulation => self.accumulate_gradients(gradients),
160            GradientProcessingStrategy::Dropout => self.dropout_gradients(gradients),
161            GradientProcessingStrategy::Compression => self.compress_gradients(gradients),
162        }
163    }
164
165    /// Clip gradients to prevent explosion
166    fn clip_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
167        let grad_norm = self.compute_gradient_norm(gradients);
168
169        if grad_norm > self.processing_params.clip_threshold {
170            let scale = self.processing_params.clip_threshold / grad_norm;
171            Ok(gradients * scale)
172        } else {
173            Ok(gradients.clone())
174        }
175    }
176
177    /// Normalize gradients
178    fn normalize_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
179        let grad_norm = self.compute_gradient_norm(gradients);
180
181        if grad_norm > self.processing_params.norm_eps {
182            Ok(gradients / grad_norm)
183        } else {
184            Ok(gradients.clone())
185        }
186    }
187
188    /// Adaptively scale gradients based on statistics
189    fn adaptive_scale_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
190        let current_norm = self.compute_gradient_norm(gradients);
191        let mean_norm = self.gradient_stats.mean_magnitude;
192
193        if mean_norm > T::zero() {
194            let adaptive_scale =
195                scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()) * mean_norm
196                    / current_norm
197                    + scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
198            Ok(gradients * adaptive_scale)
199        } else {
200            Ok(gradients.clone())
201        }
202    }
203
204    /// Smooth gradients using exponential moving average
205    fn smooth_gradients(&mut self, gradients: &Array1<T>) -> Result<Array1<T>> {
206        let alpha = self.processing_params.smoothing_factor;
207
208        if let Some(prev_grad) = self.gradient_history.back() {
209            let smoothed = gradients * alpha + prev_grad * (T::one() - alpha);
210            self.gradient_history.push_back(smoothed.clone());
211
212            // Keep only recent history
213            if self.gradient_history.len() > 10 {
214                self.gradient_history.pop_front();
215            }
216
217            Ok(smoothed)
218        } else {
219            self.gradient_history.push_back(gradients.clone());
220            Ok(gradients.clone())
221        }
222    }
223
224    /// Accumulate gradients over multiple steps
225    fn accumulate_gradients(&mut self, gradients: &Array1<T>) -> Result<Array1<T>> {
226        if let Some(ref mut accumulated) = self.accumulated_gradients {
227            *accumulated = accumulated.clone() + gradients;
228        } else {
229            self.accumulated_gradients = Some(gradients.clone());
230        }
231
232        // Return accumulated gradients if we've reached the target steps
233        if self
234            .gradient_stats
235            .update_count
236            .is_multiple_of(self.processing_params.accumulation_steps)
237        {
238            if let Some(accumulated) = self.accumulated_gradients.take() {
239                let scale = scirs2_core::numeric::NumCast::from(
240                    1.0 / self.processing_params.accumulation_steps as f64,
241                )
242                .unwrap_or_else(|| T::zero());
243                Ok(accumulated * scale)
244            } else {
245                Ok(gradients.clone())
246            }
247        } else {
248            // Return zero gradients for intermediate steps
249            Ok(Array1::zeros(gradients.len()))
250        }
251    }
252
253    /// Apply dropout to gradients
254    fn dropout_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
255        // Simplified dropout - in practice would use proper random sampling
256        let mut result = gradients.clone();
257
258        // Apply deterministic "dropout" pattern for reproducibility
259        for (i, elem) in result.iter_mut().enumerate() {
260            if (i % 10) < (self.processing_params.dropout_prob * 10.0) as usize {
261                *elem = T::zero();
262            }
263        }
264
265        Ok(result)
266    }
267
268    /// Compress gradients (simplified sparsification)
269    fn compress_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
270        let mut result = gradients.clone();
271        let threshold = self.compute_gradient_norm(gradients)
272            * scirs2_core::numeric::NumCast::from(self.processing_params.compression_ratio)
273                .unwrap_or_else(|| T::zero());
274
275        // Zero out small gradients
276        for elem in result.iter_mut() {
277            if elem.abs() < threshold {
278                *elem = T::zero();
279            }
280        }
281
282        Ok(result)
283    }
284
285    /// Compute L2 norm of gradients
286    fn compute_gradient_norm(&self, gradients: &Array1<T>) -> T {
287        let sum_squares = gradients
288            .iter()
289            .map(|&x| x * x)
290            .fold(T::zero(), |a, b| a + b);
291        sum_squares.sqrt()
292    }
293
294    /// Get gradient statistics
295    pub fn statistics(&self) -> &GradientStatistics<T> {
296        &self.gradient_stats
297    }
298
299    /// Update processing strategy
300    pub fn set_strategy(&mut self, strategy: GradientProcessingStrategy) {
301        self.strategy = strategy;
302    }
303
304    /// Update processing parameters
305    pub fn set_parameters(&mut self, params: GradientProcessingParams<T>) {
306        self.processing_params = params;
307    }
308
309    /// Reset processor state
310    pub fn reset(&mut self) {
311        self.gradient_history.clear();
312        self.accumulated_gradients = None;
313        self.gradient_stats = GradientStatistics::new();
314    }
315}
316
317impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> Default for GradientStatistics<T> {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> GradientStatistics<T> {
324    /// Create new gradient statistics
325    pub fn new() -> Self {
326        Self {
327            mean_magnitude: T::zero(),
328            var_magnitude: T::zero(),
329            max_magnitude: T::zero(),
330            min_magnitude: scirs2_core::numeric::NumCast::from(f64::INFINITY)
331                .unwrap_or_else(|| T::zero()),
332            update_count: 0,
333            sparsity: T::zero(),
334        }
335    }
336
337    /// Update statistics with new gradients
338    pub fn update(&mut self, gradients: &Array1<T>) {
339        let magnitude = gradients
340            .iter()
341            .map(|&x| x * x)
342            .fold(T::zero(), |a, b| a + b)
343            .sqrt();
344
345        self.update_count += 1;
346        let count = scirs2_core::numeric::NumCast::from(self.update_count as f64)
347            .unwrap_or_else(|| T::zero());
348
349        // Update running mean
350        let delta = magnitude - self.mean_magnitude;
351        self.mean_magnitude = self.mean_magnitude + delta / count;
352
353        // Update running variance
354        let delta2 = magnitude - self.mean_magnitude;
355        self.var_magnitude = self.var_magnitude + delta * delta2;
356
357        // Update min/max
358        if magnitude > self.max_magnitude {
359            self.max_magnitude = magnitude;
360        }
361        if magnitude < self.min_magnitude {
362            self.min_magnitude = magnitude;
363        }
364
365        // Update sparsity (fraction of near-zero elements)
366        let zero_count = gradients
367            .iter()
368            .filter(|&&x| {
369                x.abs() < scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero())
370            })
371            .count();
372        let current_sparsity = T::from(zero_count as f64 / gradients.len() as f64).unwrap();
373        let alpha = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
374        self.sparsity = self.sparsity * (T::one() - alpha) + current_sparsity * alpha;
375    }
376
377    /// Get mean magnitude
378    pub fn mean_magnitude(&self) -> T {
379        self.mean_magnitude
380    }
381
382    /// Get variance of magnitude
383    pub fn variance_magnitude(&self) -> T {
384        if self.update_count > 1 {
385            self.var_magnitude / T::from((self.update_count - 1) as f64).unwrap()
386        } else {
387            T::zero()
388        }
389    }
390
391    /// Get standard deviation of magnitude
392    pub fn std_magnitude(&self) -> T {
393        self.variance_magnitude().sqrt()
394    }
395
396    /// Get gradient sparsity
397    pub fn sparsity(&self) -> T {
398        self.sparsity
399    }
400}
401
402impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> Default
403    for GradientProcessingParams<T>
404{
405    fn default() -> Self {
406        Self {
407            clip_threshold: scirs2_core::numeric::NumCast::from(1.0).unwrap_or_else(|| T::zero()),
408            smoothing_factor: scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
409            accumulation_steps: 4,
410            dropout_prob: 0.1,
411            compression_ratio: 0.1,
412            norm_eps: scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()),
413        }
414    }
415}