quantrs2_tytan/
gpu_memory_pool.rs

1//! GPU memory pooling for efficient allocation and reuse.
2//!
3//! This module provides memory pooling functionality to reduce allocation
4//! overhead in GPU computations, particularly for iterative algorithms.
5
6#![allow(dead_code)]
7
8use std::collections::{HashMap, VecDeque};
9use std::ptr::NonNull;
10use std::sync::{Arc, Mutex};
11
12#[cfg(feature = "scirs")]
13use scirs2_core::gpu;
14
15// Stub for missing GPU functionality
16#[cfg(feature = "scirs")]
17pub struct GpuContext;
18
19#[cfg(feature = "scirs")]
20impl GpuContext {
21    pub fn new(_device_id: u32) -> Result<Self, Box<dyn std::error::Error>> {
22        Ok(Self)
23    }
24}
25
26#[cfg(feature = "scirs")]
27#[derive(Clone, Default)]
28pub struct GpuMemory {
29    id: usize,
30    size: usize,
31}
32
33/// Memory block information
34#[cfg(feature = "scirs")]
35#[derive(Clone)]
36struct MemoryBlock {
37    /// Unique ID for this block
38    id: usize,
39    /// Size in bytes
40    size: usize,
41    /// Whether the block is currently in use
42    in_use: bool,
43    /// Last access time for LRU eviction
44    last_access: std::time::Instant,
45}
46
47/// GPU memory pool for efficient allocation
48pub struct GpuMemoryPool {
49    /// GPU context
50    #[cfg(feature = "scirs")]
51    context: Arc<GpuContext>,
52    /// Pool of memory blocks by size
53    #[cfg(feature = "scirs")]
54    blocks_by_size: HashMap<usize, VecDeque<MemoryBlock>>,
55    /// All allocated blocks
56    #[cfg(feature = "scirs")]
57    all_blocks: Vec<MemoryBlock>,
58    /// Maximum pool size in bytes
59    max_size: usize,
60    /// Current allocated size
61    current_size: usize,
62    /// Allocation statistics
63    stats: AllocationStats,
64    /// Mutex for thread safety
65    mutex: Arc<Mutex<()>>,
66    /// Next block ID
67    next_block_id: usize,
68}
69
70/// Allocation statistics
71#[derive(Default, Clone)]
72pub struct AllocationStats {
73    /// Total allocations
74    pub total_allocations: usize,
75    /// Cache hits (reused blocks)
76    pub cache_hits: usize,
77    /// Cache misses (new allocations)
78    pub cache_misses: usize,
79    /// Total bytes allocated
80    pub total_bytes_allocated: usize,
81    /// Peak memory usage
82    pub peak_memory_usage: usize,
83    /// Number of evictions
84    pub evictions: usize,
85}
86
87#[cfg(feature = "scirs")]
88impl GpuMemoryPool {
89    /// Create a new memory pool
90    pub fn new(context: Arc<GpuContext>, max_size: usize) -> Self {
91        Self {
92            context,
93            blocks_by_size: HashMap::new(),
94            all_blocks: Vec::new(),
95            max_size,
96            current_size: 0,
97            stats: AllocationStats::default(),
98            mutex: Arc::new(Mutex::new(())),
99            next_block_id: 0,
100        }
101    }
102
103    /// Allocate memory from the pool
104    #[cfg(feature = "scirs")]
105    pub fn allocate(&mut self, size: usize) -> Result<GpuMemory, String> {
106        let _lock = self
107            .mutex
108            .lock()
109            .map_err(|e| format!("Failed to acquire lock in allocate: {e}"))?;
110
111        self.stats.total_allocations += 1;
112
113        // Round up to nearest power of 2 for better reuse
114        let aligned_size = size.next_power_of_two();
115
116        // Check if we have a free block of the right size
117        if let Some(blocks) = self.blocks_by_size.get_mut(&aligned_size) {
118            if let Some(mut block) = blocks.pop_front() {
119                if !block.in_use {
120                    block.in_use = true;
121                    block.last_access = std::time::Instant::now();
122                    self.stats.cache_hits += 1;
123
124                    // Update the block in all_blocks
125                    for b in &mut self.all_blocks {
126                        if b.id == block.id {
127                            b.in_use = true;
128                            b.last_access = block.last_access;
129                            break;
130                        }
131                    }
132
133                    return Ok(GpuMemory {
134                        id: block.id,
135                        size: block.size,
136                    });
137                }
138            }
139        }
140
141        // No suitable block found, allocate new
142        self.stats.cache_misses += 1;
143
144        // Check if we need to evict blocks
145        if self.current_size + aligned_size > self.max_size {
146            // Drop the lock before calling evict method
147            drop(_lock);
148            self.evict_lru_blocks(aligned_size)?;
149            // Re-acquire lock
150            let _lock = self
151                .mutex
152                .lock()
153                .map_err(|e| format!("Failed to re-acquire lock after eviction: {e}"))?;
154        }
155
156        // Allocate new block
157        let block_id = self.next_block_id;
158        self.next_block_id += 1;
159
160        let block = MemoryBlock {
161            id: block_id,
162            size: aligned_size,
163            in_use: true,
164            last_access: std::time::Instant::now(),
165        };
166
167        self.all_blocks.push(block);
168        self.current_size += aligned_size;
169        self.stats.total_bytes_allocated += aligned_size;
170
171        if self.current_size > self.stats.peak_memory_usage {
172            self.stats.peak_memory_usage = self.current_size;
173        }
174
175        Ok(GpuMemory {
176            id: block_id,
177            size: aligned_size,
178        })
179    }
180
181    /// Release memory back to the pool
182    #[cfg(feature = "scirs")]
183    pub fn release(&mut self, memory: GpuMemory) {
184        // Use if let to gracefully handle lock poisoning
185        if let Ok(_lock) = self.mutex.lock() {
186            // Find the block and mark it as free
187            for block in &mut self.all_blocks {
188                if block.id == memory.id {
189                    block.in_use = false;
190                    block.last_access = std::time::Instant::now();
191
192                    // Add to the pool for reuse
193                    self.blocks_by_size
194                        .entry(block.size)
195                        .or_default()
196                        .push_back(block.clone());
197
198                    break;
199                }
200            }
201        }
202        // If lock is poisoned, we silently skip releasing to avoid panic
203    }
204
205    /// Evict least recently used blocks to make space
206    #[cfg(feature = "scirs")]
207    fn evict_lru_blocks(&mut self, required_size: usize) -> Result<(), String> {
208        let mut freed_size = 0;
209        let mut blocks_to_evict = Vec::new();
210
211        // Sort blocks by last access time
212        let mut free_blocks: Vec<_> = self.all_blocks.iter().filter(|b| !b.in_use).collect();
213        free_blocks.sort_by_key(|b| b.last_access);
214
215        // Evict blocks until we have enough space
216        for block in free_blocks {
217            if freed_size >= required_size {
218                break;
219            }
220
221            blocks_to_evict.push(block.id);
222            freed_size += block.size;
223            self.stats.evictions += 1;
224        }
225
226        if freed_size < required_size {
227            return Err("Insufficient memory in pool even after eviction".to_string());
228        }
229
230        // Actually evict the blocks
231        for block_id in blocks_to_evict {
232            self.all_blocks.retain(|b| b.id != block_id);
233
234            // Remove from size-based pools
235            for blocks in self.blocks_by_size.values_mut() {
236                blocks.retain(|b| b.id != block_id);
237            }
238
239            // Free GPU memory
240            // TODO: Implement free_raw in GPU stub
241            // unsafe {
242            //     self.context
243            //         .free_raw(ptr)
244            //         .map_err(|e| format!("Failed to free GPU memory: {}", e))?;
245            // }
246        }
247
248        self.current_size -= freed_size;
249
250        Ok(())
251    }
252
253    /// Get allocation statistics
254    pub fn stats(&self) -> AllocationStats {
255        self.stats.clone()
256    }
257
258    /// Clear the entire pool
259    #[cfg(feature = "scirs")]
260    pub fn clear(&mut self) -> Result<(), String> {
261        let _lock = self
262            .mutex
263            .lock()
264            .map_err(|e| format!("Failed to acquire lock in clear: {e}"))?;
265
266        // Clear all blocks (in a real implementation, this would free GPU memory)
267        // For our stub implementation, we just clear the tracking structures
268
269        self.blocks_by_size.clear();
270        self.all_blocks.clear();
271        self.current_size = 0;
272
273        Ok(())
274    }
275
276    /// Defragment the pool to reduce fragmentation
277    #[cfg(feature = "scirs")]
278    pub fn defragment(&mut self) -> Result<(), String> {
279        let _lock = self
280            .mutex
281            .lock()
282            .map_err(|e| format!("Failed to acquire lock in defragment: {e}"))?;
283
284        // This is a complex operation that would involve:
285        // 1. Identifying fragmented regions
286        // 2. Allocating new contiguous blocks
287        // 3. Copying data
288        // 4. Updating pointers
289        // 5. Freeing old blocks
290
291        // For now, we just compact the free block lists
292        for blocks in self.blocks_by_size.values_mut() {
293            blocks.retain(|b| !b.in_use);
294        }
295
296        Ok(())
297    }
298}
299
300/// Scoped memory allocation that automatically returns to pool
301pub struct ScopedGpuMemory {
302    memory: Option<GpuMemory>,
303    pool: Arc<Mutex<GpuMemoryPool>>,
304}
305
306impl ScopedGpuMemory {
307    /// Create a new scoped allocation
308    #[cfg(feature = "scirs")]
309    pub fn new(pool: Arc<Mutex<GpuMemoryPool>>, size: usize) -> Result<Self, String> {
310        let memory = pool
311            .lock()
312            .map_err(|e| format!("Failed to acquire pool lock: {e}"))?
313            .allocate(size)?;
314        Ok(Self {
315            memory: Some(memory),
316            pool,
317        })
318    }
319
320    /// Get the underlying memory
321    ///
322    /// # Panics
323    /// Panics if called after the memory has been released (should never happen in normal use)
324    #[cfg(feature = "scirs")]
325    pub fn memory(&self) -> &GpuMemory {
326        self.memory
327            .as_ref()
328            .expect("ScopedGpuMemory::memory called after memory was released - this is a bug")
329    }
330
331    /// Get mutable access to memory
332    ///
333    /// # Panics
334    /// Panics if called after the memory has been released (should never happen in normal use)
335    #[cfg(feature = "scirs")]
336    pub fn memory_mut(&mut self) -> &mut GpuMemory {
337        self.memory
338            .as_mut()
339            .expect("ScopedGpuMemory::memory_mut called after memory was released - this is a bug")
340    }
341}
342
343#[cfg(feature = "scirs")]
344impl Drop for ScopedGpuMemory {
345    fn drop(&mut self) {
346        if let Some(memory) = self.memory.take() {
347            // Use if let to gracefully handle lock poisoning during drop
348            if let Ok(mut pool) = self.pool.lock() {
349                pool.release(memory);
350            }
351            // If lock is poisoned, we silently skip releasing to avoid panic in Drop
352        }
353    }
354}
355
356/// Memory pool manager for multiple devices
357pub struct MultiDeviceMemoryPool {
358    /// Pools for each device
359    device_pools: HashMap<usize, Arc<Mutex<GpuMemoryPool>>>,
360}
361
362impl Default for MultiDeviceMemoryPool {
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368impl MultiDeviceMemoryPool {
369    /// Create a new multi-device pool
370    pub fn new() -> Self {
371        Self {
372            device_pools: HashMap::new(),
373        }
374    }
375
376    /// Add a device pool
377    #[cfg(feature = "scirs")]
378    pub fn add_device(&mut self, device_id: usize, context: Arc<GpuContext>, max_size: usize) {
379        let pool = Arc::new(Mutex::new(GpuMemoryPool::new(context, max_size)));
380        self.device_pools.insert(device_id, pool);
381    }
382
383    /// Get pool for a device
384    pub fn get_pool(&self, device_id: usize) -> Option<Arc<Mutex<GpuMemoryPool>>> {
385        self.device_pools.get(&device_id).cloned()
386    }
387
388    /// Allocate from a specific device
389    #[cfg(feature = "scirs")]
390    pub fn allocate(&self, device_id: usize, size: usize) -> Result<ScopedGpuMemory, String> {
391        let pool = self
392            .get_pool(device_id)
393            .ok_or_else(|| format!("No pool for device {device_id}"))?;
394
395        ScopedGpuMemory::new(pool, size)
396    }
397
398    /// Get combined statistics
399    ///
400    /// Note: Skips any device pools that cannot be locked (e.g., due to lock poisoning)
401    pub fn combined_stats(&self) -> AllocationStats {
402        let mut combined = AllocationStats::default();
403
404        for pool in self.device_pools.values() {
405            // Use if let to gracefully handle lock poisoning
406            if let Ok(pool_guard) = pool.lock() {
407                let stats = pool_guard.stats();
408                combined.total_allocations += stats.total_allocations;
409                combined.cache_hits += stats.cache_hits;
410                combined.cache_misses += stats.cache_misses;
411                combined.total_bytes_allocated += stats.total_bytes_allocated;
412                combined.peak_memory_usage += stats.peak_memory_usage;
413                combined.evictions += stats.evictions;
414            }
415            // Silently skip pools we can't lock to avoid panic
416        }
417
418        combined
419    }
420}
421
422// Placeholder implementations when SciRS2 is not available
423#[cfg(not(feature = "scirs"))]
424pub struct GpuMemory;
425
426#[cfg(not(feature = "scirs"))]
427impl GpuMemoryPool {
428    pub fn new(_max_size: usize) -> Self {
429        Self {
430            max_size: 0,
431            current_size: 0,
432            stats: AllocationStats::default(),
433            mutex: Arc::new(Mutex::new(())),
434            next_block_id: 0,
435        }
436    }
437
438    pub fn allocate(&mut self, _size: usize) -> Result<GpuMemory, String> {
439        Err("GPU memory pooling requires SciRS2 feature".to_string())
440    }
441
442    pub fn stats(&self) -> AllocationStats {
443        self.stats.clone()
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_allocation_stats() {
453        let stats = AllocationStats {
454            total_allocations: 100,
455            cache_hits: 80,
456            cache_misses: 20,
457            total_bytes_allocated: 1024 * 1024,
458            peak_memory_usage: 512 * 1024,
459            evictions: 5,
460        };
461
462        assert_eq!(stats.total_allocations, 100);
463        assert_eq!(stats.cache_hits, 80);
464
465        let hit_rate = stats.cache_hits as f64 / stats.total_allocations as f64;
466        assert!(hit_rate > 0.79 && hit_rate < 0.81);
467    }
468
469    #[test]
470    #[cfg(feature = "scirs")]
471    fn test_memory_pool_basic() {
472        use crate::gpu_memory_pool::GpuContext;
473
474        let context = Arc::new(GpuContext::new(0).expect("Failed to create GPU context for test"));
475        let mut pool = GpuMemoryPool::new(context, 1024 * 1024); // 1MB pool
476
477        // First allocation should be a cache miss
478        let mem1 = pool
479            .allocate(1024)
480            .expect("First allocation should succeed");
481        assert_eq!(pool.stats().cache_misses, 1);
482        assert_eq!(pool.stats().cache_hits, 0);
483
484        // Release and reallocate should be a cache hit
485        pool.release(mem1);
486        let _mem2 = pool
487            .allocate(1024)
488            .expect("Second allocation should succeed");
489        assert_eq!(pool.stats().cache_misses, 1);
490        assert_eq!(pool.stats().cache_hits, 1);
491    }
492}