quantrs2_ml/torchquantum/
autograd.rs

1//! Autograd helpers for TorchQuantum-compatible quantum machine learning
2//!
3//! This module provides automatic differentiation utilities:
4//! - **GradientAccumulator**: Accumulate gradients across multiple backward passes
5//! - **ParameterRegistry**: Track and manage all parameters in a quantum model
6//! - **GradientClipper**: Prevent exploding gradients with various clipping strategies
7//! - **GradientChecker**: Numerical gradient verification for debugging
8//! - **ParameterGroup**: Organize parameters with different optimization settings
9//!
10//! ## TorchQuantum Compatibility
11//!
12//! These utilities mirror PyTorch's autograd functionality adapted for quantum circuits:
13//! - Parameter tracking similar to `torch.nn.parameter.Parameter`
14//! - Gradient accumulation like PyTorch's backward pass
15//! - Gradient clipping similar to `torch.nn.utils.clip_grad_*`
16
17use super::{TQModule, TQParameter};
18use crate::error::{MLError, Result};
19use scirs2_core::ndarray::{Array1, ArrayD, Axis, IxDyn};
20use std::collections::HashMap;
21
22// ============================================================================
23// GradientAccumulator - Accumulate gradients across multiple passes
24// ============================================================================
25
26/// Gradient accumulator for mini-batch training
27///
28/// Accumulates gradients from multiple forward-backward passes before
29/// applying parameter updates. Useful for:
30/// - Simulating larger batch sizes with limited memory
31/// - Gradient accumulation across multiple quantum circuit executions
32/// - Variance reduction in parameter-shift rule calculations
33#[derive(Debug, Clone)]
34pub struct GradientAccumulator {
35    /// Number of accumulation steps
36    pub accumulation_steps: usize,
37    /// Current step counter
38    current_step: usize,
39    /// Accumulated gradients for each parameter (keyed by parameter name)
40    accumulated_grads: HashMap<String, ArrayD<f64>>,
41    /// Whether to average gradients (vs sum)
42    average: bool,
43}
44
45impl GradientAccumulator {
46    /// Create new gradient accumulator
47    pub fn new(accumulation_steps: usize) -> Self {
48        Self {
49            accumulation_steps,
50            current_step: 0,
51            accumulated_grads: HashMap::new(),
52            average: true,
53        }
54    }
55
56    /// Create accumulator with sum (no averaging)
57    pub fn with_sum(accumulation_steps: usize) -> Self {
58        Self {
59            accumulation_steps,
60            current_step: 0,
61            accumulated_grads: HashMap::new(),
62            average: false,
63        }
64    }
65
66    /// Accumulate gradients from parameters
67    pub fn accumulate(&mut self, params: &[TQParameter]) -> Result<()> {
68        for param in params {
69            if !param.requires_grad {
70                continue;
71            }
72
73            if let Some(grad) = &param.grad {
74                let entry = self
75                    .accumulated_grads
76                    .entry(param.name.clone())
77                    .or_insert_with(|| ArrayD::zeros(grad.raw_dim()));
78
79                *entry = &*entry + grad;
80            }
81        }
82
83        self.current_step += 1;
84        Ok(())
85    }
86
87    /// Check if ready to apply gradients
88    pub fn is_ready(&self) -> bool {
89        self.current_step >= self.accumulation_steps
90    }
91
92    /// Get accumulated gradients and reset
93    pub fn get_and_reset(&mut self) -> HashMap<String, ArrayD<f64>> {
94        let mut result = std::mem::take(&mut self.accumulated_grads);
95
96        if self.average && self.accumulation_steps > 1 {
97            let scale = 1.0 / self.accumulation_steps as f64;
98            for grad in result.values_mut() {
99                *grad = &*grad * scale;
100            }
101        }
102
103        self.current_step = 0;
104        result
105    }
106
107    /// Reset accumulator without returning gradients
108    pub fn reset(&mut self) {
109        self.accumulated_grads.clear();
110        self.current_step = 0;
111    }
112
113    /// Get current step count
114    pub fn step_count(&self) -> usize {
115        self.current_step
116    }
117}
118
119// ============================================================================
120// ParameterRegistry - Track all parameters in a model
121// ============================================================================
122
123/// Parameter registry for tracking and managing quantum model parameters
124///
125/// Provides centralized parameter management:
126/// - Track all parameters across multiple quantum modules
127/// - Freeze/unfreeze specific parameters
128/// - Get parameter statistics (count, memory usage)
129/// - Named parameter access
130#[derive(Debug)]
131pub struct ParameterRegistry {
132    /// Map of parameter name to parameter
133    parameters: HashMap<String, TQParameter>,
134    /// Frozen parameter names (not trainable)
135    frozen: Vec<String>,
136}
137
138impl ParameterRegistry {
139    /// Create new parameter registry
140    pub fn new() -> Self {
141        Self {
142            parameters: HashMap::new(),
143            frozen: Vec::new(),
144        }
145    }
146
147    /// Register parameters from a module
148    pub fn register_module(&mut self, module: &dyn TQModule) -> Result<()> {
149        let params = module.parameters();
150        for param in params {
151            self.parameters.insert(param.name.clone(), param);
152        }
153        Ok(())
154    }
155
156    /// Register a single parameter
157    pub fn register(&mut self, param: TQParameter) {
158        self.parameters.insert(param.name.clone(), param);
159    }
160
161    /// Get parameter by name
162    pub fn get(&self, name: &str) -> Option<&TQParameter> {
163        self.parameters.get(name)
164    }
165
166    /// Get mutable parameter by name
167    pub fn get_mut(&mut self, name: &str) -> Option<&mut TQParameter> {
168        self.parameters.get_mut(name)
169    }
170
171    /// Get all trainable parameters
172    pub fn trainable_parameters(&self) -> Vec<&TQParameter> {
173        self.parameters
174            .values()
175            .filter(|p| p.requires_grad && !self.frozen.contains(&p.name))
176            .collect()
177    }
178
179    /// Get all parameter names
180    pub fn parameter_names(&self) -> Vec<&str> {
181        self.parameters.keys().map(|s| s.as_str()).collect()
182    }
183
184    /// Total number of parameters
185    pub fn count(&self) -> usize {
186        self.parameters.values().map(|p| p.numel()).sum()
187    }
188
189    /// Number of trainable parameters
190    pub fn trainable_count(&self) -> usize {
191        self.trainable_parameters().iter().map(|p| p.numel()).sum()
192    }
193
194    /// Freeze parameter (make non-trainable)
195    pub fn freeze(&mut self, name: &str) -> Result<()> {
196        if !self.parameters.contains_key(name) {
197            return Err(MLError::InvalidConfiguration(format!(
198                "Parameter '{}' not found",
199                name
200            )));
201        }
202        if !self.frozen.contains(&name.to_string()) {
203            self.frozen.push(name.to_string());
204        }
205        Ok(())
206    }
207
208    /// Unfreeze parameter (make trainable)
209    pub fn unfreeze(&mut self, name: &str) -> Result<()> {
210        self.frozen.retain(|n| n != name);
211        Ok(())
212    }
213
214    /// Freeze all parameters
215    pub fn freeze_all(&mut self) {
216        self.frozen = self.parameters.keys().cloned().collect();
217    }
218
219    /// Unfreeze all parameters
220    pub fn unfreeze_all(&mut self) {
221        self.frozen.clear();
222    }
223
224    /// Zero all gradients
225    pub fn zero_grad(&mut self) {
226        for param in self.parameters.values_mut() {
227            param.zero_grad();
228        }
229    }
230
231    /// Get memory usage in bytes
232    pub fn memory_bytes(&self) -> usize {
233        self.parameters.values().map(|p| p.numel() * 8).sum() // 8 bytes per f64
234    }
235
236    /// Get parameter statistics
237    pub fn statistics(&self) -> ParameterStatistics {
238        let total_params = self.count();
239        let trainable_params = self.trainable_count();
240        let memory_mb = self.memory_bytes() as f64 / (1024.0 * 1024.0);
241
242        ParameterStatistics {
243            total_params,
244            trainable_params,
245            frozen_params: total_params - trainable_params,
246            memory_mb,
247        }
248    }
249}
250
251impl Default for ParameterRegistry {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257/// Parameter statistics
258#[derive(Debug, Clone)]
259pub struct ParameterStatistics {
260    pub total_params: usize,
261    pub trainable_params: usize,
262    pub frozen_params: usize,
263    pub memory_mb: f64,
264}
265
266// ============================================================================
267// GradientClipper - Prevent exploding gradients
268// ============================================================================
269
270/// Gradient clipping strategy
271#[derive(Debug, Clone, Copy, PartialEq)]
272pub enum ClippingStrategy {
273    /// Clip by global norm (scale all gradients by same factor)
274    Norm { max_norm: f64 },
275    /// Clip each gradient individually by value
276    Value { clip_value: f64 },
277    /// Adaptive clipping based on parameter norm
278    Adaptive { clip_factor: f64 },
279}
280
281/// Gradient clipper to prevent exploding gradients
282///
283/// Provides various clipping strategies:
284/// - **Norm clipping**: Scales all gradients if total norm exceeds threshold
285/// - **Value clipping**: Clips individual gradient values to [-clip_value, clip_value]
286/// - **Adaptive clipping**: Clips based on parameter magnitude
287pub struct GradientClipper {
288    strategy: ClippingStrategy,
289    /// Statistics about last clipping operation
290    pub last_norm: Option<f64>,
291    pub was_clipped: bool,
292}
293
294impl GradientClipper {
295    /// Create clipper with norm-based strategy
296    pub fn by_norm(max_norm: f64) -> Self {
297        Self {
298            strategy: ClippingStrategy::Norm { max_norm },
299            last_norm: None,
300            was_clipped: false,
301        }
302    }
303
304    /// Create clipper with value-based strategy
305    pub fn by_value(clip_value: f64) -> Self {
306        Self {
307            strategy: ClippingStrategy::Value { clip_value },
308            last_norm: None,
309            was_clipped: false,
310        }
311    }
312
313    /// Create clipper with adaptive strategy
314    pub fn adaptive(clip_factor: f64) -> Self {
315        Self {
316            strategy: ClippingStrategy::Adaptive { clip_factor },
317            last_norm: None,
318            was_clipped: false,
319        }
320    }
321
322    /// Clip gradients in place
323    pub fn clip(&mut self, params: &mut [TQParameter]) -> Result<()> {
324        match self.strategy {
325            ClippingStrategy::Norm { max_norm } => self.clip_by_norm(params, max_norm),
326            ClippingStrategy::Value { clip_value } => self.clip_by_value(params, clip_value),
327            ClippingStrategy::Adaptive { clip_factor } => self.clip_adaptive(params, clip_factor),
328        }
329    }
330
331    fn clip_by_norm(&mut self, params: &mut [TQParameter], max_norm: f64) -> Result<()> {
332        // Calculate total gradient norm
333        let mut total_norm_sq = 0.0;
334        for param in params.iter() {
335            if let Some(grad) = &param.grad {
336                for &val in grad.iter() {
337                    total_norm_sq += val * val;
338                }
339            }
340        }
341
342        let total_norm = total_norm_sq.sqrt();
343        self.last_norm = Some(total_norm);
344
345        if total_norm > max_norm {
346            let scale = max_norm / (total_norm + 1e-10);
347            for param in params {
348                if let Some(grad) = &mut param.grad {
349                    *grad = &*grad * scale;
350                }
351            }
352            self.was_clipped = true;
353        } else {
354            self.was_clipped = false;
355        }
356
357        Ok(())
358    }
359
360    fn clip_by_value(&mut self, params: &mut [TQParameter], clip_value: f64) -> Result<()> {
361        self.was_clipped = false;
362
363        for param in params {
364            if let Some(grad) = &mut param.grad {
365                for val in grad.iter_mut() {
366                    if val.abs() > clip_value {
367                        *val = val.signum() * clip_value;
368                        self.was_clipped = true;
369                    }
370                }
371            }
372        }
373
374        Ok(())
375    }
376
377    fn clip_adaptive(&mut self, params: &mut [TQParameter], clip_factor: f64) -> Result<()> {
378        self.was_clipped = false;
379
380        for param in params {
381            if let Some(grad) = &mut param.grad {
382                // Calculate parameter norm
383                let param_norm: f64 = param.data.iter().map(|&v| v * v).sum::<f64>().sqrt();
384                let max_grad = param_norm * clip_factor;
385
386                // Calculate gradient norm
387                let grad_norm: f64 = grad.iter().map(|&v| v * v).sum::<f64>().sqrt();
388
389                if grad_norm > max_grad {
390                    let scale = max_grad / (grad_norm + 1e-10);
391                    *grad = &*grad * scale;
392                    self.was_clipped = true;
393                }
394            }
395        }
396
397        Ok(())
398    }
399
400    /// Get clipping statistics
401    pub fn statistics(&self) -> ClippingStatistics {
402        ClippingStatistics {
403            was_clipped: self.was_clipped,
404            last_norm: self.last_norm,
405            strategy: self.strategy,
406        }
407    }
408}
409
410/// Clipping statistics
411#[derive(Debug, Clone)]
412pub struct ClippingStatistics {
413    pub was_clipped: bool,
414    pub last_norm: Option<f64>,
415    pub strategy: ClippingStrategy,
416}
417
418// ============================================================================
419// GradientChecker - Numerical gradient verification
420// ============================================================================
421
422/// Gradient checker for numerical verification
423///
424/// Compares analytical gradients (from parameter-shift rule or adjoint method)
425/// with numerical gradients (finite differences) to verify correctness.
426pub struct GradientChecker {
427    /// Epsilon for finite differences
428    pub epsilon: f64,
429    /// Relative tolerance for comparison
430    pub rtol: f64,
431    /// Absolute tolerance for comparison
432    pub atol: f64,
433}
434
435impl GradientChecker {
436    /// Create new gradient checker with default tolerances
437    pub fn new() -> Self {
438        Self {
439            epsilon: 1e-5,
440            rtol: 1e-3,
441            atol: 1e-5,
442        }
443    }
444
445    /// Create with custom epsilon
446    pub fn with_epsilon(epsilon: f64) -> Self {
447        Self {
448            epsilon,
449            rtol: 1e-3,
450            atol: 1e-5,
451        }
452    }
453
454    /// Create with custom tolerances
455    pub fn with_tolerances(epsilon: f64, rtol: f64, atol: f64) -> Self {
456        Self {
457            epsilon,
458            rtol,
459            atol,
460        }
461    }
462
463    /// Compute numerical gradient using finite differences
464    ///
465    /// For function f and parameter θ:
466    /// ∂f/∂θ ≈ [f(θ + ε) - f(θ - ε)] / (2ε)
467    pub fn numerical_gradient<F>(
468        &self,
469        param: &mut TQParameter,
470        param_idx: usize,
471        loss_fn: &mut F,
472    ) -> Result<f64>
473    where
474        F: FnMut() -> Result<f64>,
475    {
476        // Get original value
477        let flat_idx = self.flat_index(param_idx, param.shape());
478        let original =
479            param.data.as_slice_mut().ok_or_else(|| {
480                MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
481            })?[flat_idx];
482
483        // f(θ + ε)
484        param.data.as_slice_mut().ok_or_else(|| {
485            MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
486        })?[flat_idx] = original + self.epsilon;
487        let loss_plus = loss_fn()?;
488
489        // f(θ - ε)
490        param.data.as_slice_mut().ok_or_else(|| {
491            MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
492        })?[flat_idx] = original - self.epsilon;
493        let loss_minus = loss_fn()?;
494
495        // Restore original value
496        param.data.as_slice_mut().ok_or_else(|| {
497            MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
498        })?[flat_idx] = original;
499
500        // Compute numerical gradient
501        Ok((loss_plus - loss_minus) / (2.0 * self.epsilon))
502    }
503
504    /// Check if analytical and numerical gradients match
505    pub fn check_gradient(&self, analytical: f64, numerical: f64) -> GradientCheckResult {
506        let abs_diff = (analytical - numerical).abs();
507        let rel_diff = if numerical.abs() > 1e-10 {
508            abs_diff / numerical.abs()
509        } else {
510            abs_diff
511        };
512
513        let matches = abs_diff <= self.atol || rel_diff <= self.rtol;
514
515        GradientCheckResult {
516            analytical,
517            numerical,
518            abs_diff,
519            rel_diff,
520            matches,
521        }
522    }
523
524    fn flat_index(&self, idx: usize, shape: &[usize]) -> usize {
525        idx
526    }
527}
528
529impl Default for GradientChecker {
530    fn default() -> Self {
531        Self::new()
532    }
533}
534
535/// Result of gradient check
536#[derive(Debug, Clone)]
537pub struct GradientCheckResult {
538    pub analytical: f64,
539    pub numerical: f64,
540    pub abs_diff: f64,
541    pub rel_diff: f64,
542    pub matches: bool,
543}
544
545// ============================================================================
546// ParameterGroup - Group parameters with different settings
547// ============================================================================
548
549/// Parameter group for organizing parameters
550///
551/// Similar to PyTorch's parameter groups in optimizers.
552/// Allows different learning rates, weight decay, etc. for different parameter sets.
553#[derive(Debug, Clone)]
554pub struct ParameterGroup {
555    /// Group name
556    pub name: String,
557    /// Parameter names in this group
558    pub param_names: Vec<String>,
559    /// Learning rate multiplier for this group
560    pub lr_multiplier: f64,
561    /// Weight decay for this group
562    pub weight_decay: f64,
563    /// Whether gradients are enabled for this group
564    pub requires_grad: bool,
565}
566
567impl ParameterGroup {
568    /// Create new parameter group
569    pub fn new(name: impl Into<String>) -> Self {
570        Self {
571            name: name.into(),
572            param_names: Vec::new(),
573            lr_multiplier: 1.0,
574            weight_decay: 0.0,
575            requires_grad: true,
576        }
577    }
578
579    /// Add parameter to group
580    pub fn add_param(&mut self, param_name: impl Into<String>) {
581        self.param_names.push(param_name.into());
582    }
583
584    /// Set learning rate multiplier
585    pub fn with_lr_multiplier(mut self, multiplier: f64) -> Self {
586        self.lr_multiplier = multiplier;
587        self
588    }
589
590    /// Set weight decay
591    pub fn with_weight_decay(mut self, decay: f64) -> Self {
592        self.weight_decay = decay;
593        self
594    }
595
596    /// Set requires_grad
597    pub fn with_requires_grad(mut self, requires_grad: bool) -> Self {
598        self.requires_grad = requires_grad;
599        self
600    }
601
602    /// Check if parameter belongs to this group
603    pub fn contains(&self, param_name: &str) -> bool {
604        self.param_names.iter().any(|n| n == param_name)
605    }
606}
607
608/// Manager for multiple parameter groups
609#[derive(Debug)]
610pub struct ParameterGroupManager {
611    groups: Vec<ParameterGroup>,
612}
613
614impl ParameterGroupManager {
615    /// Create new manager
616    pub fn new() -> Self {
617        Self { groups: Vec::new() }
618    }
619
620    /// Add a parameter group
621    pub fn add_group(&mut self, group: ParameterGroup) {
622        self.groups.push(group);
623    }
624
625    /// Get group for parameter
626    pub fn get_group(&self, param_name: &str) -> Option<&ParameterGroup> {
627        self.groups.iter().find(|g| g.contains(param_name))
628    }
629
630    /// Get all groups
631    pub fn groups(&self) -> &[ParameterGroup] {
632        &self.groups
633    }
634
635    /// Get learning rate multiplier for parameter
636    pub fn lr_multiplier(&self, param_name: &str) -> f64 {
637        self.get_group(param_name)
638            .map(|g| g.lr_multiplier)
639            .unwrap_or(1.0)
640    }
641
642    /// Get weight decay for parameter
643    pub fn weight_decay(&self, param_name: &str) -> f64 {
644        self.get_group(param_name)
645            .map(|g| g.weight_decay)
646            .unwrap_or(0.0)
647    }
648
649    /// Check if parameter requires grad
650    pub fn requires_grad(&self, param_name: &str) -> bool {
651        self.get_group(param_name)
652            .map(|g| g.requires_grad)
653            .unwrap_or(true)
654    }
655}
656
657impl Default for ParameterGroupManager {
658    fn default() -> Self {
659        Self::new()
660    }
661}
662
663// ============================================================================
664// Gradient utilities
665// ============================================================================
666
667/// Compute gradient norm (L2 norm of all gradients)
668pub fn gradient_norm(params: &[TQParameter]) -> f64 {
669    let mut norm_sq = 0.0;
670    for param in params {
671        if let Some(grad) = &param.grad {
672            for &val in grad.iter() {
673                norm_sq += val * val;
674            }
675        }
676    }
677    norm_sq.sqrt()
678}
679
680/// Compute gradient statistics
681pub fn gradient_statistics(params: &[TQParameter]) -> GradientStatistics {
682    let mut all_grads = Vec::new();
683    for param in params {
684        if let Some(grad) = &param.grad {
685            all_grads.extend(grad.iter().copied());
686        }
687    }
688
689    if all_grads.is_empty() {
690        return GradientStatistics::default();
691    }
692
693    let n = all_grads.len() as f64;
694    let mean = all_grads.iter().sum::<f64>() / n;
695    let variance = all_grads.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / n;
696    let std = variance.sqrt();
697
698    let min = all_grads
699        .iter()
700        .copied()
701        .min_by(|a, b| a.partial_cmp(b).unwrap())
702        .unwrap_or(0.0);
703    let max = all_grads
704        .iter()
705        .copied()
706        .max_by(|a, b| a.partial_cmp(b).unwrap())
707        .unwrap_or(0.0);
708
709    let norm = gradient_norm(params);
710
711    GradientStatistics {
712        mean,
713        std,
714        min,
715        max,
716        norm,
717    }
718}
719
720/// Gradient statistics
721#[derive(Debug, Clone, Default)]
722pub struct GradientStatistics {
723    pub mean: f64,
724    pub std: f64,
725    pub min: f64,
726    pub max: f64,
727    pub norm: f64,
728}
729
730// ============================================================================
731// Tests
732// ============================================================================
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737    use scirs2_core::ndarray::ArrayD;
738
739    #[test]
740    fn test_gradient_accumulator() {
741        let mut acc = GradientAccumulator::new(3);
742
743        let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
744        param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap());
745
746        // Accumulate 3 times
747        for _ in 0..3 {
748            acc.accumulate(&[param.clone()]).unwrap();
749        }
750
751        assert!(acc.is_ready());
752
753        let grads = acc.get_and_reset();
754        let test_grad = &grads["test"];
755
756        // Should be averaged: (1+1+1)/3 = 1, (2+2+2)/3 = 2
757        assert!((test_grad[[0]] - 1.0).abs() < 1e-10);
758        assert!((test_grad[[1]] - 2.0).abs() < 1e-10);
759    }
760
761    #[test]
762    fn test_parameter_registry() {
763        let mut registry = ParameterRegistry::new();
764
765        let param1 = TQParameter::new(ArrayD::zeros(IxDyn(&[5])), "layer1");
766        let param2 = TQParameter::new(ArrayD::zeros(IxDyn(&[10])), "layer2");
767
768        registry.register(param1);
769        registry.register(param2);
770
771        assert_eq!(registry.count(), 15);
772        assert_eq!(registry.trainable_count(), 15);
773
774        registry.freeze("layer1").unwrap();
775        assert_eq!(registry.trainable_count(), 10);
776
777        let stats = registry.statistics();
778        assert_eq!(stats.total_params, 15);
779        assert_eq!(stats.trainable_params, 10);
780        assert_eq!(stats.frozen_params, 5);
781    }
782
783    #[test]
784    fn test_gradient_clipper_by_norm() {
785        let mut clipper = GradientClipper::by_norm(1.0);
786
787        let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
788        param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, 4.0]).unwrap());
789
790        // Gradient norm is 5.0, should be clipped to 1.0
791        clipper.clip(&mut [param]).unwrap();
792
793        assert!(clipper.was_clipped);
794        assert!((clipper.last_norm.unwrap() - 5.0).abs() < 1e-10);
795    }
796
797    #[test]
798    fn test_gradient_clipper_by_value() {
799        let mut clipper = GradientClipper::by_value(2.0);
800
801        let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
802        param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, -4.0]).unwrap());
803
804        clipper.clip(&mut [param]).unwrap();
805
806        assert!(clipper.was_clipped);
807    }
808
809    #[test]
810    fn test_parameter_group() {
811        let mut manager = ParameterGroupManager::new();
812
813        let mut group1 = ParameterGroup::new("backbone")
814            .with_lr_multiplier(0.1)
815            .with_weight_decay(0.01);
816        group1.add_param("layer1");
817        group1.add_param("layer2");
818
819        let mut group2 = ParameterGroup::new("head")
820            .with_lr_multiplier(1.0)
821            .with_weight_decay(0.0);
822        group2.add_param("output");
823
824        manager.add_group(group1);
825        manager.add_group(group2);
826
827        assert_eq!(manager.lr_multiplier("layer1"), 0.1);
828        assert_eq!(manager.lr_multiplier("output"), 1.0);
829        assert_eq!(manager.weight_decay("layer1"), 0.01);
830        assert_eq!(manager.weight_decay("output"), 0.0);
831    }
832
833    #[test]
834    fn test_gradient_statistics() {
835        let mut param1 = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "p1");
836        param1.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap());
837
838        let mut param2 = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "p2");
839        param2.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, 4.0]).unwrap());
840
841        let stats = gradient_statistics(&[param1, param2]);
842
843        assert!((stats.mean - 2.5).abs() < 1e-10);
844        assert_eq!(stats.min, 1.0);
845        assert_eq!(stats.max, 4.0);
846    }
847}