tensorlogic-train 0.1.0

Training loops, loss composition, and optimization schedules for TensorLogic
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
//! Gradient monitoring and accumulation callbacks.

use crate::callbacks::core::Callback;
use crate::{TrainError, TrainResult, TrainingState};
use std::collections::HashMap;

/// Gradient flow monitor for tracking gradient statistics during training.
///
/// This callback tracks gradient norms, mean, std, and identifies vanishing/exploding gradients.
/// Useful for debugging training issues and understanding gradient flow through the network.
///
/// # Example
/// ```rust,ignore
/// use tensorlogic_train::{GradientMonitor, CallbackList};
///
/// let mut callbacks = CallbackList::new();
/// callbacks.add(Box::new(GradientMonitor::new(
///     10,      // log_frequency
///     1e-7,    // vanishing_threshold
///     100.0,   // exploding_threshold
/// )));
/// ```
pub struct GradientMonitor {
    /// Frequency of logging (every N batches).
    log_frequency: usize,
    /// Threshold for detecting vanishing gradients.
    vanishing_threshold: f64,
    /// Threshold for detecting exploding gradients.
    exploding_threshold: f64,
    /// History of gradient norms.
    pub gradient_norms: Vec<f64>,
    /// History of gradient means.
    pub gradient_means: Vec<f64>,
    /// History of gradient stds.
    pub gradient_stds: Vec<f64>,
    /// Count of vanishing gradient warnings.
    pub vanishing_count: usize,
    /// Count of exploding gradient warnings.
    pub exploding_count: usize,
    /// Current batch counter.
    batch_counter: usize,
}

impl GradientMonitor {
    /// Create a new gradient monitor.
    ///
    /// # Arguments
    /// * `log_frequency` - Log statistics every N batches
    /// * `vanishing_threshold` - Threshold below which gradients are considered vanishing
    /// * `exploding_threshold` - Threshold above which gradients are considered exploding
    pub fn new(log_frequency: usize, vanishing_threshold: f64, exploding_threshold: f64) -> Self {
        Self {
            log_frequency,
            vanishing_threshold,
            exploding_threshold,
            gradient_norms: Vec::new(),
            gradient_means: Vec::new(),
            gradient_stds: Vec::new(),
            vanishing_count: 0,
            exploding_count: 0,
            batch_counter: 0,
        }
    }

    /// Compute gradient statistics (placeholder - actual implementation needs gradient access).
    fn compute_gradient_stats(&mut self, _state: &TrainingState) -> (f64, f64, f64) {
        // In a real implementation, this would access actual gradients
        // For now, return placeholder values
        // (norm, mean, std)
        (1.0, 0.0, 0.1)
    }

    /// Check for vanishing gradients.
    fn check_vanishing(&mut self, norm: f64) -> bool {
        if norm < self.vanishing_threshold {
            self.vanishing_count += 1;
            return true;
        }
        false
    }

    /// Check for exploding gradients.
    fn check_exploding(&mut self, norm: f64) -> bool {
        if norm > self.exploding_threshold {
            self.exploding_count += 1;
            return true;
        }
        false
    }

    /// Print gradient statistics.
    fn print_stats(&self, norm: f64, mean: f64, std: f64) {
        println!("Gradient Stats [Batch {}]:", self.batch_counter);
        println!("  Norm: {:.6e}, Mean: {:.6e}, Std: {:.6e}", norm, mean, std);

        if self.vanishing_count > 0 {
            println!(
                "  Warning: Vanishing gradient warnings: {}",
                self.vanishing_count
            );
        }

        if self.exploding_count > 0 {
            println!(
                "  Warning: Exploding gradient warnings: {}",
                self.exploding_count
            );
        }
    }

    /// Get summary statistics.
    pub fn summary(&self) -> GradientSummary {
        let avg_norm = if !self.gradient_norms.is_empty() {
            self.gradient_norms.iter().sum::<f64>() / self.gradient_norms.len() as f64
        } else {
            0.0
        };

        GradientSummary {
            total_batches: self.batch_counter,
            average_norm: avg_norm,
            vanishing_count: self.vanishing_count,
            exploding_count: self.exploding_count,
        }
    }
}

/// Summary of gradient statistics.
#[derive(Debug, Clone)]
pub struct GradientSummary {
    /// Total number of batches monitored.
    pub total_batches: usize,
    /// Average gradient norm.
    pub average_norm: f64,
    /// Number of vanishing gradient warnings.
    pub vanishing_count: usize,
    /// Number of exploding gradient warnings.
    pub exploding_count: usize,
}

impl Callback for GradientMonitor {
    fn on_batch_end(&mut self, _batch: usize, state: &TrainingState) -> TrainResult<()> {
        self.batch_counter += 1;

        // Compute gradient statistics
        let (norm, mean, std) = self.compute_gradient_stats(state);

        // Record statistics
        self.gradient_norms.push(norm);
        self.gradient_means.push(mean);
        self.gradient_stds.push(std);

        // Check for issues
        let vanishing = self.check_vanishing(norm);
        let exploding = self.check_exploding(norm);

        // Log if needed
        if self.batch_counter.is_multiple_of(self.log_frequency) {
            self.print_stats(norm, mean, std);
        } else if vanishing || exploding {
            // Always log warnings immediately
            self.print_stats(norm, mean, std);
        }

        Ok(())
    }

    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
        let summary = self.summary();
        println!("\n=== Gradient Monitoring Summary ===");
        println!("Total batches: {}", summary.total_batches);
        println!("Average gradient norm: {:.6e}", summary.average_norm);
        println!("Vanishing gradient warnings: {}", summary.vanishing_count);
        println!("Exploding gradient warnings: {}", summary.exploding_count);
        println!("====================================\n");
        Ok(())
    }
}

/// Gradient scaling strategy for accumulation.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GradientScalingStrategy {
    /// Divide by accumulation steps (default, maintains gradient magnitude)
    Average,
    /// Sum gradients without scaling (useful for some optimizers)
    Sum,
    /// Dynamic scaling based on batch size ratio
    Dynamic,
}

/// Gradient Accumulation callback with advanced features.
///
/// Simulates larger batch sizes by accumulating gradients over multiple
/// mini-batches before updating parameters. This is useful when GPU memory
/// is limited but you want to train with effectively larger batches.
///
/// Effective batch size = mini_batch_size * accumulation_steps
///
/// # Features
/// - Memory-efficient in-place accumulation
/// - Multiple scaling strategies
/// - Gradient overflow detection
/// - Memory usage tracking
/// - Automatic gradient zeroing
///
/// # Example
/// ```rust,ignore
/// use tensorlogic_train::{GradientAccumulationCallback, GradientScalingStrategy};
///
/// let mut grad_accum = GradientAccumulationCallback::new(
///     4, // accumulate over 4 mini-batches
///     GradientScalingStrategy::Average,
/// ).expect("unwrap");
/// ```
pub struct GradientAccumulationCallback {
    /// Number of steps to accumulate gradients before updating.
    accumulation_steps: usize,
    /// Current accumulation counter.
    current_step: usize,
    /// Accumulated gradients.
    accumulated_grads: HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
    /// Whether gradients are initialized.
    initialized: bool,
    /// Gradient scaling strategy.
    scaling_strategy: GradientScalingStrategy,
    /// Track maximum gradient norm seen during accumulation.
    max_grad_norm: f64,
    /// Track if overflow was detected.
    overflow_detected: bool,
    /// Total number of accumulation cycles completed.
    total_cycles: usize,
    /// Enable gradient clipping during accumulation.
    clip_grad_norm: Option<f64>,
}

impl GradientAccumulationCallback {
    /// Create a new Gradient Accumulation callback with default average scaling.
    ///
    /// # Arguments
    /// * `accumulation_steps` - Number of mini-batches to accumulate (e.g., 4, 8, 16)
    pub fn new(accumulation_steps: usize) -> TrainResult<Self> {
        Self::with_strategy(accumulation_steps, GradientScalingStrategy::Average)
    }

    /// Create a new Gradient Accumulation callback with specified scaling strategy.
    ///
    /// # Arguments
    /// * `accumulation_steps` - Number of mini-batches to accumulate
    /// * `scaling_strategy` - How to scale accumulated gradients
    pub fn with_strategy(
        accumulation_steps: usize,
        scaling_strategy: GradientScalingStrategy,
    ) -> TrainResult<Self> {
        if accumulation_steps == 0 {
            return Err(TrainError::CallbackError(
                "Accumulation steps must be greater than 0".to_string(),
            ));
        }

        Ok(Self {
            accumulation_steps,
            current_step: 0,
            accumulated_grads: HashMap::new(),
            initialized: false,
            scaling_strategy,
            max_grad_norm: 0.0,
            overflow_detected: false,
            total_cycles: 0,
            clip_grad_norm: None,
        })
    }

    /// Enable gradient clipping during accumulation.
    ///
    /// # Arguments
    /// * `max_norm` - Maximum gradient norm before clipping
    pub fn with_grad_clipping(mut self, max_norm: f64) -> Self {
        self.clip_grad_norm = Some(max_norm);
        self
    }

    /// Accumulate gradients with optional clipping and overflow detection.
    pub fn accumulate(
        &mut self,
        gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
    ) -> TrainResult<()> {
        // Check for NaN/Inf before accumulation
        for grad in gradients.values() {
            if grad.iter().any(|&x| x.is_nan() || x.is_infinite()) {
                self.overflow_detected = true;
                return Err(TrainError::CallbackError(
                    "Gradient overflow detected (NaN or Inf)".to_string(),
                ));
            }
        }

        // Compute gradient norm for monitoring
        let grad_norm = self.compute_total_norm(gradients);
        self.max_grad_norm = self.max_grad_norm.max(grad_norm);

        if !self.initialized {
            // Initialize on first call with zero-copy when possible
            for (name, grad) in gradients {
                let clipped_grad = if let Some(max_norm) = self.clip_grad_norm {
                    if grad_norm > max_norm {
                        let scale = max_norm / grad_norm;
                        grad * scale
                    } else {
                        grad.clone()
                    }
                } else {
                    grad.clone()
                };
                self.accumulated_grads.insert(name.clone(), clipped_grad);
            }
            self.initialized = true;
        } else {
            // In-place accumulation for memory efficiency
            for (name, grad) in gradients {
                if let Some(acc_grad) = self.accumulated_grads.get_mut(name) {
                    let grad_to_add = if let Some(max_norm) = self.clip_grad_norm {
                        if grad_norm > max_norm {
                            let scale = max_norm / grad_norm;
                            grad * scale
                        } else {
                            grad.clone()
                        }
                    } else {
                        grad.clone()
                    };

                    // In-place addition
                    *acc_grad = &*acc_grad + &grad_to_add;
                }
            }
        }

        self.current_step += 1;
        Ok(())
    }

    /// Compute the total L2 norm of all gradients.
    fn compute_total_norm(
        &self,
        gradients: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>>,
    ) -> f64 {
        let mut total_norm_sq = 0.0;
        for grad in gradients.values() {
            total_norm_sq += grad.iter().map(|&x| x * x).sum::<f64>();
        }
        total_norm_sq.sqrt()
    }

    /// Check if we should perform an optimizer step.
    pub fn should_update(&self) -> bool {
        self.current_step >= self.accumulation_steps
    }

    /// Get scaled accumulated gradients and reset state.
    pub fn get_and_reset(
        &mut self,
    ) -> HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::Ix2>> {
        let scale = match self.scaling_strategy {
            GradientScalingStrategy::Average => 1.0 / self.accumulation_steps as f64,
            GradientScalingStrategy::Sum => 1.0,
            GradientScalingStrategy::Dynamic => {
                // Dynamic scaling based on actual steps accumulated
                1.0 / self.current_step.max(1) as f64
            }
        };

        let mut scaled_grads = HashMap::new();
        for (name, grad) in &self.accumulated_grads {
            scaled_grads.insert(name.clone(), grad * scale);
        }

        // Update statistics
        self.total_cycles += 1;

        // Reset state
        self.current_step = 0;
        self.initialized = false;
        self.accumulated_grads.clear();
        self.max_grad_norm = 0.0;
        self.overflow_detected = false;

        scaled_grads
    }

    /// Get statistics about gradient accumulation.
    pub fn get_stats(&self) -> GradientAccumulationStats {
        let memory_usage = self.estimate_memory_usage();

        GradientAccumulationStats {
            accumulation_steps: self.accumulation_steps,
            current_step: self.current_step,
            total_cycles: self.total_cycles,
            max_grad_norm: self.max_grad_norm,
            overflow_detected: self.overflow_detected,
            num_parameters: self.accumulated_grads.len(),
            memory_usage_mb: memory_usage,
        }
    }

    /// Estimate memory usage of accumulated gradients in MB.
    fn estimate_memory_usage(&self) -> f64 {
        let mut total_elements = 0usize;
        for grad in self.accumulated_grads.values() {
            total_elements += grad.len();
        }
        // f64 = 8 bytes
        (total_elements * 8) as f64 / (1024.0 * 1024.0)
    }

    /// Reset all state without returning gradients (useful for error recovery).
    pub fn reset(&mut self) {
        self.current_step = 0;
        self.initialized = false;
        self.accumulated_grads.clear();
        self.max_grad_norm = 0.0;
        self.overflow_detected = false;
    }
}

/// Statistics for gradient accumulation.
#[derive(Debug, Clone)]
pub struct GradientAccumulationStats {
    /// Configured accumulation steps.
    pub accumulation_steps: usize,
    /// Current step in accumulation.
    pub current_step: usize,
    /// Total completed cycles.
    pub total_cycles: usize,
    /// Maximum gradient norm seen.
    pub max_grad_norm: f64,
    /// Whether overflow was detected.
    pub overflow_detected: bool,
    /// Number of parameters being accumulated.
    pub num_parameters: usize,
    /// Estimated memory usage in MB.
    pub memory_usage_mb: f64,
}

impl Callback for GradientAccumulationCallback {
    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
        // Reset at the beginning of each epoch
        self.current_step = 0;
        self.initialized = false;
        self.accumulated_grads.clear();
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use scirs2_core::ndarray::Array2;

    fn create_test_gradients() -> HashMap<String, Array2<f64>> {
        let mut grads = HashMap::new();
        grads.insert(
            "layer1".to_string(),
            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("unwrap"),
        );
        grads.insert(
            "layer2".to_string(),
            Array2::from_shape_vec((2, 2), vec![0.5, 1.0, 1.5, 2.0]).expect("unwrap"),
        );
        grads
    }

    #[test]
    fn test_gradient_accumulation_average_strategy() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
        let grads = create_test_gradients();

        // First accumulation
        accum.accumulate(&grads).expect("unwrap");
        assert_eq!(accum.current_step, 1);
        assert!(!accum.should_update());

        // Second accumulation
        accum.accumulate(&grads).expect("unwrap");
        assert_eq!(accum.current_step, 2);
        assert!(accum.should_update());

        // Get averaged gradients
        let averaged = accum.get_and_reset();
        let layer1 = averaged.get("layer1").expect("unwrap");

        // Should be average of 2 accumulations (same gradient twice)
        assert_eq!(layer1[[0, 0]], 1.0); // (1.0 + 1.0) / 2
        assert_eq!(layer1[[0, 1]], 2.0); // (2.0 + 2.0) / 2

        // Should be reset
        assert_eq!(accum.current_step, 0);
    }

    #[test]
    fn test_gradient_accumulation_sum_strategy() {
        let mut accum =
            GradientAccumulationCallback::with_strategy(2, GradientScalingStrategy::Sum)
                .expect("unwrap");
        let grads = create_test_gradients();

        accum.accumulate(&grads).expect("unwrap");
        accum.accumulate(&grads).expect("unwrap");

        let summed = accum.get_and_reset();
        let layer1 = summed.get("layer1").expect("unwrap");

        // Should be sum (no scaling)
        assert_eq!(layer1[[0, 0]], 2.0); // 1.0 + 1.0
        assert_eq!(layer1[[0, 1]], 4.0); // 2.0 + 2.0
    }

    #[test]
    fn test_gradient_accumulation_dynamic_strategy() {
        let mut accum =
            GradientAccumulationCallback::with_strategy(4, GradientScalingStrategy::Dynamic)
                .expect("unwrap");
        let grads = create_test_gradients();

        // Accumulate only 3 times (less than configured 4)
        accum.accumulate(&grads).expect("unwrap");
        accum.accumulate(&grads).expect("unwrap");
        accum.accumulate(&grads).expect("unwrap");

        let scaled = accum.get_and_reset();
        let layer1 = scaled.get("layer1").expect("unwrap");

        // Should scale by actual steps (3) not configured steps (4)
        assert_eq!(layer1[[0, 0]], 1.0); // (1.0 + 1.0 + 1.0) / 3
    }

    #[test]
    fn test_gradient_clipping_during_accumulation() {
        let mut accum = GradientAccumulationCallback::new(2)
            .expect("unwrap")
            .with_grad_clipping(1.0); // Very small max norm

        let mut grads = HashMap::new();
        grads.insert(
            "layer1".to_string(),
            Array2::from_shape_vec((2, 2), vec![10.0, 10.0, 10.0, 10.0]).expect("unwrap"),
        );

        // Large gradients should be clipped
        accum.accumulate(&grads).expect("unwrap");
        assert!(accum.max_grad_norm > 0.0);

        // Accumulated gradients should be clipped
        let accumulated = &accum.accumulated_grads["layer1"];
        let norm_sq: f64 = accumulated.iter().map(|&x| x * x).sum();
        let norm = norm_sq.sqrt();

        // Norm should be at or below clip threshold
        assert!(norm <= 1.1); // Small tolerance
    }

    #[test]
    fn test_overflow_detection() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");

        let mut grads = HashMap::new();
        grads.insert(
            "layer1".to_string(),
            Array2::from_shape_vec((2, 2), vec![f64::NAN, 1.0, 2.0, 3.0]).expect("unwrap"),
        );

        // Should detect NaN
        let result = accum.accumulate(&grads);
        assert!(result.is_err());
        assert!(accum.overflow_detected);
    }

    #[test]
    fn test_gradient_accumulation_stats() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
        let grads = create_test_gradients();

        accum.accumulate(&grads).expect("unwrap");
        accum.accumulate(&grads).expect("unwrap");
        accum.get_and_reset();

        let stats = accum.get_stats();
        assert_eq!(stats.accumulation_steps, 2);
        assert_eq!(stats.total_cycles, 1);
        assert!(!stats.overflow_detected);
    }

    #[test]
    fn test_memory_usage_estimation() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
        let grads = create_test_gradients();

        accum.accumulate(&grads).expect("unwrap");

        let stats = accum.get_stats();
        assert!(stats.memory_usage_mb > 0.0);
        assert_eq!(stats.num_parameters, 2); // 2 layers
    }

    #[test]
    fn test_gradient_accumulation_reset() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
        let grads = create_test_gradients();

        accum.accumulate(&grads).expect("unwrap");
        assert_eq!(accum.current_step, 1);

        accum.reset();
        assert_eq!(accum.current_step, 0);
        assert!(!accum.initialized);
        assert_eq!(accum.accumulated_grads.len(), 0);
    }

    #[test]
    fn test_gradient_accumulation_zero_steps_error() {
        let result = GradientAccumulationCallback::new(0);
        assert!(result.is_err());
    }

    #[test]
    fn test_gradient_accumulation_multiple_cycles() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");
        let grads = create_test_gradients();

        // First cycle
        accum.accumulate(&grads).expect("unwrap");
        accum.accumulate(&grads).expect("unwrap");
        accum.get_and_reset();

        // Second cycle
        accum.accumulate(&grads).expect("unwrap");
        accum.accumulate(&grads).expect("unwrap");
        accum.get_and_reset();

        let stats = accum.get_stats();
        assert_eq!(stats.total_cycles, 2);
    }

    #[test]
    fn test_different_gradient_shapes() {
        let mut accum = GradientAccumulationCallback::new(2).expect("unwrap");

        let mut grads1 = HashMap::new();
        grads1.insert(
            "layer1".to_string(),
            Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("unwrap"),
        );

        let mut grads2 = HashMap::new();
        grads2.insert(
            "layer1".to_string(),
            Array2::from_shape_vec((2, 3), vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0]).expect("unwrap"),
        );

        accum.accumulate(&grads1).expect("unwrap");
        accum.accumulate(&grads2).expect("unwrap");

        let averaged = accum.get_and_reset();
        let layer1 = averaged.get("layer1").expect("unwrap");

        assert_eq!(layer1.dim(), (2, 3));
        assert_eq!(layer1[[0, 0]], 0.75); // (1.0 + 0.5) / 2
    }
}