1use async_trait::async_trait;
4use ferrum_interfaces::memory::{
5 DefragmentationStats, DeviceMemoryManager, MemoryHandle, MemoryHandleInfo, MemoryInfo,
6 MemoryPoolConfig as InterfaceMemoryPoolConfig, MemoryPressure, MemoryTransfer, MemoryType,
7 StreamHandle,
8};
9use ferrum_types::{Device, Result};
10use parking_lot::Mutex;
11use std::collections::{HashMap, VecDeque};
12use tracing::{debug, warn};
13
14#[derive(Debug, Clone)]
16struct MemoryBlock {
17 handle: MemoryHandle,
18 size: usize,
19 is_free: bool,
20 allocated_at: std::time::Instant,
21}
22
23pub struct MemoryPool {
25 device: Device,
26 blocks: Mutex<VecDeque<MemoryBlock>>,
27 free_blocks: Mutex<HashMap<usize, VecDeque<usize>>>, total_allocated: Mutex<usize>,
29 peak_allocated: Mutex<usize>,
30 allocation_count: Mutex<u64>,
31 config: InternalMemoryPoolConfig,
32}
33
34#[derive(Debug, Clone)]
40pub struct InternalMemoryPoolConfig {
41 pub initial_size: usize,
43 pub max_size: usize,
45 pub growth_factor: f32,
47 pub enable_defragmentation: bool,
49 pub min_pooled_size: usize,
51 pub max_pooled_size: usize,
53 pub size_buckets: usize,
55}
56
57impl Default for InternalMemoryPoolConfig {
58 fn default() -> Self {
59 Self {
60 initial_size: 256 * 1024 * 1024, max_size: 8 * 1024 * 1024 * 1024, growth_factor: 1.5,
63 enable_defragmentation: true,
64 min_pooled_size: 256, max_pooled_size: 128 * 1024 * 1024, size_buckets: 64,
67 }
68 }
69}
70
71impl MemoryPool {
72 pub fn new(device: Device, config: InternalMemoryPoolConfig) -> Self {
74 Self {
75 device,
76 blocks: Mutex::new(VecDeque::new()),
77 free_blocks: Mutex::new(HashMap::new()),
78 total_allocated: Mutex::new(0),
79 peak_allocated: Mutex::new(0),
80 allocation_count: Mutex::new(0),
81 config,
82 }
83 }
84
85 pub fn allocate(&self, size: usize) -> Result<MemoryHandle> {
87 let aligned_size = align_size(size, 256); if let Some(handle) = self.try_allocate_from_pool(aligned_size) {
91 return Ok(handle);
92 }
93
94 self.allocate_new_block(aligned_size)
96 }
97
98 pub fn deallocate(&self, handle: MemoryHandle) -> Result<()> {
100 let mut blocks = self.blocks.lock();
101
102 for (index, block) in blocks.iter_mut().enumerate() {
104 if block.handle.id() == handle.id() {
105 block.is_free = true;
106
107 let size = block.size;
109 drop(blocks);
110
111 let mut free_blocks = self.free_blocks.lock();
112 free_blocks.entry(size).or_default().push_back(index);
113
114 debug!("Deallocated block of size {} bytes", size);
115 return Ok(());
116 }
117 }
118
119 warn!(
120 "Attempted to deallocate unknown memory handle: {:?}",
121 handle
122 );
123 Ok(())
124 }
125
126 pub fn stats(&self) -> MemoryInfo {
128 let blocks = self.blocks.lock();
129 let total_allocated = *self.total_allocated.lock();
130
131 let used_memory = blocks
132 .iter()
133 .filter(|b| !b.is_free)
134 .map(|b| b.size)
135 .sum::<usize>();
136
137 let free_memory = blocks
138 .iter()
139 .filter(|b| b.is_free)
140 .map(|b| b.size)
141 .sum::<usize>();
142
143 let fragmentation_ratio = if total_allocated > 0 {
144 let free_blocks_count = blocks.iter().filter(|b| b.is_free).count();
145 free_blocks_count as f32 / blocks.len() as f32
146 } else {
147 0.0
148 };
149
150 MemoryInfo {
151 total_bytes: total_allocated as u64,
152 used_bytes: used_memory as u64,
153 free_bytes: free_memory as u64,
154 reserved_bytes: 0,
155 active_allocations: blocks.iter().filter(|b| !b.is_free).count(),
156 fragmentation_ratio,
157 bandwidth_gbps: None,
158 }
159 }
160
161 pub fn defragment(&self) -> Result<()> {
163 if !self.config.enable_defragmentation {
164 return Ok(());
165 }
166
167 debug!(
168 "Starting memory pool defragmentation for device {:?}",
169 self.device
170 );
171
172 let mut blocks = self.blocks.lock();
174 let mut free_blocks = self.free_blocks.lock();
175
176 blocks.retain(|b| !b.is_free);
178 free_blocks.clear();
179
180 for (index, block) in blocks.iter().enumerate() {
182 if block.is_free {
183 free_blocks.entry(block.size).or_default().push_back(index);
184 }
185 }
186
187 debug!("Memory pool defragmentation completed");
188 Ok(())
189 }
190
191 fn try_allocate_from_pool(&self, size: usize) -> Option<MemoryHandle> {
192 let mut free_blocks = self.free_blocks.lock();
193
194 if let Some(indices) = free_blocks.get_mut(&size) {
196 if let Some(index) = indices.pop_front() {
197 let mut blocks = self.blocks.lock();
198 if let Some(block) = blocks.get_mut(index) {
199 block.is_free = false;
200 return Some(block.handle);
201 }
202 }
203 }
204
205 let mut best_fit: Option<(usize, usize)> = None; for (&block_size, indices) in free_blocks.iter() {
209 if block_size >= size && (best_fit.is_none() || block_size < best_fit.unwrap().0) {
210 if let Some(&index) = indices.front() {
211 best_fit = Some((block_size, index));
212 }
213 }
214 }
215
216 if let Some((block_size, index)) = best_fit {
217 free_blocks.get_mut(&block_size)?.pop_front();
219
220 let mut blocks = self.blocks.lock();
221 if let Some(block) = blocks.get_mut(index) {
222 block.is_free = false;
223 return Some(block.handle);
224 }
225 }
226
227 None
228 }
229
230 fn allocate_new_block(&self, size: usize) -> Result<MemoryHandle> {
231 let current_total = *self.total_allocated.lock();
233 if current_total + size > self.config.max_size {
234 return Err(ferrum_types::FerrumError::backend(format!(
235 "Memory pool size limit exceeded: {} + {} > {}",
236 current_total, size, self.config.max_size
237 )));
238 }
239
240 let handle_id = {
242 let mut count = self.allocation_count.lock();
243 *count += 1;
244 *count
245 };
246
247 let handle = MemoryHandle::new(handle_id);
248
249 let block = MemoryBlock {
251 handle,
252 size,
253 is_free: false,
254 allocated_at: std::time::Instant::now(),
255 };
256
257 let mut blocks = self.blocks.lock();
258 blocks.push_back(block);
259
260 {
262 let mut total = self.total_allocated.lock();
263 *total += size;
264
265 let mut peak = self.peak_allocated.lock();
266 if *total > *peak {
267 *peak = *total;
268 }
269 }
270
271 debug!("Allocated new memory block of size {} bytes", size);
272 Ok(handle)
273 }
274}
275
276#[async_trait]
277impl DeviceMemoryManager for MemoryPool {
278 async fn allocate(&self, size: usize, _device: &Device) -> Result<MemoryHandle> {
279 self.allocate(size)
280 }
281
282 async fn allocate_aligned(
283 &self,
284 size: usize,
285 alignment: usize,
286 _device: &Device,
287 ) -> Result<MemoryHandle> {
288 let aligned_size = align_size(size, alignment);
289 self.allocate(aligned_size)
290 }
291
292 async fn deallocate(&self, handle: MemoryHandle) -> Result<()> {
293 self.deallocate(handle)
294 }
295
296 async fn copy(
297 &self,
298 _src: MemoryHandle,
299 _dst: MemoryHandle,
300 _size: usize,
301 _src_offset: usize,
302 _dst_offset: usize,
303 ) -> Result<()> {
304 Ok(())
306 }
307
308 async fn copy_async(
309 &self,
310 _transfer: MemoryTransfer,
311 _stream: Option<StreamHandle>,
312 ) -> Result<()> {
313 Ok(())
315 }
316
317 async fn memory_info(&self, _device: &Device) -> Result<MemoryInfo> {
318 Ok(self.stats())
319 }
320
321 fn handle_info(&self, handle: MemoryHandle) -> Option<MemoryHandleInfo> {
322 let blocks = self.blocks.lock();
323 blocks
324 .iter()
325 .find(|b| b.handle.id() == handle.id())
326 .map(|block| {
327 MemoryHandleInfo {
328 handle: block.handle,
329 size: block.size,
330 device: self.device.clone(),
331 alignment: 256, allocated_at: block.allocated_at,
333 is_mapped: false,
334 memory_type: MemoryType::General,
335 }
336 })
337 }
338
339 async fn configure_pool(
340 &self,
341 _device: &Device,
342 _config: InterfaceMemoryPoolConfig,
343 ) -> Result<()> {
344 Ok(())
346 }
347
348 async fn defragment(&self, _device: &Device) -> Result<DefragmentationStats> {
349 let before_fragmentation = self.stats().fragmentation_ratio;
350 self.defragment()?;
351 let after_fragmentation = self.stats().fragmentation_ratio;
352
353 Ok(DefragmentationStats {
354 memory_freed: 0, blocks_moved: 0,
356 time_taken_ms: 0,
357 fragmentation_before: before_fragmentation,
358 fragmentation_after: after_fragmentation,
359 })
360 }
361
362 fn set_pressure_callback(&self, _callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
363 }
365}
366
367fn align_size(size: usize, alignment: usize) -> usize {
369 (size + alignment - 1) & !(alignment - 1)
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_align_size() {
378 assert_eq!(align_size(100, 256), 256);
379 assert_eq!(align_size(256, 256), 256);
380 assert_eq!(align_size(257, 256), 512);
381 assert_eq!(align_size(500, 256), 512);
382 assert_eq!(align_size(1, 64), 64);
383 assert_eq!(align_size(64, 64), 64);
384 assert_eq!(align_size(65, 64), 128);
385 }
386
387 #[test]
388 fn test_memory_pool_creation() {
389 let device = Device::CPU;
390 let config = InternalMemoryPoolConfig::default();
391 let pool = MemoryPool::new(device, config);
392
393 let stats = pool.stats();
394 assert_eq!(stats.used_bytes, 0);
395 assert_eq!(stats.active_allocations, 0);
396 }
397
398 #[test]
399 fn test_memory_pool_allocation() {
400 let device = Device::CPU;
401 let config = InternalMemoryPoolConfig::default();
402 let pool = MemoryPool::new(device, config);
403
404 let handle1 = pool.allocate(1024).unwrap();
406 let stats = pool.stats();
407 assert_eq!(stats.active_allocations, 1);
408 assert!(stats.used_bytes > 0);
409
410 let handle2 = pool.allocate(2048).unwrap();
412 let stats = pool.stats();
413 assert_eq!(stats.active_allocations, 2);
414
415 assert_ne!(handle1.id(), handle2.id());
417 }
418
419 #[test]
420 fn test_memory_pool_deallocation() {
421 let device = Device::CPU;
422 let config = InternalMemoryPoolConfig::default();
423 let pool = MemoryPool::new(device, config);
424
425 let handle = pool.allocate(1024).unwrap();
427 assert_eq!(pool.stats().active_allocations, 1);
428
429 pool.deallocate(handle).unwrap();
430 assert_eq!(pool.stats().active_allocations, 0);
431 }
432
433 #[test]
434 fn test_memory_pool_reuse() {
435 let device = Device::CPU;
436 let config = InternalMemoryPoolConfig::default();
437 let pool = MemoryPool::new(device, config);
438
439 let handle1 = pool.allocate(1024).unwrap();
441 pool.deallocate(handle1).unwrap();
442
443 let _handle2 = pool.allocate(1024).unwrap();
445 let stats = pool.stats();
446 assert_eq!(stats.active_allocations, 1);
447 }
448
449 #[test]
450 fn test_memory_pool_size_limit() {
451 let device = Device::CPU;
452 let mut config = InternalMemoryPoolConfig::default();
453 config.max_size = 1024; let pool = MemoryPool::new(device, config);
455
456 let result = pool.allocate(2048);
458 assert!(result.is_err());
459 }
460
461 #[test]
462 fn test_memory_pool_multiple_allocations() {
463 let device = Device::CPU;
464 let config = InternalMemoryPoolConfig::default();
465 let pool = MemoryPool::new(device, config);
466
467 let mut handles = Vec::new();
468 for i in 0..5 {
469 let handle = pool.allocate(1024 * (i + 1)).unwrap();
470 handles.push(handle);
471 }
472
473 let stats = pool.stats();
474 assert_eq!(stats.active_allocations, 5);
475
476 for handle in handles {
478 pool.deallocate(handle).unwrap();
479 }
480
481 let stats = pool.stats();
482 assert_eq!(stats.active_allocations, 0);
483 }
484
485 #[test]
486 fn test_memory_pool_stats() {
487 let device = Device::CPU;
488 let config = InternalMemoryPoolConfig::default();
489 let pool = MemoryPool::new(device, config);
490
491 let stats = pool.stats();
493 assert_eq!(stats.used_bytes, 0);
494 assert_eq!(stats.active_allocations, 0);
495 assert_eq!(stats.fragmentation_ratio, 0.0);
496
497 let _handle1 = pool.allocate(1024).unwrap();
499 let _handle2 = pool.allocate(2048).unwrap();
500
501 let stats = pool.stats();
502 assert!(stats.total_bytes >= 1024 + 2048);
503 assert_eq!(stats.active_allocations, 2);
504 assert!(stats.used_bytes > 0);
505 }
506
507 #[test]
508 fn test_memory_pool_defragment() {
509 let device = Device::CPU;
510 let config = InternalMemoryPoolConfig::default();
511 let pool = MemoryPool::new(device, config);
512
513 let handle1 = pool.allocate(1024).unwrap();
515 let handle2 = pool.allocate(2048).unwrap();
516 let handle3 = pool.allocate(512).unwrap();
517
518 pool.deallocate(handle2).unwrap(); let stats_before = pool.stats();
521 pool.defragment().unwrap();
522 let stats_after = pool.stats();
523
524 assert_eq!(
526 stats_before.active_allocations,
527 stats_after.active_allocations
528 );
529
530 pool.deallocate(handle1).ok();
532 pool.deallocate(handle3).ok();
533 }
534
535 #[tokio::test]
536 async fn test_device_memory_manager_trait() {
537 use ferrum_interfaces::memory::DeviceMemoryManager;
538
539 let device = Device::CPU;
540 let config = InternalMemoryPoolConfig::default();
541 let pool = MemoryPool::new(device.clone(), config);
542
543 let handle = DeviceMemoryManager::allocate(&pool, 1024, &device)
545 .await
546 .unwrap();
547 assert_ne!(handle.id(), 0);
548
549 let aligned_handle = DeviceMemoryManager::allocate_aligned(&pool, 1000, 256, &device)
551 .await
552 .unwrap();
553 assert_ne!(aligned_handle.id(), 0);
554
555 let info = DeviceMemoryManager::memory_info(&pool, &device)
557 .await
558 .unwrap();
559 assert_eq!(info.active_allocations, 2);
560
561 DeviceMemoryManager::deallocate(&pool, handle)
563 .await
564 .unwrap();
565 let info = DeviceMemoryManager::memory_info(&pool, &device)
566 .await
567 .unwrap();
568 assert_eq!(info.active_allocations, 1);
569
570 DeviceMemoryManager::deallocate(&pool, aligned_handle)
572 .await
573 .ok();
574 }
575
576 #[tokio::test]
577 async fn test_device_memory_manager_defragment() {
578 use ferrum_interfaces::memory::DeviceMemoryManager;
579
580 let device = Device::CPU;
581 let config = InternalMemoryPoolConfig::default();
582 let pool = MemoryPool::new(device.clone(), config);
583
584 let _handle1 = DeviceMemoryManager::allocate(&pool, 1024, &device)
586 .await
587 .unwrap();
588 let _handle2 = DeviceMemoryManager::allocate(&pool, 2048, &device)
589 .await
590 .unwrap();
591
592 let defrag_stats = DeviceMemoryManager::defragment(&pool, &device)
594 .await
595 .unwrap();
596 assert!(defrag_stats.fragmentation_before >= 0.0);
597 assert!(defrag_stats.fragmentation_after >= 0.0);
598 }
599
600 #[test]
601 fn test_handle_info() {
602 let device = Device::CPU;
603 let config = InternalMemoryPoolConfig::default();
604 let pool = MemoryPool::new(device, config);
605
606 let handle = pool.allocate(1024).unwrap();
607
608 let info = pool.handle_info(handle);
610 assert!(info.is_some());
611 let info = info.unwrap();
612 assert_eq!(info.handle.id(), handle.id());
613 assert!(info.size >= 1024);
614 assert_eq!(info.alignment, 256);
615 assert!(!info.is_mapped);
616
617 let invalid_handle = MemoryHandle::new(99999);
619 let info = pool.handle_info(invalid_handle);
620 assert!(info.is_none());
621 }
622}