1use super::{TQModule, TQParameter};
18use crate::error::{MLError, Result};
19use scirs2_core::ndarray::{Array1, ArrayD, Axis, IxDyn};
20use std::collections::HashMap;
21
22#[derive(Debug, Clone)]
34pub struct GradientAccumulator {
35 pub accumulation_steps: usize,
37 current_step: usize,
39 accumulated_grads: HashMap<String, ArrayD<f64>>,
41 average: bool,
43}
44
45impl GradientAccumulator {
46 pub fn new(accumulation_steps: usize) -> Self {
48 Self {
49 accumulation_steps,
50 current_step: 0,
51 accumulated_grads: HashMap::new(),
52 average: true,
53 }
54 }
55
56 pub fn with_sum(accumulation_steps: usize) -> Self {
58 Self {
59 accumulation_steps,
60 current_step: 0,
61 accumulated_grads: HashMap::new(),
62 average: false,
63 }
64 }
65
66 pub fn accumulate(&mut self, params: &[TQParameter]) -> Result<()> {
68 for param in params {
69 if !param.requires_grad {
70 continue;
71 }
72
73 if let Some(grad) = ¶m.grad {
74 let entry = self
75 .accumulated_grads
76 .entry(param.name.clone())
77 .or_insert_with(|| ArrayD::zeros(grad.raw_dim()));
78
79 *entry = &*entry + grad;
80 }
81 }
82
83 self.current_step += 1;
84 Ok(())
85 }
86
87 pub fn is_ready(&self) -> bool {
89 self.current_step >= self.accumulation_steps
90 }
91
92 pub fn get_and_reset(&mut self) -> HashMap<String, ArrayD<f64>> {
94 let mut result = std::mem::take(&mut self.accumulated_grads);
95
96 if self.average && self.accumulation_steps > 1 {
97 let scale = 1.0 / self.accumulation_steps as f64;
98 for grad in result.values_mut() {
99 *grad = &*grad * scale;
100 }
101 }
102
103 self.current_step = 0;
104 result
105 }
106
107 pub fn reset(&mut self) {
109 self.accumulated_grads.clear();
110 self.current_step = 0;
111 }
112
113 pub fn step_count(&self) -> usize {
115 self.current_step
116 }
117}
118
119#[derive(Debug)]
131pub struct ParameterRegistry {
132 parameters: HashMap<String, TQParameter>,
134 frozen: Vec<String>,
136}
137
138impl ParameterRegistry {
139 pub fn new() -> Self {
141 Self {
142 parameters: HashMap::new(),
143 frozen: Vec::new(),
144 }
145 }
146
147 pub fn register_module(&mut self, module: &dyn TQModule) -> Result<()> {
149 let params = module.parameters();
150 for param in params {
151 self.parameters.insert(param.name.clone(), param);
152 }
153 Ok(())
154 }
155
156 pub fn register(&mut self, param: TQParameter) {
158 self.parameters.insert(param.name.clone(), param);
159 }
160
161 pub fn get(&self, name: &str) -> Option<&TQParameter> {
163 self.parameters.get(name)
164 }
165
166 pub fn get_mut(&mut self, name: &str) -> Option<&mut TQParameter> {
168 self.parameters.get_mut(name)
169 }
170
171 pub fn trainable_parameters(&self) -> Vec<&TQParameter> {
173 self.parameters
174 .values()
175 .filter(|p| p.requires_grad && !self.frozen.contains(&p.name))
176 .collect()
177 }
178
179 pub fn parameter_names(&self) -> Vec<&str> {
181 self.parameters.keys().map(|s| s.as_str()).collect()
182 }
183
184 pub fn count(&self) -> usize {
186 self.parameters.values().map(|p| p.numel()).sum()
187 }
188
189 pub fn trainable_count(&self) -> usize {
191 self.trainable_parameters().iter().map(|p| p.numel()).sum()
192 }
193
194 pub fn freeze(&mut self, name: &str) -> Result<()> {
196 if !self.parameters.contains_key(name) {
197 return Err(MLError::InvalidConfiguration(format!(
198 "Parameter '{}' not found",
199 name
200 )));
201 }
202 if !self.frozen.contains(&name.to_string()) {
203 self.frozen.push(name.to_string());
204 }
205 Ok(())
206 }
207
208 pub fn unfreeze(&mut self, name: &str) -> Result<()> {
210 self.frozen.retain(|n| n != name);
211 Ok(())
212 }
213
214 pub fn freeze_all(&mut self) {
216 self.frozen = self.parameters.keys().cloned().collect();
217 }
218
219 pub fn unfreeze_all(&mut self) {
221 self.frozen.clear();
222 }
223
224 pub fn zero_grad(&mut self) {
226 for param in self.parameters.values_mut() {
227 param.zero_grad();
228 }
229 }
230
231 pub fn memory_bytes(&self) -> usize {
233 self.parameters.values().map(|p| p.numel() * 8).sum() }
235
236 pub fn statistics(&self) -> ParameterStatistics {
238 let total_params = self.count();
239 let trainable_params = self.trainable_count();
240 let memory_mb = self.memory_bytes() as f64 / (1024.0 * 1024.0);
241
242 ParameterStatistics {
243 total_params,
244 trainable_params,
245 frozen_params: total_params - trainable_params,
246 memory_mb,
247 }
248 }
249}
250
251impl Default for ParameterRegistry {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct ParameterStatistics {
260 pub total_params: usize,
261 pub trainable_params: usize,
262 pub frozen_params: usize,
263 pub memory_mb: f64,
264}
265
266#[derive(Debug, Clone, Copy, PartialEq)]
272pub enum ClippingStrategy {
273 Norm { max_norm: f64 },
275 Value { clip_value: f64 },
277 Adaptive { clip_factor: f64 },
279}
280
281pub struct GradientClipper {
288 strategy: ClippingStrategy,
289 pub last_norm: Option<f64>,
291 pub was_clipped: bool,
292}
293
294impl GradientClipper {
295 pub fn by_norm(max_norm: f64) -> Self {
297 Self {
298 strategy: ClippingStrategy::Norm { max_norm },
299 last_norm: None,
300 was_clipped: false,
301 }
302 }
303
304 pub fn by_value(clip_value: f64) -> Self {
306 Self {
307 strategy: ClippingStrategy::Value { clip_value },
308 last_norm: None,
309 was_clipped: false,
310 }
311 }
312
313 pub fn adaptive(clip_factor: f64) -> Self {
315 Self {
316 strategy: ClippingStrategy::Adaptive { clip_factor },
317 last_norm: None,
318 was_clipped: false,
319 }
320 }
321
322 pub fn clip(&mut self, params: &mut [TQParameter]) -> Result<()> {
324 match self.strategy {
325 ClippingStrategy::Norm { max_norm } => self.clip_by_norm(params, max_norm),
326 ClippingStrategy::Value { clip_value } => self.clip_by_value(params, clip_value),
327 ClippingStrategy::Adaptive { clip_factor } => self.clip_adaptive(params, clip_factor),
328 }
329 }
330
331 fn clip_by_norm(&mut self, params: &mut [TQParameter], max_norm: f64) -> Result<()> {
332 let mut total_norm_sq = 0.0;
334 for param in params.iter() {
335 if let Some(grad) = ¶m.grad {
336 for &val in grad.iter() {
337 total_norm_sq += val * val;
338 }
339 }
340 }
341
342 let total_norm = total_norm_sq.sqrt();
343 self.last_norm = Some(total_norm);
344
345 if total_norm > max_norm {
346 let scale = max_norm / (total_norm + 1e-10);
347 for param in params {
348 if let Some(grad) = &mut param.grad {
349 *grad = &*grad * scale;
350 }
351 }
352 self.was_clipped = true;
353 } else {
354 self.was_clipped = false;
355 }
356
357 Ok(())
358 }
359
360 fn clip_by_value(&mut self, params: &mut [TQParameter], clip_value: f64) -> Result<()> {
361 self.was_clipped = false;
362
363 for param in params {
364 if let Some(grad) = &mut param.grad {
365 for val in grad.iter_mut() {
366 if val.abs() > clip_value {
367 *val = val.signum() * clip_value;
368 self.was_clipped = true;
369 }
370 }
371 }
372 }
373
374 Ok(())
375 }
376
377 fn clip_adaptive(&mut self, params: &mut [TQParameter], clip_factor: f64) -> Result<()> {
378 self.was_clipped = false;
379
380 for param in params {
381 if let Some(grad) = &mut param.grad {
382 let param_norm: f64 = param.data.iter().map(|&v| v * v).sum::<f64>().sqrt();
384 let max_grad = param_norm * clip_factor;
385
386 let grad_norm: f64 = grad.iter().map(|&v| v * v).sum::<f64>().sqrt();
388
389 if grad_norm > max_grad {
390 let scale = max_grad / (grad_norm + 1e-10);
391 *grad = &*grad * scale;
392 self.was_clipped = true;
393 }
394 }
395 }
396
397 Ok(())
398 }
399
400 pub fn statistics(&self) -> ClippingStatistics {
402 ClippingStatistics {
403 was_clipped: self.was_clipped,
404 last_norm: self.last_norm,
405 strategy: self.strategy,
406 }
407 }
408}
409
410#[derive(Debug, Clone)]
412pub struct ClippingStatistics {
413 pub was_clipped: bool,
414 pub last_norm: Option<f64>,
415 pub strategy: ClippingStrategy,
416}
417
418pub struct GradientChecker {
427 pub epsilon: f64,
429 pub rtol: f64,
431 pub atol: f64,
433}
434
435impl GradientChecker {
436 pub fn new() -> Self {
438 Self {
439 epsilon: 1e-5,
440 rtol: 1e-3,
441 atol: 1e-5,
442 }
443 }
444
445 pub fn with_epsilon(epsilon: f64) -> Self {
447 Self {
448 epsilon,
449 rtol: 1e-3,
450 atol: 1e-5,
451 }
452 }
453
454 pub fn with_tolerances(epsilon: f64, rtol: f64, atol: f64) -> Self {
456 Self {
457 epsilon,
458 rtol,
459 atol,
460 }
461 }
462
463 pub fn numerical_gradient<F>(
468 &self,
469 param: &mut TQParameter,
470 param_idx: usize,
471 loss_fn: &mut F,
472 ) -> Result<f64>
473 where
474 F: FnMut() -> Result<f64>,
475 {
476 let flat_idx = self.flat_index(param_idx, param.shape());
478 let original =
479 param.data.as_slice_mut().ok_or_else(|| {
480 MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
481 })?[flat_idx];
482
483 param.data.as_slice_mut().ok_or_else(|| {
485 MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
486 })?[flat_idx] = original + self.epsilon;
487 let loss_plus = loss_fn()?;
488
489 param.data.as_slice_mut().ok_or_else(|| {
491 MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
492 })?[flat_idx] = original - self.epsilon;
493 let loss_minus = loss_fn()?;
494
495 param.data.as_slice_mut().ok_or_else(|| {
497 MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
498 })?[flat_idx] = original;
499
500 Ok((loss_plus - loss_minus) / (2.0 * self.epsilon))
502 }
503
504 pub fn check_gradient(&self, analytical: f64, numerical: f64) -> GradientCheckResult {
506 let abs_diff = (analytical - numerical).abs();
507 let rel_diff = if numerical.abs() > 1e-10 {
508 abs_diff / numerical.abs()
509 } else {
510 abs_diff
511 };
512
513 let matches = abs_diff <= self.atol || rel_diff <= self.rtol;
514
515 GradientCheckResult {
516 analytical,
517 numerical,
518 abs_diff,
519 rel_diff,
520 matches,
521 }
522 }
523
524 fn flat_index(&self, idx: usize, shape: &[usize]) -> usize {
525 idx
526 }
527}
528
529impl Default for GradientChecker {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535#[derive(Debug, Clone)]
537pub struct GradientCheckResult {
538 pub analytical: f64,
539 pub numerical: f64,
540 pub abs_diff: f64,
541 pub rel_diff: f64,
542 pub matches: bool,
543}
544
545#[derive(Debug, Clone)]
554pub struct ParameterGroup {
555 pub name: String,
557 pub param_names: Vec<String>,
559 pub lr_multiplier: f64,
561 pub weight_decay: f64,
563 pub requires_grad: bool,
565}
566
567impl ParameterGroup {
568 pub fn new(name: impl Into<String>) -> Self {
570 Self {
571 name: name.into(),
572 param_names: Vec::new(),
573 lr_multiplier: 1.0,
574 weight_decay: 0.0,
575 requires_grad: true,
576 }
577 }
578
579 pub fn add_param(&mut self, param_name: impl Into<String>) {
581 self.param_names.push(param_name.into());
582 }
583
584 pub fn with_lr_multiplier(mut self, multiplier: f64) -> Self {
586 self.lr_multiplier = multiplier;
587 self
588 }
589
590 pub fn with_weight_decay(mut self, decay: f64) -> Self {
592 self.weight_decay = decay;
593 self
594 }
595
596 pub fn with_requires_grad(mut self, requires_grad: bool) -> Self {
598 self.requires_grad = requires_grad;
599 self
600 }
601
602 pub fn contains(&self, param_name: &str) -> bool {
604 self.param_names.iter().any(|n| n == param_name)
605 }
606}
607
608#[derive(Debug)]
610pub struct ParameterGroupManager {
611 groups: Vec<ParameterGroup>,
612}
613
614impl ParameterGroupManager {
615 pub fn new() -> Self {
617 Self { groups: Vec::new() }
618 }
619
620 pub fn add_group(&mut self, group: ParameterGroup) {
622 self.groups.push(group);
623 }
624
625 pub fn get_group(&self, param_name: &str) -> Option<&ParameterGroup> {
627 self.groups.iter().find(|g| g.contains(param_name))
628 }
629
630 pub fn groups(&self) -> &[ParameterGroup] {
632 &self.groups
633 }
634
635 pub fn lr_multiplier(&self, param_name: &str) -> f64 {
637 self.get_group(param_name)
638 .map(|g| g.lr_multiplier)
639 .unwrap_or(1.0)
640 }
641
642 pub fn weight_decay(&self, param_name: &str) -> f64 {
644 self.get_group(param_name)
645 .map(|g| g.weight_decay)
646 .unwrap_or(0.0)
647 }
648
649 pub fn requires_grad(&self, param_name: &str) -> bool {
651 self.get_group(param_name)
652 .map(|g| g.requires_grad)
653 .unwrap_or(true)
654 }
655}
656
657impl Default for ParameterGroupManager {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663pub fn gradient_norm(params: &[TQParameter]) -> f64 {
669 let mut norm_sq = 0.0;
670 for param in params {
671 if let Some(grad) = ¶m.grad {
672 for &val in grad.iter() {
673 norm_sq += val * val;
674 }
675 }
676 }
677 norm_sq.sqrt()
678}
679
680pub fn gradient_statistics(params: &[TQParameter]) -> GradientStatistics {
682 let mut all_grads = Vec::new();
683 for param in params {
684 if let Some(grad) = ¶m.grad {
685 all_grads.extend(grad.iter().copied());
686 }
687 }
688
689 if all_grads.is_empty() {
690 return GradientStatistics::default();
691 }
692
693 let n = all_grads.len() as f64;
694 let mean = all_grads.iter().sum::<f64>() / n;
695 let variance = all_grads.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / n;
696 let std = variance.sqrt();
697
698 let min = all_grads
699 .iter()
700 .copied()
701 .min_by(|a, b| a.partial_cmp(b).unwrap())
702 .unwrap_or(0.0);
703 let max = all_grads
704 .iter()
705 .copied()
706 .max_by(|a, b| a.partial_cmp(b).unwrap())
707 .unwrap_or(0.0);
708
709 let norm = gradient_norm(params);
710
711 GradientStatistics {
712 mean,
713 std,
714 min,
715 max,
716 norm,
717 }
718}
719
720#[derive(Debug, Clone, Default)]
722pub struct GradientStatistics {
723 pub mean: f64,
724 pub std: f64,
725 pub min: f64,
726 pub max: f64,
727 pub norm: f64,
728}
729
730#[cfg(test)]
735mod tests {
736 use super::*;
737 use scirs2_core::ndarray::ArrayD;
738
739 #[test]
740 fn test_gradient_accumulator() {
741 let mut acc = GradientAccumulator::new(3);
742
743 let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
744 param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap());
745
746 for _ in 0..3 {
748 acc.accumulate(&[param.clone()]).unwrap();
749 }
750
751 assert!(acc.is_ready());
752
753 let grads = acc.get_and_reset();
754 let test_grad = &grads["test"];
755
756 assert!((test_grad[[0]] - 1.0).abs() < 1e-10);
758 assert!((test_grad[[1]] - 2.0).abs() < 1e-10);
759 }
760
761 #[test]
762 fn test_parameter_registry() {
763 let mut registry = ParameterRegistry::new();
764
765 let param1 = TQParameter::new(ArrayD::zeros(IxDyn(&[5])), "layer1");
766 let param2 = TQParameter::new(ArrayD::zeros(IxDyn(&[10])), "layer2");
767
768 registry.register(param1);
769 registry.register(param2);
770
771 assert_eq!(registry.count(), 15);
772 assert_eq!(registry.trainable_count(), 15);
773
774 registry.freeze("layer1").unwrap();
775 assert_eq!(registry.trainable_count(), 10);
776
777 let stats = registry.statistics();
778 assert_eq!(stats.total_params, 15);
779 assert_eq!(stats.trainable_params, 10);
780 assert_eq!(stats.frozen_params, 5);
781 }
782
783 #[test]
784 fn test_gradient_clipper_by_norm() {
785 let mut clipper = GradientClipper::by_norm(1.0);
786
787 let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
788 param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, 4.0]).unwrap());
789
790 clipper.clip(&mut [param]).unwrap();
792
793 assert!(clipper.was_clipped);
794 assert!((clipper.last_norm.unwrap() - 5.0).abs() < 1e-10);
795 }
796
797 #[test]
798 fn test_gradient_clipper_by_value() {
799 let mut clipper = GradientClipper::by_value(2.0);
800
801 let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
802 param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, -4.0]).unwrap());
803
804 clipper.clip(&mut [param]).unwrap();
805
806 assert!(clipper.was_clipped);
807 }
808
809 #[test]
810 fn test_parameter_group() {
811 let mut manager = ParameterGroupManager::new();
812
813 let mut group1 = ParameterGroup::new("backbone")
814 .with_lr_multiplier(0.1)
815 .with_weight_decay(0.01);
816 group1.add_param("layer1");
817 group1.add_param("layer2");
818
819 let mut group2 = ParameterGroup::new("head")
820 .with_lr_multiplier(1.0)
821 .with_weight_decay(0.0);
822 group2.add_param("output");
823
824 manager.add_group(group1);
825 manager.add_group(group2);
826
827 assert_eq!(manager.lr_multiplier("layer1"), 0.1);
828 assert_eq!(manager.lr_multiplier("output"), 1.0);
829 assert_eq!(manager.weight_decay("layer1"), 0.01);
830 assert_eq!(manager.weight_decay("output"), 0.0);
831 }
832
833 #[test]
834 fn test_gradient_statistics() {
835 let mut param1 = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "p1");
836 param1.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap());
837
838 let mut param2 = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "p2");
839 param2.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, 4.0]).unwrap());
840
841 let stats = gradient_statistics(&[param1, param2]);
842
843 assert!((stats.mean - 2.5).abs() < 1e-10);
844 assert_eq!(stats.min, 1.0);
845 assert_eq!(stats.max, 4.0);
846 }
847}