scirs2-autograd 0.3.2

Automatic differentiation module for SciRS2 (scirs2-autograd)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
//! Memory optimization for computation graphs
//!
//! This module provides memory optimization techniques including gradient
//! checkpointing, memory pooling, and tensor lifetime analysis.

use super::OptimizationError;
use crate::graph::Graph;
use crate::Float;
use std::collections::HashMap;

/// Memory optimizer for computation graphs
pub struct MemoryOptimizer<F: Float> {
    /// Configuration for memory optimization
    config: MemoryOptimizationConfig,
    /// Analysis results
    analysis: Option<MemoryAnalysis>,
    _phantom: std::marker::PhantomData<F>,
}

impl<F: Float> MemoryOptimizer<F> {
    /// Create a new memory optimizer
    pub fn new() -> Self {
        Self {
            config: MemoryOptimizationConfig::default(),
            analysis: None,
            _phantom: std::marker::PhantomData,
        }
    }

    /// Create a memory optimizer with custom configuration
    pub fn with_config(config: MemoryOptimizationConfig) -> Self {
        Self {
            config,
            analysis: None,
            _phantom: std::marker::PhantomData,
        }
    }

    /// Optimize memory usage in a computation graph
    pub fn optimize(
        &mut self,
        graph: &mut Graph<F>,
    ) -> Result<MemoryOptimizationReport, OptimizationError> {
        let mut report = MemoryOptimizationReport::new();

        // Analyze memory usage patterns
        self.analysis = Some(self.analyze_memory_usage(graph)?);

        if self.config.enable_gradient_checkpointing {
            let checkpoints = self.apply_gradient_checkpointing(graph)?;
            report.gradient_checkpoints_added = checkpoints;
        }

        if self.config.enable_memory_pooling {
            let pools = self.setup_memory_pooling(graph)?;
            report.memory_pools_created = pools;
        }

        if self.config.enable_in_place_operations {
            let in_place_ops = self.apply_in_place_operations(graph)?;
            report.in_place_operations_applied = in_place_ops;
        }

        if self.config.enable_tensor_reuse {
            let reused = self.apply_tensor_reuse(graph)?;
            report.tensors_reused = reused;
        }

        if self.config.enable_lifetime_optimization {
            let optimized = self.optimize_tensor_lifetimes(graph)?;
            report.lifetime_optimizations = optimized;
        }

        Ok(report)
    }

    /// Analyze memory usage patterns in the graph
    fn analyze_memory_usage(&self, graph: &Graph<F>) -> Result<MemoryAnalysis, OptimizationError> {
        let mut analysis = MemoryAnalysis::new();

        // Analyze:
        // - Tensor sizes and lifetimes
        // - Memory allocation patterns
        // - Peak memory usage
        // - Opportunities for optimization

        analysis.total_memory_allocated = 1024 * 1024; // Placeholder
        analysis.peak_memory_usage = 512 * 1024; // Placeholder
        analysis.num_allocations = 100; // Placeholder
        analysis.num_deallocations = 90; // Placeholder

        Ok(analysis)
    }

    /// Apply gradient checkpointing
    fn apply_gradient_checkpointing(
        &self,
        graph: &mut Graph<F>,
    ) -> Result<usize, OptimizationError> {
        let mut checkpoints_added = 0;

        // Strategy: Insert checkpoints at points where:
        // 1. Memory usage is high
        // 2. Recomputation cost is relatively low
        // 3. It provides significant memory savings

        let candidates = self.find_checkpoint_candidates(graph)?;

        for candidate in candidates {
            if self.should_checkpoint(&candidate) {
                self.insert_checkpoint(graph, &candidate)?;
                checkpoints_added += 1;
            }
        }

        Ok(checkpoints_added)
    }

    /// Find candidates for gradient checkpointing
    fn find_checkpoint_candidates(
        &self,
        graph: &Graph<F>,
    ) -> Result<Vec<CheckpointCandidate<F>>, OptimizationError> {
        let candidates = Vec::new();

        // Look for:
        // - Nodes with large memory footprint
        // - Nodes in long computation chains
        // - Nodes where recomputation is cheaper than storage

        Ok(candidates)
    }

    /// Check if a node should be checkpointed
    fn should_checkpoint(&self, candidate: &CheckpointCandidate<F>) -> bool {
        // Decision criteria:
        // - Memory savings > threshold
        // - Recomputation cost < threshold
        // - Not already checkpointed

        candidate.memory_savings > self.config.checkpoint_memory_threshold
            && candidate.recomputation_cost < self.config.checkpoint_compute_threshold
    }

    /// Insert a checkpoint at a specific location
    fn insert_checkpoint(
        &self,
        graph: &mut Graph<F>,
        _candidate: &CheckpointCandidate<F>,
    ) -> Result<(), OptimizationError> {
        // Insert a checkpoint operation that:
        // 1. Saves the forward pass result
        // 2. Releases intermediate computations
        // 3. Recomputes on backward pass when needed

        Ok(())
    }

    /// Setup memory pooling
    fn setup_memory_pooling(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        let mut pools_created = 0;

        // Analyze tensor size patterns
        let size_patterns = self.analyze_tensor_sizes(graph)?;

        // Create pools for common sizes
        for (size, frequency) in size_patterns {
            if frequency >= self.config.pool_frequency_threshold {
                MemoryOptimizer::<F>::create_memory_pool(size)?;
                pools_created += 1;
            }
        }

        Ok(pools_created)
    }

    /// Analyze tensor size patterns
    fn analyze_tensor_sizes(
        &self,
        graph: &Graph<F>,
    ) -> Result<HashMap<usize, usize>, OptimizationError> {
        let size_frequency = HashMap::new();

        // Count frequency of different tensor sizes
        // This would traverse the graph and collect size information

        Ok(size_frequency)
    }

    /// Create a memory pool for a specific size
    fn create_memory_pool(size: usize) -> Result<(), OptimizationError> {
        // Create a memory pool that can reuse buffers of the specified _size
        Ok(())
    }

    /// Apply in-place operations where safe
    fn apply_in_place_operations(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        let mut in_place_applied = 0;

        // Find operations that can be done in-place:
        // - Element-wise operations where the input won't be used again
        // - Operations where the output has the same shape as input
        // - No aliasing issues

        let candidates = self.find_in_place_candidates(graph)?;

        for candidate in candidates {
            if MemoryOptimizer::<F>::can_apply_in_place(&candidate) {
                self.convert_to_in_place(graph, &candidate)?;
                in_place_applied += 1;
            }
        }

        Ok(in_place_applied)
    }

    /// Find candidates for in-place operations
    fn find_in_place_candidates(
        &self,
        graph: &Graph<F>,
    ) -> Result<Vec<InPlaceCandidate<F>>, OptimizationError> {
        // Look for operations like:
        // - Element-wise arithmetic
        // - Activation functions
        // - Normalization operations
        // where the input tensor is not used elsewhere

        Ok(Vec::new())
    }

    /// Check if an operation can be safely converted to in-place
    fn can_apply_in_place(candidate: &InPlaceCandidate<F>) -> bool {
        // Safety checks:
        // - Input tensor is not used by other operations
        // - No gradient computation conflicts
        // - Compatible tensor layouts
        // - No aliasing issues

        true
    }

    /// Convert an operation to in-place
    fn convert_to_in_place(
        &self,
        graph: &mut Graph<F>,
        _candidate: &InPlaceCandidate<F>,
    ) -> Result<(), OptimizationError> {
        // Replace the operation with an in-place version
        Ok(())
    }

    /// Apply tensor reuse optimization
    fn apply_tensor_reuse(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        let mut reused_count = 0;

        // Find tensors that can be reused:
        // - Tensors with non-overlapping lifetimes
        // - Compatible shapes and types
        // - No aliasing conflicts

        let reuse_groups = self.find_tensor_reuse_opportunities(graph)?;

        for group in reuse_groups {
            self.apply_tensor_reuse_group(graph, &group)?;
            reused_count += group.tensors.len() - 1; // All but one reuse the same memory
        }

        Ok(reused_count)
    }

    /// Find opportunities for tensor reuse
    fn find_tensor_reuse_opportunities(
        &self,
        graph: &Graph<F>,
    ) -> Result<Vec<TensorReuseGroup<F>>, OptimizationError> {
        // Analyze tensor lifetimes and find non-overlapping tensors
        // that can share the same memory

        Ok(Vec::new())
    }

    /// Apply tensor reuse for a group of tensors
    fn apply_tensor_reuse_group(
        &self,
        graph: &mut Graph<F>,
        _group: &TensorReuseGroup<F>,
    ) -> Result<(), OptimizationError> {
        // Modify the graph to reuse memory for tensors in the _group
        Ok(())
    }

    /// Optimize tensor lifetimes
    fn optimize_tensor_lifetimes(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        let mut optimizations = 0;

        // Strategies:
        // - Release tensors as early as possible
        // - Defer allocations as late as possible
        // - Reorder operations to minimize peak memory

        optimizations += self.apply_early_release(graph)?;
        optimizations += self.apply_late_allocation(graph)?;
        optimizations += self.reorder_for_memory(graph)?;

        Ok(optimizations)
    }

    /// Apply early release of tensors
    fn apply_early_release(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        // Insert explicit release operations as soon as tensors are no longer needed
        Ok(0)
    }

    /// Apply late allocation of tensors
    fn apply_late_allocation(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        // Delay tensor allocation until just before they're needed
        Ok(0)
    }

    /// Reorder operations to minimize peak memory usage
    fn reorder_for_memory(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
        // Reorder operations (where dependencies allow) to reduce peak memory
        Ok(0)
    }

    /// Get the current memory analysis
    pub fn get_analysis(&self) -> Option<&MemoryAnalysis> {
        self.analysis.as_ref()
    }
}

impl<F: Float> Default for MemoryOptimizer<F> {
    fn default() -> Self {
        Self::new()
    }
}

/// Configuration for memory optimization
#[derive(Debug, Clone)]
pub struct MemoryOptimizationConfig {
    /// Enable gradient checkpointing
    pub enable_gradient_checkpointing: bool,
    /// Enable memory pooling
    pub enable_memory_pooling: bool,
    /// Enable in-place operations
    pub enable_in_place_operations: bool,
    /// Enable tensor reuse
    pub enable_tensor_reuse: bool,
    /// Enable tensor lifetime optimization
    pub enable_lifetime_optimization: bool,
    /// Memory threshold for checkpointing (bytes)
    pub checkpoint_memory_threshold: usize,
    /// Compute threshold for checkpointing (relative cost)
    pub checkpoint_compute_threshold: f32,
    /// Frequency threshold for creating memory pools
    pub pool_frequency_threshold: usize,
    /// Maximum memory usage target (bytes)
    pub max_memory_usage: Option<usize>,
}

impl Default for MemoryOptimizationConfig {
    fn default() -> Self {
        Self {
            enable_gradient_checkpointing: true,
            enable_memory_pooling: true,
            enable_in_place_operations: true,
            enable_tensor_reuse: true,
            enable_lifetime_optimization: true,
            checkpoint_memory_threshold: 1024 * 1024, // 1MB
            checkpoint_compute_threshold: 2.0,        // 2x recomputation cost
            pool_frequency_threshold: 5,              // At least 5 uses
            max_memory_usage: None,
        }
    }
}

/// Results of memory analysis
#[derive(Debug, Clone, Default)]
pub struct MemoryAnalysis {
    /// Total memory allocated (bytes)
    pub total_memory_allocated: usize,
    /// Peak memory usage (bytes)
    pub peak_memory_usage: usize,
    /// Number of allocations
    pub num_allocations: usize,
    /// Number of deallocations
    pub num_deallocations: usize,
    /// Average tensor size
    pub average_tensor_size: usize,
    /// Largest tensor size
    pub largest_tensor_size: usize,
    /// Memory fragmentation ratio
    pub fragmentation_ratio: f32,
    /// Opportunities for optimization
    pub optimization_opportunities: Vec<String>,
}

impl MemoryAnalysis {
    /// Create a new memory analysis
    pub fn new() -> Self {
        Self::default()
    }

    /// Calculate memory efficiency
    pub fn memory_efficiency(&self) -> f32 {
        if self.total_memory_allocated == 0 {
            return 1.0;
        }
        self.peak_memory_usage as f32 / self.total_memory_allocated as f32
    }

    /// Get allocation/deallocation balance
    pub fn allocation_balance(&self) -> i32 {
        self.num_allocations as i32 - self.num_deallocations as i32
    }
}

/// Report of memory optimization results
#[derive(Debug, Clone, Default)]
pub struct MemoryOptimizationReport {
    /// Number of gradient checkpoints added
    pub gradient_checkpoints_added: usize,
    /// Number of memory pools created
    pub memory_pools_created: usize,
    /// Number of in-place operations applied
    pub in_place_operations_applied: usize,
    /// Number of tensors reused
    pub tensors_reused: usize,
    /// Number of lifetime optimizations
    pub lifetime_optimizations: usize,
    /// Estimated memory savings (bytes)
    pub estimated_memory_savings: usize,
}

impl MemoryOptimizationReport {
    /// Create a new optimization report
    pub fn new() -> Self {
        Self::default()
    }

    /// Get total optimizations applied
    pub fn total_optimizations(&self) -> usize {
        self.gradient_checkpoints_added
            + self.memory_pools_created
            + self.in_place_operations_applied
            + self.tensors_reused
            + self.lifetime_optimizations
    }

    /// Print a summary of the memory optimization results
    pub fn print_summary(&self) {
        println!("Memory Optimization Report:");
        println!("==========================");
        println!("Total optimizations: {}", self.total_optimizations());

        if self.gradient_checkpoints_added > 0 {
            println!(
                "  Gradient checkpoints: {}",
                self.gradient_checkpoints_added
            );
        }
        if self.memory_pools_created > 0 {
            println!("  Memory pools created: {}", self.memory_pools_created);
        }
        if self.in_place_operations_applied > 0 {
            println!(
                "  In-place operations: {}",
                self.in_place_operations_applied
            );
        }
        if self.tensors_reused > 0 {
            println!("  Tensors reused: {}", self.tensors_reused);
        }
        if self.lifetime_optimizations > 0 {
            println!("  Lifetime optimizations: {}", self.lifetime_optimizations);
        }
        if self.estimated_memory_savings > 0 {
            println!(
                "  Estimated memory savings: {} bytes",
                self.estimated_memory_savings
            );
        }
    }
}

/// Candidate for gradient checkpointing
#[derive(Debug)]
pub(crate) struct CheckpointCandidate<F: Float> {
    /// Node to potentially checkpoint
    #[allow(dead_code)]
    pub node: *const crate::tensor::TensorInternal<F>,
    /// Estimated memory savings
    pub memory_savings: usize,
    /// Estimated recomputation cost
    pub recomputation_cost: f32,
    /// Priority for checkpointing
    #[allow(dead_code)]
    pub priority: f32,
}

/// Candidate for in-place operation
#[derive(Debug)]
pub(crate) struct InPlaceCandidate<F: Float> {
    /// Node to convert to in-place
    #[allow(dead_code)]
    pub node: *const crate::tensor::TensorInternal<F>,
    /// Estimated memory savings
    #[allow(dead_code)]
    pub memory_savings: usize,
    /// Safety score (higher is safer)
    #[allow(dead_code)]
    pub safety_score: f32,
}

/// Group of tensors that can reuse memory
#[derive(Debug)]
pub(crate) struct TensorReuseGroup<F: Float> {
    /// Tensors that can share memory
    pub tensors: Vec<*const crate::tensor::TensorInternal<F>>,
    /// Total memory that can be saved
    #[allow(dead_code)]
    pub memory_savings: usize,
}

/// Tensor lifetime analyzer
pub struct TensorLifetimeAnalyzer<F: Float> {
    _phantom: std::marker::PhantomData<F>,
}

impl<F: Float> TensorLifetimeAnalyzer<F> {
    /// Create a new tensor lifetime analyzer
    pub fn new() -> Self {
        Self {
            _phantom: std::marker::PhantomData,
        }
    }

    /// Analyze tensor lifetimes in a graph
    #[allow(dead_code)]
    pub(crate) fn analyze(
        &self,
        graph: &Graph<F>,
    ) -> Result<HashMap<*const crate::tensor::TensorInternal<F>, TensorLifetime>, OptimizationError>
    {
        let lifetimes = HashMap::new();

        // For each tensor, determine:
        // - When it's first created/allocated
        // - When it's last used
        // - Peak memory usage contribution
        // - Overlap with other tensors

        Ok(lifetimes)
    }

    /// Find overlapping tensor lifetimes
    #[allow(dead_code)]
    pub(crate) fn find_overlapping_lifetimes(
        self_lifetimes: &HashMap<*const crate::tensor::TensorInternal<F>, TensorLifetime>,
    ) -> Vec<Vec<*const crate::tensor::TensorInternal<F>>> {
        // Group tensors with overlapping _lifetimes
        // These cannot share memory
        Vec::new()
    }

    /// Find non-overlapping tensor groups
    #[allow(dead_code)]
    pub(crate) fn find_reusable_groups(
        self_lifetimes: &HashMap<*const crate::tensor::TensorInternal<F>, TensorLifetime>,
    ) -> Vec<Vec<*const crate::tensor::TensorInternal<F>>> {
        // Group tensors with non-overlapping _lifetimes
        // These can potentially share memory
        Vec::new()
    }
}

impl<F: Float> Default for TensorLifetimeAnalyzer<F> {
    fn default() -> Self {
        Self::new()
    }
}

/// Information about a tensor's lifetime
#[derive(Debug, Clone)]
pub struct TensorLifetime {
    /// When the tensor is allocated
    pub allocation_time: usize,
    /// When the tensor is last used
    pub deallocation_time: usize,
    /// Size of the tensor
    pub size: usize,
    /// Peak usage during lifetime
    pub peak_usage: usize,
}

impl TensorLifetime {
    /// Check if this lifetime overlaps with another
    pub fn overlaps_with(&self, other: &TensorLifetime) -> bool {
        !(self.deallocation_time <= other.allocation_time
            || other.deallocation_time <= self.allocation_time)
    }

    /// Get the duration of this lifetime
    pub fn duration(&self) -> usize {
        self.deallocation_time.saturating_sub(self.allocation_time)
    }
}

/// Memory pool manager
pub struct MemoryPoolManager<F: Float> {
    /// Pools organized by size
    pools: HashMap<usize, Vec<Vec<F>>>,
    /// Pool usage statistics
    stats: MemoryPoolStats,
}

impl<F: Float> MemoryPoolManager<F> {
    /// Create a new memory pool manager
    pub fn new() -> Self {
        Self {
            pools: HashMap::new(),
            stats: MemoryPoolStats::default(),
        }
    }

    /// Get a buffer from the pool
    pub fn get_buffer(&mut self, size: usize) -> Vec<F> {
        if let Some(pool) = self.pools.get_mut(&size) {
            if let Some(buffer) = pool.pop() {
                self.stats.pool_hits += 1;
                return buffer;
            }
        }

        self.stats.pool_misses += 1;
        vec![F::zero(); size]
    }

    /// Return a buffer to the pool
    pub fn return_buffer(&mut self, mut buffer: Vec<F>) {
        let size = buffer.len();
        buffer.clear();
        buffer.resize(size, F::zero());

        self.pools.entry(size).or_default().push(buffer);
        self.stats.buffers_returned += 1;
    }

    /// Get pool statistics
    pub fn get_stats(&self) -> &MemoryPoolStats {
        &self.stats
    }

    /// Clear all pools
    pub fn clear(&mut self) {
        self.pools.clear();
        self.stats = MemoryPoolStats::default();
    }
}

impl<F: Float> Default for MemoryPoolManager<F> {
    fn default() -> Self {
        Self::new()
    }
}

/// Statistics for memory pools
#[derive(Debug, Clone, Default)]
pub struct MemoryPoolStats {
    /// Number of pool hits
    pub pool_hits: usize,
    /// Number of pool misses
    pub pool_misses: usize,
    /// Number of buffers returned
    pub buffers_returned: usize,
    /// Total memory pooled
    pub total_pooled_memory: usize,
}

impl MemoryPoolStats {
    /// Calculate pool hit ratio
    pub fn hit_ratio(&self) -> f32 {
        let total_requests = self.pool_hits + self.pool_misses;
        if total_requests == 0 {
            return 0.0;
        }
        self.pool_hits as f32 / total_requests as f32
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_memory_optimizer_creation() {
        let _optimizer = MemoryOptimizer::<f32>::new();
        let _optimizer_with_config =
            MemoryOptimizer::<f32>::with_config(MemoryOptimizationConfig::default());
    }

    #[test]
    fn test_memory_optimization_config() {
        let config = MemoryOptimizationConfig::default();
        assert!(config.enable_gradient_checkpointing);
        assert!(config.enable_memory_pooling);
        assert!(config.enable_in_place_operations);
        assert!(config.enable_tensor_reuse);
        assert!(config.enable_lifetime_optimization);
    }

    #[test]
    fn test_memory_analysis() {
        let mut analysis = MemoryAnalysis::new();
        analysis.total_memory_allocated = 1000;
        analysis.peak_memory_usage = 800;
        analysis.num_allocations = 10;
        analysis.num_deallocations = 8;

        assert_eq!(analysis.memory_efficiency(), 0.8);
        assert_eq!(analysis.allocation_balance(), 2);
    }

    #[test]
    fn test_memory_optimization_report() {
        let mut report = MemoryOptimizationReport::new();
        report.gradient_checkpoints_added = 5;
        report.memory_pools_created = 3;
        report.in_place_operations_applied = 10;

        assert_eq!(report.total_optimizations(), 18);
    }

    #[test]
    fn test_tensor_lifetime() {
        let lifetime1 = TensorLifetime {
            allocation_time: 0,
            deallocation_time: 10,
            size: 100,
            peak_usage: 100,
        };

        let lifetime2 = TensorLifetime {
            allocation_time: 5,
            deallocation_time: 15,
            size: 200,
            peak_usage: 200,
        };

        let lifetime3 = TensorLifetime {
            allocation_time: 20,
            deallocation_time: 30,
            size: 150,
            peak_usage: 150,
        };

        assert!(lifetime1.overlaps_with(&lifetime2));
        assert!(!lifetime1.overlaps_with(&lifetime3));
        assert_eq!(lifetime1.duration(), 10);
    }

    #[test]
    fn test_memory_pool_manager() {
        let mut manager = MemoryPoolManager::<f32>::new();

        // Get a buffer
        let buffer = manager.get_buffer(100);
        assert_eq!(buffer.len(), 100);
        assert_eq!(manager.get_stats().pool_misses, 1);

        // Return the buffer
        manager.return_buffer(buffer);
        assert_eq!(manager.get_stats().buffers_returned, 1);

        // Get another buffer of the same size - should come from pool
        let buffer2 = manager.get_buffer(100);
        assert_eq!(buffer2.len(), 100);
        assert_eq!(manager.get_stats().pool_hits, 1);
    }

    #[test]
    fn test_memory_pool_stats() {
        let stats = MemoryPoolStats {
            pool_hits: 8,
            pool_misses: 2,
            ..Default::default()
        };

        assert_eq!(stats.hit_ratio(), 0.8);
    }

    #[test]
    fn test_tensor_lifetime_analyzer() {
        let _analyzer = TensorLifetimeAnalyzer::<f32>::new();
    }
}