cuda_rust_wasm/memory/
memory_pool.rs

1//! High-performance memory pool for WASM optimization
2//!
3//! This module provides efficient memory allocation patterns optimized for
4//! WASM environments with minimal allocation overhead.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use crate::error::Result;
9
10/// Memory pool configuration for optimal performance
11#[derive(Debug, Clone)]
12pub struct PoolConfig {
13    /// Maximum pool size per allocation class (in bytes)
14    pub max_pool_size: usize,
15    /// Minimum allocation size to pool
16    pub min_pooled_size: usize,
17    /// Maximum allocation size to pool
18    pub max_pooled_size: usize,
19    /// Number of pre-allocated buffers per size class
20    pub prealloc_count: usize,
21}
22
23impl Default for PoolConfig {
24    fn default() -> Self {
25        Self {
26            max_pool_size: 16 * 1024 * 1024, // 16MB max per pool
27            min_pooled_size: 1024,            // 1KB min
28            max_pooled_size: 4 * 1024 * 1024, // 4MB max
29            prealloc_count: 8,                // Pre-allocate 8 buffers
30        }
31    }
32}
33
34/// High-performance memory pool optimized for WASM
35#[derive(Debug)]
36pub struct MemoryPool {
37    /// Pools organized by power-of-2 sizes
38    pools: Arc<Mutex<HashMap<usize, Vec<Vec<u8>>>>>,
39    /// Configuration
40    config: PoolConfig,
41    /// Statistics for performance monitoring
42    stats: Arc<Mutex<PoolStats>>,
43}
44
45/// Performance statistics for the memory pool
46#[derive(Debug, Clone, Default)]
47pub struct PoolStats {
48    /// Total allocations requested
49    pub total_allocations: u64,
50    /// Cache hits (allocations served from pool)
51    pub cache_hits: u64,
52    /// Cache misses (new allocations)
53    pub cache_misses: u64,
54    /// Total bytes allocated
55    pub total_bytes_allocated: u64,
56    /// Total bytes served from pool
57    pub pooled_bytes_served: u64,
58    /// Peak memory usage
59    pub peak_memory_usage: usize,
60    /// Current memory usage
61    pub current_memory_usage: usize,
62}
63
64impl MemoryPool {
65    /// Create a new memory pool with default configuration
66    pub fn new() -> Self {
67        Self::with_config(PoolConfig::default())
68    }
69
70    /// Create a new memory pool with custom configuration
71    pub fn with_config(config: PoolConfig) -> Self {
72        let pool = Self {
73            pools: Arc::new(Mutex::new(HashMap::new())),
74            config,
75            stats: Arc::new(Mutex::new(PoolStats::default())),
76        };
77        
78        // Pre-allocate common sizes
79        pool.preallocate_common_sizes();
80        pool
81    }
82
83    /// Pre-allocate buffers for common sizes to reduce allocation overhead
84    fn preallocate_common_sizes(&self) {
85        let common_sizes = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072];
86        
87        for &size in &common_sizes {
88            if size >= self.config.min_pooled_size && size <= self.config.max_pooled_size {
89                let pool_size = self.round_to_power_of_2(size);
90                let mut pools = self.pools.lock().unwrap();
91                let pool = pools.entry(pool_size).or_default();
92                
93                for _ in 0..self.config.prealloc_count {
94                    pool.push(vec![0; pool_size]);
95                }
96            }
97        }
98    }
99
100    /// Allocate a buffer of the specified size
101    pub fn allocate(&self, size: usize) -> Vec<u8> {
102        let mut stats = self.stats.lock().unwrap();
103        stats.total_allocations += 1;
104        stats.total_bytes_allocated += size as u64;
105
106        // Don't pool very small or very large allocations
107        if size < self.config.min_pooled_size || size > self.config.max_pooled_size {
108            stats.cache_misses += 1;
109            stats.current_memory_usage += size;
110            if stats.current_memory_usage > stats.peak_memory_usage {
111                stats.peak_memory_usage = stats.current_memory_usage;
112            }
113            drop(stats);
114            return vec![0; size];
115        }
116
117        let pool_size = self.round_to_power_of_2(size);
118        let mut pools = self.pools.lock().unwrap();
119        
120        if let Some(pool) = pools.get_mut(&pool_size) {
121            if let Some(mut buffer) = pool.pop() {
122                // Cache hit
123                stats.cache_hits += 1;
124                stats.pooled_bytes_served += pool_size as u64;
125                drop(stats);
126                drop(pools);
127                
128                // Resize buffer to exact size needed
129                buffer.resize(size, 0);
130                return buffer;
131            }
132        }
133
134        // Cache miss - create new buffer
135        stats.cache_misses += 1;
136        stats.current_memory_usage += pool_size;
137        if stats.current_memory_usage > stats.peak_memory_usage {
138            stats.peak_memory_usage = stats.current_memory_usage;
139        }
140        drop(stats);
141        drop(pools);
142        
143        vec![0; size]
144    }
145
146    /// Return a buffer to the pool for reuse
147    pub fn deallocate(&self, mut buffer: Vec<u8>) {
148        let original_size = buffer.len();
149        
150        // Don't pool very small or very large allocations
151        if original_size < self.config.min_pooled_size || original_size > self.config.max_pooled_size {
152            let mut stats = self.stats.lock().unwrap();
153            stats.current_memory_usage = stats.current_memory_usage.saturating_sub(original_size);
154            return;
155        }
156
157        let pool_size = self.round_to_power_of_2(original_size);
158        buffer.resize(pool_size, 0);
159        buffer.clear(); // Clear but keep capacity
160        buffer.resize(pool_size, 0);
161
162        let mut pools = self.pools.lock().unwrap();
163        let pool = pools.entry(pool_size).or_default();
164        
165        // Limit pool size to prevent memory bloat
166        if pool.len() < self.config.max_pool_size / pool_size {
167            pool.push(buffer);
168        } else {
169            // Pool is full, just drop the buffer
170            let mut stats = self.stats.lock().unwrap();
171            stats.current_memory_usage = stats.current_memory_usage.saturating_sub(pool_size);
172        }
173    }
174
175    /// Round size up to the next power of 2 for efficient pooling
176    fn round_to_power_of_2(&self, size: usize) -> usize {
177        if size <= 1 {
178            return 1;
179        }
180        
181        let mut power = 1;
182        while power < size {
183            power <<= 1;
184        }
185        power
186    }
187
188    /// Get current pool statistics
189    pub fn stats(&self) -> PoolStats {
190        self.stats.lock().unwrap().clone()
191    }
192
193    /// Get cache hit ratio as a percentage
194    pub fn hit_ratio(&self) -> f64 {
195        let stats = self.stats.lock().unwrap();
196        if stats.total_allocations == 0 {
197            return 0.0;
198        }
199        (stats.cache_hits as f64 / stats.total_allocations as f64) * 100.0
200    }
201
202    /// Clear all pools and reset statistics
203    pub fn clear(&self) {
204        self.pools.lock().unwrap().clear();
205        let mut stats = self.stats.lock().unwrap();
206        *stats = PoolStats::default();
207    }
208
209    /// Get total memory usage across all pools
210    pub fn total_pooled_memory(&self) -> usize {
211        let pools = self.pools.lock().unwrap();
212        pools.iter()
213            .map(|(&size, pool)| size * pool.len())
214            .sum()
215    }
216
217    /// Shrink pools to release unused memory
218    pub fn shrink_to_fit(&self) {
219        let mut pools = self.pools.lock().unwrap();
220        for pool in pools.values_mut() {
221            pool.shrink_to_fit();
222        }
223    }
224}
225
226impl Default for MemoryPool {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232/// Global memory pool instance for efficient allocation
233static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
234
235/// Get or initialize the global memory pool
236pub fn global_pool() -> &'static MemoryPool {
237    GLOBAL_POOL.get_or_init(MemoryPool::new)
238}
239
240/// Allocate from the global pool
241pub fn allocate(size: usize) -> Vec<u8> {
242    global_pool().allocate(size)
243}
244
245/// Deallocate to the global pool
246pub fn deallocate(buffer: Vec<u8>) {
247    global_pool().deallocate(buffer);
248}
249
250/// Get global pool statistics
251pub fn global_stats() -> PoolStats {
252    global_pool().stats()
253}
254
255/// High-level memory management for kernel operations
256pub struct KernelMemoryManager {
257    pool: Arc<MemoryPool>,
258    allocations: Mutex<HashMap<*const u8, usize>>,
259}
260
261impl KernelMemoryManager {
262    /// Create a new kernel memory manager
263    pub fn new() -> Self {
264        Self {
265            pool: Arc::new(MemoryPool::new()),
266            allocations: Mutex::new(HashMap::new()),
267        }
268    }
269
270    /// Allocate aligned memory for kernel operations
271    pub fn allocate_kernel_memory(&self, size: usize, alignment: usize) -> Result<*mut u8> {
272        // For WASM, alignment is typically handled by the allocator
273        let buffer = self.pool.allocate(size + alignment - 1);
274        let ptr = buffer.as_ptr() as *mut u8;
275        
276        // Store allocation info for tracking
277        {
278            let mut allocations = self.allocations.lock().unwrap();
279            allocations.insert(ptr, size);
280        }
281        
282        // Prevent buffer from being dropped
283        std::mem::forget(buffer);
284        
285        Ok(ptr)
286    }
287
288    /// Deallocate kernel memory
289    /// 
290    /// # Safety
291    /// The caller must ensure that the pointer was allocated by this memory pool
292    /// and is not used after this function returns.
293    pub unsafe fn deallocate_kernel_memory(&self, ptr: *mut u8) -> Result<()> {
294        let size = {
295            let mut allocations = self.allocations.lock().unwrap();
296            allocations.remove(&(ptr as *const u8))
297                .ok_or_else(|| crate::error::CudaRustError::MemoryError("Invalid pointer for deallocation".to_string()))?
298        };
299        
300        // Reconstruct the Vec from the raw pointer
301        let buffer = Vec::from_raw_parts(ptr, size, size);
302        self.pool.deallocate(buffer);
303        
304        Ok(())
305    }
306
307    /// Get total allocated kernel memory
308    pub fn total_kernel_memory(&self) -> usize {
309        let allocations = self.allocations.lock().unwrap();
310        allocations.values().sum()
311    }
312}
313
314impl Default for KernelMemoryManager {
315    fn default() -> Self {
316        Self::new()
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_memory_pool_basic() {
326        let pool = MemoryPool::new();
327        
328        // Test allocation
329        let buffer1 = pool.allocate(1024);
330        assert_eq!(buffer1.len(), 1024);
331        
332        // Test deallocation and reuse
333        pool.deallocate(buffer1);
334        let buffer2 = pool.allocate(1024);
335        assert_eq!(buffer2.len(), 1024);
336        
337        // Should be a cache hit
338        assert!(pool.hit_ratio() > 0.0);
339    }
340
341    #[test]
342    fn test_power_of_2_rounding() {
343        let pool = MemoryPool::new();
344        assert_eq!(pool.round_to_power_of_2(1000), 1024);
345        assert_eq!(pool.round_to_power_of_2(1024), 1024);
346        assert_eq!(pool.round_to_power_of_2(1500), 2048);
347    }
348
349    #[test]
350    fn test_global_pool() {
351        let buffer = allocate(2048);
352        assert_eq!(buffer.len(), 2048);
353        
354        deallocate(buffer);
355        let stats = global_stats();
356        assert!(stats.total_allocations > 0);
357    }
358
359    #[test]
360    fn test_kernel_memory_manager() {
361        let manager = KernelMemoryManager::new();
362        
363        unsafe {
364            let ptr = manager.allocate_kernel_memory(4096, 16).unwrap();
365            assert!(!ptr.is_null());
366            
367            manager.deallocate_kernel_memory(ptr).unwrap();
368        }
369    }
370}