cuda_rust_wasm/memory/
memory_pool.rs1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use crate::error::Result;
9
10#[derive(Debug, Clone)]
12pub struct PoolConfig {
13 pub max_pool_size: usize,
15 pub min_pooled_size: usize,
17 pub max_pooled_size: usize,
19 pub prealloc_count: usize,
21}
22
23impl Default for PoolConfig {
24 fn default() -> Self {
25 Self {
26 max_pool_size: 16 * 1024 * 1024, min_pooled_size: 1024, max_pooled_size: 4 * 1024 * 1024, prealloc_count: 8, }
31 }
32}
33
34#[derive(Debug)]
36pub struct MemoryPool {
37 pools: Arc<Mutex<HashMap<usize, Vec<Vec<u8>>>>>,
39 config: PoolConfig,
41 stats: Arc<Mutex<PoolStats>>,
43}
44
45#[derive(Debug, Clone, Default)]
47pub struct PoolStats {
48 pub total_allocations: u64,
50 pub cache_hits: u64,
52 pub cache_misses: u64,
54 pub total_bytes_allocated: u64,
56 pub pooled_bytes_served: u64,
58 pub peak_memory_usage: usize,
60 pub current_memory_usage: usize,
62}
63
64impl MemoryPool {
65 pub fn new() -> Self {
67 Self::with_config(PoolConfig::default())
68 }
69
70 pub fn with_config(config: PoolConfig) -> Self {
72 let pool = Self {
73 pools: Arc::new(Mutex::new(HashMap::new())),
74 config,
75 stats: Arc::new(Mutex::new(PoolStats::default())),
76 };
77
78 pool.preallocate_common_sizes();
80 pool
81 }
82
83 fn preallocate_common_sizes(&self) {
85 let common_sizes = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072];
86
87 for &size in &common_sizes {
88 if size >= self.config.min_pooled_size && size <= self.config.max_pooled_size {
89 let pool_size = self.round_to_power_of_2(size);
90 let mut pools = self.pools.lock().unwrap();
91 let pool = pools.entry(pool_size).or_default();
92
93 for _ in 0..self.config.prealloc_count {
94 pool.push(vec![0; pool_size]);
95 }
96 }
97 }
98 }
99
100 pub fn allocate(&self, size: usize) -> Vec<u8> {
102 let mut stats = self.stats.lock().unwrap();
103 stats.total_allocations += 1;
104 stats.total_bytes_allocated += size as u64;
105
106 if size < self.config.min_pooled_size || size > self.config.max_pooled_size {
108 stats.cache_misses += 1;
109 stats.current_memory_usage += size;
110 if stats.current_memory_usage > stats.peak_memory_usage {
111 stats.peak_memory_usage = stats.current_memory_usage;
112 }
113 drop(stats);
114 return vec![0; size];
115 }
116
117 let pool_size = self.round_to_power_of_2(size);
118 let mut pools = self.pools.lock().unwrap();
119
120 if let Some(pool) = pools.get_mut(&pool_size) {
121 if let Some(mut buffer) = pool.pop() {
122 stats.cache_hits += 1;
124 stats.pooled_bytes_served += pool_size as u64;
125 drop(stats);
126 drop(pools);
127
128 buffer.resize(size, 0);
130 return buffer;
131 }
132 }
133
134 stats.cache_misses += 1;
136 stats.current_memory_usage += pool_size;
137 if stats.current_memory_usage > stats.peak_memory_usage {
138 stats.peak_memory_usage = stats.current_memory_usage;
139 }
140 drop(stats);
141 drop(pools);
142
143 vec![0; size]
144 }
145
146 pub fn deallocate(&self, mut buffer: Vec<u8>) {
148 let original_size = buffer.len();
149
150 if original_size < self.config.min_pooled_size || original_size > self.config.max_pooled_size {
152 let mut stats = self.stats.lock().unwrap();
153 stats.current_memory_usage = stats.current_memory_usage.saturating_sub(original_size);
154 return;
155 }
156
157 let pool_size = self.round_to_power_of_2(original_size);
158 buffer.resize(pool_size, 0);
159 buffer.clear(); buffer.resize(pool_size, 0);
161
162 let mut pools = self.pools.lock().unwrap();
163 let pool = pools.entry(pool_size).or_default();
164
165 if pool.len() < self.config.max_pool_size / pool_size {
167 pool.push(buffer);
168 } else {
169 let mut stats = self.stats.lock().unwrap();
171 stats.current_memory_usage = stats.current_memory_usage.saturating_sub(pool_size);
172 }
173 }
174
175 fn round_to_power_of_2(&self, size: usize) -> usize {
177 if size <= 1 {
178 return 1;
179 }
180
181 let mut power = 1;
182 while power < size {
183 power <<= 1;
184 }
185 power
186 }
187
188 pub fn stats(&self) -> PoolStats {
190 self.stats.lock().unwrap().clone()
191 }
192
193 pub fn hit_ratio(&self) -> f64 {
195 let stats = self.stats.lock().unwrap();
196 if stats.total_allocations == 0 {
197 return 0.0;
198 }
199 (stats.cache_hits as f64 / stats.total_allocations as f64) * 100.0
200 }
201
202 pub fn clear(&self) {
204 self.pools.lock().unwrap().clear();
205 let mut stats = self.stats.lock().unwrap();
206 *stats = PoolStats::default();
207 }
208
209 pub fn total_pooled_memory(&self) -> usize {
211 let pools = self.pools.lock().unwrap();
212 pools.iter()
213 .map(|(&size, pool)| size * pool.len())
214 .sum()
215 }
216
217 pub fn shrink_to_fit(&self) {
219 let mut pools = self.pools.lock().unwrap();
220 for pool in pools.values_mut() {
221 pool.shrink_to_fit();
222 }
223 }
224}
225
226impl Default for MemoryPool {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
234
235pub fn global_pool() -> &'static MemoryPool {
237 GLOBAL_POOL.get_or_init(MemoryPool::new)
238}
239
240pub fn allocate(size: usize) -> Vec<u8> {
242 global_pool().allocate(size)
243}
244
245pub fn deallocate(buffer: Vec<u8>) {
247 global_pool().deallocate(buffer);
248}
249
250pub fn global_stats() -> PoolStats {
252 global_pool().stats()
253}
254
255pub struct KernelMemoryManager {
257 pool: Arc<MemoryPool>,
258 allocations: Mutex<HashMap<*const u8, usize>>,
259}
260
261impl KernelMemoryManager {
262 pub fn new() -> Self {
264 Self {
265 pool: Arc::new(MemoryPool::new()),
266 allocations: Mutex::new(HashMap::new()),
267 }
268 }
269
270 pub fn allocate_kernel_memory(&self, size: usize, alignment: usize) -> Result<*mut u8> {
272 let buffer = self.pool.allocate(size + alignment - 1);
274 let ptr = buffer.as_ptr() as *mut u8;
275
276 {
278 let mut allocations = self.allocations.lock().unwrap();
279 allocations.insert(ptr, size);
280 }
281
282 std::mem::forget(buffer);
284
285 Ok(ptr)
286 }
287
288 pub unsafe fn deallocate_kernel_memory(&self, ptr: *mut u8) -> Result<()> {
294 let size = {
295 let mut allocations = self.allocations.lock().unwrap();
296 allocations.remove(&(ptr as *const u8))
297 .ok_or_else(|| crate::error::CudaRustError::MemoryError("Invalid pointer for deallocation".to_string()))?
298 };
299
300 let buffer = Vec::from_raw_parts(ptr, size, size);
302 self.pool.deallocate(buffer);
303
304 Ok(())
305 }
306
307 pub fn total_kernel_memory(&self) -> usize {
309 let allocations = self.allocations.lock().unwrap();
310 allocations.values().sum()
311 }
312}
313
314impl Default for KernelMemoryManager {
315 fn default() -> Self {
316 Self::new()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_memory_pool_basic() {
326 let pool = MemoryPool::new();
327
328 let buffer1 = pool.allocate(1024);
330 assert_eq!(buffer1.len(), 1024);
331
332 pool.deallocate(buffer1);
334 let buffer2 = pool.allocate(1024);
335 assert_eq!(buffer2.len(), 1024);
336
337 assert!(pool.hit_ratio() > 0.0);
339 }
340
341 #[test]
342 fn test_power_of_2_rounding() {
343 let pool = MemoryPool::new();
344 assert_eq!(pool.round_to_power_of_2(1000), 1024);
345 assert_eq!(pool.round_to_power_of_2(1024), 1024);
346 assert_eq!(pool.round_to_power_of_2(1500), 2048);
347 }
348
349 #[test]
350 fn test_global_pool() {
351 let buffer = allocate(2048);
352 assert_eq!(buffer.len(), 2048);
353
354 deallocate(buffer);
355 let stats = global_stats();
356 assert!(stats.total_allocations > 0);
357 }
358
359 #[test]
360 fn test_kernel_memory_manager() {
361 let manager = KernelMemoryManager::new();
362
363 unsafe {
364 let ptr = manager.allocate_kernel_memory(4096, 16).unwrap();
365 assert!(!ptr.is_null());
366
367 manager.deallocate_kernel_memory(ptr).unwrap();
368 }
369 }
370}