quantrs2_sim/
memory_optimization.rs

1//! Advanced Memory Optimization for Quantum Simulation
2//!
3//! This module provides sophisticated memory management strategies to optimize
4//! memory usage patterns for large quantum state vector simulations.
5
6use scirs2_core::Complex64;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11/// Advanced memory pool with intelligent allocation strategies
12#[derive(Debug)]
13pub struct AdvancedMemoryPool {
14    /// Stratified buffers organized by size classes
15    size_pools: RwLock<HashMap<usize, VecDeque<Vec<Complex64>>>>,
16    /// Maximum number of buffers per size class
17    max_buffers_per_size: usize,
18    /// Memory usage statistics
19    stats: Arc<Mutex<MemoryStats>>,
20    /// Automatic cleanup threshold
21    cleanup_threshold: Duration,
22    /// Last cleanup time
23    last_cleanup: Mutex<Instant>,
24}
25
26/// Memory usage statistics for optimization
27#[derive(Debug, Clone, Default)]
28pub struct MemoryStats {
29    /// Total allocations requested
30    pub total_allocations: u64,
31    /// Cache hits (buffer reused)
32    pub cache_hits: u64,
33    /// Cache misses (new allocation)
34    pub cache_misses: u64,
35    /// Peak memory usage in bytes
36    pub peak_memory_bytes: u64,
37    /// Current memory usage in bytes
38    pub current_memory_bytes: u64,
39    /// Total cleanup operations
40    pub cleanup_operations: u64,
41    /// Average allocation size
42    pub average_allocation_size: f64,
43    /// Buffer size distribution
44    pub size_distribution: HashMap<usize, u64>,
45}
46
47impl MemoryStats {
48    /// Calculate cache hit ratio
49    #[must_use]
50    pub fn cache_hit_ratio(&self) -> f64 {
51        if self.total_allocations == 0 {
52            0.0
53        } else {
54            self.cache_hits as f64 / self.total_allocations as f64
55        }
56    }
57
58    /// Update statistics for a new allocation
59    pub fn record_allocation(&mut self, size: usize, cache_hit: bool) {
60        self.total_allocations += 1;
61        if cache_hit {
62            self.cache_hits += 1;
63        } else {
64            self.cache_misses += 1;
65        }
66
67        // Update average allocation size
68        let total_size = self
69            .average_allocation_size
70            .mul_add((self.total_allocations - 1) as f64, size as f64);
71        self.average_allocation_size = total_size / self.total_allocations as f64;
72
73        // Update size distribution
74        *self.size_distribution.entry(size).or_insert(0) += 1;
75
76        // Update memory usage (approximation)
77        let allocation_bytes = size * std::mem::size_of::<Complex64>();
78        self.current_memory_bytes += allocation_bytes as u64;
79        if self.current_memory_bytes > self.peak_memory_bytes {
80            self.peak_memory_bytes = self.current_memory_bytes;
81        }
82    }
83
84    /// Record memory deallocation
85    pub const fn record_deallocation(&mut self, size: usize) {
86        let deallocation_bytes = size * std::mem::size_of::<Complex64>();
87        self.current_memory_bytes = self
88            .current_memory_bytes
89            .saturating_sub(deallocation_bytes as u64);
90    }
91}
92
93impl AdvancedMemoryPool {
94    /// Create new advanced memory pool
95    #[must_use]
96    pub fn new(max_buffers_per_size: usize, cleanup_threshold: Duration) -> Self {
97        Self {
98            size_pools: RwLock::new(HashMap::new()),
99            max_buffers_per_size,
100            stats: Arc::new(Mutex::new(MemoryStats::default())),
101            cleanup_threshold,
102            last_cleanup: Mutex::new(Instant::now()),
103        }
104    }
105
106    /// Get optimal size class for a requested size (power of 2 buckets)
107    const fn get_size_class(size: usize) -> usize {
108        if size <= 64 {
109            64
110        } else if size <= 128 {
111            128
112        } else if size <= 256 {
113            256
114        } else if size <= 512 {
115            512
116        } else if size <= 1024 {
117            1024
118        } else if size <= 2048 {
119            2048
120        } else if size <= 4096 {
121            4096
122        } else if size <= 8192 {
123            8192
124        } else {
125            // For large sizes, round up to next power of 2
126            let mut power = 1;
127            while power < size {
128                power <<= 1;
129            }
130            power
131        }
132    }
133
134    /// Get buffer from pool with intelligent allocation
135    pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
136        let size_class = Self::get_size_class(size);
137        let mut cache_hit = false;
138
139        // Try to get from appropriate size pool
140        let buffer = {
141            let pools = self
142                .size_pools
143                .read()
144                .expect("Size pools read lock poisoned");
145            if let Some(pool) = pools.get(&size_class) {
146                if pool.is_empty() {
147                    None
148                } else {
149                    cache_hit = true;
150                    // Need to get write lock to modify
151                    drop(pools);
152                    let mut pools_write = self
153                        .size_pools
154                        .write()
155                        .expect("Size pools write lock poisoned");
156                    pools_write
157                        .get_mut(&size_class)
158                        .and_then(std::collections::VecDeque::pop_front)
159                }
160            } else {
161                None
162            }
163        };
164
165        let buffer = if let Some(mut buffer) = buffer {
166            // Reuse existing buffer
167            buffer.clear();
168            buffer.resize(size, Complex64::new(0.0, 0.0));
169            buffer
170        } else {
171            // Allocate new buffer with size class capacity
172            let mut buffer = Vec::with_capacity(size_class);
173            buffer.resize(size, Complex64::new(0.0, 0.0));
174            buffer
175        };
176
177        // Update statistics
178        if let Ok(mut stats) = self.stats.lock() {
179            stats.record_allocation(size, cache_hit);
180        }
181
182        // Trigger cleanup if needed
183        self.maybe_cleanup();
184
185        buffer
186    }
187
188    /// Return buffer to appropriate size pool
189    pub fn return_buffer(&self, buffer: Vec<Complex64>) {
190        let capacity = buffer.capacity();
191        let size_class = Self::get_size_class(capacity);
192
193        // Only cache if capacity matches size class to avoid memory waste
194        if capacity == size_class {
195            let mut pools = self
196                .size_pools
197                .write()
198                .expect("Size pools write lock poisoned");
199            let pool = pools.entry(size_class).or_default();
200
201            if pool.len() < self.max_buffers_per_size {
202                pool.push_back(buffer);
203                return;
204            }
205        }
206
207        // Update deallocation stats
208        if let Ok(mut stats) = self.stats.lock() {
209            stats.record_deallocation(capacity);
210        }
211
212        // Buffer will be dropped here if not cached
213    }
214
215    /// Periodic cleanup of unused buffers
216    fn maybe_cleanup(&self) {
217        if let Ok(mut last_cleanup) = self.last_cleanup.try_lock() {
218            if last_cleanup.elapsed() > self.cleanup_threshold {
219                self.cleanup_unused_buffers();
220                *last_cleanup = Instant::now();
221
222                if let Ok(mut stats) = self.stats.lock() {
223                    stats.cleanup_operations += 1;
224                }
225            }
226        }
227    }
228
229    /// Clean up unused buffers to free memory
230    pub fn cleanup_unused_buffers(&self) {
231        let mut pools = self
232            .size_pools
233            .write()
234            .expect("Size pools write lock poisoned");
235        let mut freed_memory = 0u64;
236
237        for (size_class, pool) in pools.iter_mut() {
238            // Keep only half the buffers in each pool during cleanup
239            let target_size = pool.len() / 2;
240            while pool.len() > target_size {
241                if let Some(buffer) = pool.pop_back() {
242                    freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
243                }
244            }
245        }
246
247        // Update memory stats
248        if let Ok(mut stats) = self.stats.lock() {
249            stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
250        }
251    }
252
253    /// Get memory statistics
254    pub fn get_stats(&self) -> MemoryStats {
255        self.stats.lock().expect("Stats lock poisoned").clone()
256    }
257
258    /// Clear all cached buffers
259    pub fn clear(&self) {
260        let mut pools = self
261            .size_pools
262            .write()
263            .expect("Size pools write lock poisoned");
264        let mut freed_memory = 0u64;
265
266        for (_, pool) in pools.iter() {
267            for buffer in pool {
268                freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
269            }
270        }
271
272        pools.clear();
273
274        // Update memory stats
275        if let Ok(mut stats) = self.stats.lock() {
276            stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
277        }
278    }
279}
280
281/// NUMA-aware memory optimization strategies
282pub struct NumaAwareAllocator {
283    /// Node-specific memory pools
284    node_pools: Vec<AdvancedMemoryPool>,
285    /// Current allocation node
286    current_node: Mutex<usize>,
287}
288
289impl NumaAwareAllocator {
290    /// Create NUMA-aware allocator
291    #[must_use]
292    pub fn new(num_nodes: usize, max_buffers_per_size: usize) -> Self {
293        let node_pools = (0..num_nodes)
294            .map(|_| AdvancedMemoryPool::new(max_buffers_per_size, Duration::from_secs(30)))
295            .collect();
296
297        Self {
298            node_pools,
299            current_node: Mutex::new(0),
300        }
301    }
302
303    /// Get buffer from specific NUMA node
304    pub fn get_buffer_from_node(&self, size: usize, node: usize) -> Option<Vec<Complex64>> {
305        if node < self.node_pools.len() {
306            Some(self.node_pools[node].get_buffer(size))
307        } else {
308            None
309        }
310    }
311
312    /// Get buffer with automatic load balancing
313    pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
314        let mut current_node = self
315            .current_node
316            .lock()
317            .expect("Current node lock poisoned");
318        let node = *current_node;
319        *current_node = (*current_node + 1) % self.node_pools.len();
320        drop(current_node);
321
322        self.node_pools[node].get_buffer(size)
323    }
324
325    /// Return buffer to appropriate node
326    pub fn return_buffer(&self, buffer: Vec<Complex64>, preferred_node: Option<usize>) {
327        let node = preferred_node.unwrap_or(0).min(self.node_pools.len() - 1);
328        self.node_pools[node].return_buffer(buffer);
329    }
330
331    /// Get combined statistics from all nodes
332    pub fn get_combined_stats(&self) -> MemoryStats {
333        let mut combined = MemoryStats::default();
334
335        for pool in &self.node_pools {
336            let stats = pool.get_stats();
337            combined.total_allocations += stats.total_allocations;
338            combined.cache_hits += stats.cache_hits;
339            combined.cache_misses += stats.cache_misses;
340            combined.current_memory_bytes += stats.current_memory_bytes;
341            combined.peak_memory_bytes = combined.peak_memory_bytes.max(stats.peak_memory_bytes);
342            combined.cleanup_operations += stats.cleanup_operations;
343
344            // Merge size distributions
345            for (size, count) in stats.size_distribution {
346                *combined.size_distribution.entry(size).or_insert(0) += count;
347            }
348        }
349
350        // Recalculate average allocation size
351        if combined.total_allocations > 0 {
352            let total_size: u64 = combined
353                .size_distribution
354                .iter()
355                .map(|(size, count)| *size as u64 * count)
356                .sum();
357            combined.average_allocation_size =
358                total_size as f64 / combined.total_allocations as f64;
359        }
360
361        combined
362    }
363}
364
365/// Memory optimization utility functions
366pub mod utils {
367    use super::Complex64;
368
369    /// Estimate memory requirements for a given number of qubits
370    #[must_use]
371    pub const fn estimate_memory_requirements(num_qubits: usize) -> u64 {
372        let state_size = 1usize << num_qubits;
373        let bytes_per_amplitude = std::mem::size_of::<Complex64>();
374        let state_memory = state_size * bytes_per_amplitude;
375
376        // Add overhead for temporary buffers (estimated 3x for gates)
377        let overhead_factor = 3;
378        (state_memory * overhead_factor) as u64
379    }
380
381    /// Check if system has sufficient memory for simulation
382    #[must_use]
383    pub const fn check_memory_availability(num_qubits: usize) -> bool {
384        let required_memory = estimate_memory_requirements(num_qubits);
385
386        // Get available system memory (this is a simplified check)
387        // In practice, you'd use system-specific APIs
388        let available_memory = get_available_memory();
389
390        available_memory > required_memory
391    }
392
393    /// Get available system memory (placeholder implementation)
394    const fn get_available_memory() -> u64 {
395        // This would use platform-specific APIs in practice
396        // For now, return a conservative estimate
397        8 * 1024 * 1024 * 1024 // 8 GB
398    }
399
400    /// Optimize buffer size for cache efficiency
401    #[must_use]
402    pub const fn optimize_buffer_size(target_size: usize) -> usize {
403        // Align to cache line size (typically 64 bytes)
404        let cache_line_size = 64;
405        let element_size = std::mem::size_of::<Complex64>();
406        let elements_per_cache_line = cache_line_size / element_size;
407
408        // Round up to nearest multiple of cache line elements
409        target_size.div_ceil(elements_per_cache_line) * elements_per_cache_line
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_advanced_memory_pool() {
419        let pool = AdvancedMemoryPool::new(4, Duration::from_secs(1));
420
421        // Test buffer allocation and reuse
422        let buffer1 = pool.get_buffer(100);
423        assert_eq!(buffer1.len(), 100);
424
425        pool.return_buffer(buffer1);
426
427        let buffer2 = pool.get_buffer(100);
428        assert_eq!(buffer2.len(), 100);
429
430        // Check cache hit ratio
431        let stats = pool.get_stats();
432        assert!(stats.cache_hit_ratio() > 0.0);
433    }
434
435    #[test]
436    fn test_size_class_allocation() {
437        assert_eq!(AdvancedMemoryPool::get_size_class(50), 64);
438        assert_eq!(AdvancedMemoryPool::get_size_class(100), 128);
439        assert_eq!(AdvancedMemoryPool::get_size_class(1000), 1024);
440        assert_eq!(AdvancedMemoryPool::get_size_class(5000), 8192);
441    }
442
443    #[test]
444    fn test_numa_aware_allocator() {
445        let allocator = NumaAwareAllocator::new(2, 4);
446
447        let buffer1 = allocator.get_buffer(100);
448        let buffer2 = allocator.get_buffer(200);
449
450        allocator.return_buffer(buffer1, Some(0));
451        allocator.return_buffer(buffer2, Some(1));
452
453        let stats = allocator.get_combined_stats();
454        assert_eq!(stats.total_allocations, 2);
455    }
456
457    #[test]
458    fn test_memory_estimation() {
459        let memory_4_qubits = utils::estimate_memory_requirements(4);
460        let memory_8_qubits = utils::estimate_memory_requirements(8);
461
462        // 8-qubit simulation should require much more memory than 4-qubit
463        assert!(memory_8_qubits > memory_4_qubits * 10);
464    }
465}