Skip to main content

ferrum_runtime/memory/
pool.rs

1//! Memory pool implementation for efficient allocation
2
3use async_trait::async_trait;
4use ferrum_interfaces::memory::{
5    DefragmentationStats, DeviceMemoryManager, MemoryHandle, MemoryHandleInfo, MemoryInfo,
6    MemoryPoolConfig as InterfaceMemoryPoolConfig, MemoryPressure, MemoryTransfer, MemoryType,
7    StreamHandle,
8};
9use ferrum_types::{Device, Result};
10use parking_lot::Mutex;
11use std::collections::{HashMap, VecDeque};
12use tracing::{debug, warn};
13
14/// Memory block in the pool
15#[derive(Debug, Clone)]
16struct MemoryBlock {
17    handle: MemoryHandle,
18    size: usize,
19    is_free: bool,
20    allocated_at: std::time::Instant,
21}
22
23/// Memory pool for efficient allocation/deallocation
24pub struct MemoryPool {
25    device: Device,
26    blocks: Mutex<VecDeque<MemoryBlock>>,
27    free_blocks: Mutex<HashMap<usize, VecDeque<usize>>>, // size -> block indices
28    total_allocated: Mutex<usize>,
29    peak_allocated: Mutex<usize>,
30    allocation_count: Mutex<u64>,
31    config: InternalMemoryPoolConfig,
32}
33
34/// Internal memory pool configuration for runtime implementation
35///
36/// Note: This is distinct from ferrum_interfaces::memory::MemoryPoolConfig which
37/// defines the interface-level configuration. This type contains implementation-specific
38/// details for the memory pool.
39#[derive(Debug, Clone)]
40pub struct InternalMemoryPoolConfig {
41    /// Initial pool size in bytes
42    pub initial_size: usize,
43    /// Maximum pool size in bytes  
44    pub max_size: usize,
45    /// Growth factor when expanding pool
46    pub growth_factor: f32,
47    /// Whether to enable automatic defragmentation
48    pub enable_defragmentation: bool,
49    /// Minimum block size to pool
50    pub min_pooled_size: usize,
51    /// Maximum block size to pool
52    pub max_pooled_size: usize,
53    /// Number of buckets for size-based pooling
54    pub size_buckets: usize,
55}
56
57impl Default for InternalMemoryPoolConfig {
58    fn default() -> Self {
59        Self {
60            initial_size: 256 * 1024 * 1024,  // 256MB
61            max_size: 8 * 1024 * 1024 * 1024, // 8GB
62            growth_factor: 1.5,
63            enable_defragmentation: true,
64            min_pooled_size: 256,               // 256B
65            max_pooled_size: 128 * 1024 * 1024, // 128MB
66            size_buckets: 64,
67        }
68    }
69}
70
71impl MemoryPool {
72    /// Create new memory pool
73    pub fn new(device: Device, config: InternalMemoryPoolConfig) -> Self {
74        Self {
75            device,
76            blocks: Mutex::new(VecDeque::new()),
77            free_blocks: Mutex::new(HashMap::new()),
78            total_allocated: Mutex::new(0),
79            peak_allocated: Mutex::new(0),
80            allocation_count: Mutex::new(0),
81            config,
82        }
83    }
84
85    /// Allocate memory from pool
86    pub fn allocate(&self, size: usize) -> Result<MemoryHandle> {
87        let aligned_size = align_size(size, 256); // 256-byte alignment
88
89        // Try to find a free block of appropriate size
90        if let Some(handle) = self.try_allocate_from_pool(aligned_size) {
91            return Ok(handle);
92        }
93
94        // Allocate new block
95        self.allocate_new_block(aligned_size)
96    }
97
98    /// Deallocate memory back to pool
99    pub fn deallocate(&self, handle: MemoryHandle) -> Result<()> {
100        let mut blocks = self.blocks.lock();
101
102        // Find the block and mark it as free
103        for (index, block) in blocks.iter_mut().enumerate() {
104            if block.handle.id() == handle.id() {
105                block.is_free = true;
106
107                // Add to free blocks index
108                let size = block.size;
109                drop(blocks);
110
111                let mut free_blocks = self.free_blocks.lock();
112                free_blocks.entry(size).or_default().push_back(index);
113
114                debug!("Deallocated block of size {} bytes", size);
115                return Ok(());
116            }
117        }
118
119        warn!(
120            "Attempted to deallocate unknown memory handle: {:?}",
121            handle
122        );
123        Ok(())
124    }
125
126    /// Get memory statistics
127    pub fn stats(&self) -> MemoryInfo {
128        let blocks = self.blocks.lock();
129        let total_allocated = *self.total_allocated.lock();
130
131        let used_memory = blocks
132            .iter()
133            .filter(|b| !b.is_free)
134            .map(|b| b.size)
135            .sum::<usize>();
136
137        let free_memory = blocks
138            .iter()
139            .filter(|b| b.is_free)
140            .map(|b| b.size)
141            .sum::<usize>();
142
143        let fragmentation_ratio = if total_allocated > 0 {
144            let free_blocks_count = blocks.iter().filter(|b| b.is_free).count();
145            free_blocks_count as f32 / blocks.len() as f32
146        } else {
147            0.0
148        };
149
150        MemoryInfo {
151            total_bytes: total_allocated as u64,
152            used_bytes: used_memory as u64,
153            free_bytes: free_memory as u64,
154            reserved_bytes: 0,
155            active_allocations: blocks.iter().filter(|b| !b.is_free).count(),
156            fragmentation_ratio,
157            bandwidth_gbps: None,
158        }
159    }
160
161    /// Defragment memory pool
162    pub fn defragment(&self) -> Result<()> {
163        if !self.config.enable_defragmentation {
164            return Ok(());
165        }
166
167        debug!(
168            "Starting memory pool defragmentation for device {:?}",
169            self.device
170        );
171
172        // Simple defragmentation: compact free blocks
173        let mut blocks = self.blocks.lock();
174        let mut free_blocks = self.free_blocks.lock();
175
176        // Remove freed blocks and rebuild free index
177        blocks.retain(|b| !b.is_free);
178        free_blocks.clear();
179
180        // Rebuild free blocks index
181        for (index, block) in blocks.iter().enumerate() {
182            if block.is_free {
183                free_blocks.entry(block.size).or_default().push_back(index);
184            }
185        }
186
187        debug!("Memory pool defragmentation completed");
188        Ok(())
189    }
190
191    fn try_allocate_from_pool(&self, size: usize) -> Option<MemoryHandle> {
192        let mut free_blocks = self.free_blocks.lock();
193
194        // Look for exact size match first
195        if let Some(indices) = free_blocks.get_mut(&size) {
196            if let Some(index) = indices.pop_front() {
197                let mut blocks = self.blocks.lock();
198                if let Some(block) = blocks.get_mut(index) {
199                    block.is_free = false;
200                    return Some(block.handle);
201                }
202            }
203        }
204
205        // Look for larger blocks that can be split
206        let mut best_fit: Option<(usize, usize)> = None; // (size, index)
207
208        for (&block_size, indices) in free_blocks.iter() {
209            if block_size >= size && (best_fit.is_none() || block_size < best_fit.unwrap().0) {
210                if let Some(&index) = indices.front() {
211                    best_fit = Some((block_size, index));
212                }
213            }
214        }
215
216        if let Some((block_size, index)) = best_fit {
217            // Remove from free list
218            free_blocks.get_mut(&block_size)?.pop_front();
219
220            let mut blocks = self.blocks.lock();
221            if let Some(block) = blocks.get_mut(index) {
222                block.is_free = false;
223                return Some(block.handle);
224            }
225        }
226
227        None
228    }
229
230    fn allocate_new_block(&self, size: usize) -> Result<MemoryHandle> {
231        // Check if we would exceed max pool size
232        let current_total = *self.total_allocated.lock();
233        if current_total + size > self.config.max_size {
234            return Err(ferrum_types::FerrumError::backend(format!(
235                "Memory pool size limit exceeded: {} + {} > {}",
236                current_total, size, self.config.max_size
237            )));
238        }
239
240        // Create new memory handle (simplified - real implementation would allocate actual memory)
241        let handle_id = {
242            let mut count = self.allocation_count.lock();
243            *count += 1;
244            *count
245        };
246
247        let handle = MemoryHandle::new(handle_id);
248
249        // Add to blocks
250        let block = MemoryBlock {
251            handle,
252            size,
253            is_free: false,
254            allocated_at: std::time::Instant::now(),
255        };
256
257        let mut blocks = self.blocks.lock();
258        blocks.push_back(block);
259
260        // Update statistics
261        {
262            let mut total = self.total_allocated.lock();
263            *total += size;
264
265            let mut peak = self.peak_allocated.lock();
266            if *total > *peak {
267                *peak = *total;
268            }
269        }
270
271        debug!("Allocated new memory block of size {} bytes", size);
272        Ok(handle)
273    }
274}
275
276#[async_trait]
277impl DeviceMemoryManager for MemoryPool {
278    async fn allocate(&self, size: usize, _device: &Device) -> Result<MemoryHandle> {
279        self.allocate(size)
280    }
281
282    async fn allocate_aligned(
283        &self,
284        size: usize,
285        alignment: usize,
286        _device: &Device,
287    ) -> Result<MemoryHandle> {
288        let aligned_size = align_size(size, alignment);
289        self.allocate(aligned_size)
290    }
291
292    async fn deallocate(&self, handle: MemoryHandle) -> Result<()> {
293        self.deallocate(handle)
294    }
295
296    async fn copy(
297        &self,
298        _src: MemoryHandle,
299        _dst: MemoryHandle,
300        _size: usize,
301        _src_offset: usize,
302        _dst_offset: usize,
303    ) -> Result<()> {
304        // Simplified implementation - real version would do actual copy
305        Ok(())
306    }
307
308    async fn copy_async(
309        &self,
310        _transfer: MemoryTransfer,
311        _stream: Option<StreamHandle>,
312    ) -> Result<()> {
313        // Simplified implementation
314        Ok(())
315    }
316
317    async fn memory_info(&self, _device: &Device) -> Result<MemoryInfo> {
318        Ok(self.stats())
319    }
320
321    fn handle_info(&self, handle: MemoryHandle) -> Option<MemoryHandleInfo> {
322        let blocks = self.blocks.lock();
323        blocks
324            .iter()
325            .find(|b| b.handle.id() == handle.id())
326            .map(|block| {
327                MemoryHandleInfo {
328                    handle: block.handle,
329                    size: block.size,
330                    device: self.device.clone(),
331                    alignment: 256, // Default alignment
332                    allocated_at: block.allocated_at,
333                    is_mapped: false,
334                    memory_type: MemoryType::General,
335                }
336            })
337    }
338
339    async fn configure_pool(
340        &self,
341        _device: &Device,
342        _config: InterfaceMemoryPoolConfig,
343    ) -> Result<()> {
344        // For now, pool config is set at construction
345        Ok(())
346    }
347
348    async fn defragment(&self, _device: &Device) -> Result<DefragmentationStats> {
349        let before_fragmentation = self.stats().fragmentation_ratio;
350        self.defragment()?;
351        let after_fragmentation = self.stats().fragmentation_ratio;
352
353        Ok(DefragmentationStats {
354            memory_freed: 0, // Simplified
355            blocks_moved: 0,
356            time_taken_ms: 0,
357            fragmentation_before: before_fragmentation,
358            fragmentation_after: after_fragmentation,
359        })
360    }
361
362    fn set_pressure_callback(&self, _callback: Box<dyn Fn(MemoryPressure) + Send + Sync>) {
363        // Simplified - real implementation would store and use callback
364    }
365}
366
367/// Align size to specified boundary
368fn align_size(size: usize, alignment: usize) -> usize {
369    (size + alignment - 1) & !(alignment - 1)
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_align_size() {
378        assert_eq!(align_size(100, 256), 256);
379        assert_eq!(align_size(256, 256), 256);
380        assert_eq!(align_size(257, 256), 512);
381        assert_eq!(align_size(500, 256), 512);
382        assert_eq!(align_size(1, 64), 64);
383        assert_eq!(align_size(64, 64), 64);
384        assert_eq!(align_size(65, 64), 128);
385    }
386
387    #[test]
388    fn test_memory_pool_creation() {
389        let device = Device::CPU;
390        let config = InternalMemoryPoolConfig::default();
391        let pool = MemoryPool::new(device, config);
392
393        let stats = pool.stats();
394        assert_eq!(stats.used_bytes, 0);
395        assert_eq!(stats.active_allocations, 0);
396    }
397
398    #[test]
399    fn test_memory_pool_allocation() {
400        let device = Device::CPU;
401        let config = InternalMemoryPoolConfig::default();
402        let pool = MemoryPool::new(device, config);
403
404        // Allocate some memory
405        let handle1 = pool.allocate(1024).unwrap();
406        let stats = pool.stats();
407        assert_eq!(stats.active_allocations, 1);
408        assert!(stats.used_bytes > 0);
409
410        // Allocate more memory
411        let handle2 = pool.allocate(2048).unwrap();
412        let stats = pool.stats();
413        assert_eq!(stats.active_allocations, 2);
414
415        // Verify handles are different
416        assert_ne!(handle1.id(), handle2.id());
417    }
418
419    #[test]
420    fn test_memory_pool_deallocation() {
421        let device = Device::CPU;
422        let config = InternalMemoryPoolConfig::default();
423        let pool = MemoryPool::new(device, config);
424
425        // Allocate and deallocate
426        let handle = pool.allocate(1024).unwrap();
427        assert_eq!(pool.stats().active_allocations, 1);
428
429        pool.deallocate(handle).unwrap();
430        assert_eq!(pool.stats().active_allocations, 0);
431    }
432
433    #[test]
434    fn test_memory_pool_reuse() {
435        let device = Device::CPU;
436        let config = InternalMemoryPoolConfig::default();
437        let pool = MemoryPool::new(device, config);
438
439        // Allocate and deallocate
440        let handle1 = pool.allocate(1024).unwrap();
441        pool.deallocate(handle1).unwrap();
442
443        // Allocate again with same size - should reuse
444        let _handle2 = pool.allocate(1024).unwrap();
445        let stats = pool.stats();
446        assert_eq!(stats.active_allocations, 1);
447    }
448
449    #[test]
450    fn test_memory_pool_size_limit() {
451        let device = Device::CPU;
452        let mut config = InternalMemoryPoolConfig::default();
453        config.max_size = 1024; // Very small limit
454        let pool = MemoryPool::new(device, config);
455
456        // Try to allocate more than the limit
457        let result = pool.allocate(2048);
458        assert!(result.is_err());
459    }
460
461    #[test]
462    fn test_memory_pool_multiple_allocations() {
463        let device = Device::CPU;
464        let config = InternalMemoryPoolConfig::default();
465        let pool = MemoryPool::new(device, config);
466
467        let mut handles = Vec::new();
468        for i in 0..5 {
469            let handle = pool.allocate(1024 * (i + 1)).unwrap();
470            handles.push(handle);
471        }
472
473        let stats = pool.stats();
474        assert_eq!(stats.active_allocations, 5);
475
476        // Deallocate all
477        for handle in handles {
478            pool.deallocate(handle).unwrap();
479        }
480
481        let stats = pool.stats();
482        assert_eq!(stats.active_allocations, 0);
483    }
484
485    #[test]
486    fn test_memory_pool_stats() {
487        let device = Device::CPU;
488        let config = InternalMemoryPoolConfig::default();
489        let pool = MemoryPool::new(device, config);
490
491        // Initially empty
492        let stats = pool.stats();
493        assert_eq!(stats.used_bytes, 0);
494        assert_eq!(stats.active_allocations, 0);
495        assert_eq!(stats.fragmentation_ratio, 0.0);
496
497        // After allocations
498        let _handle1 = pool.allocate(1024).unwrap();
499        let _handle2 = pool.allocate(2048).unwrap();
500
501        let stats = pool.stats();
502        assert!(stats.total_bytes >= 1024 + 2048);
503        assert_eq!(stats.active_allocations, 2);
504        assert!(stats.used_bytes > 0);
505    }
506
507    #[test]
508    fn test_memory_pool_defragment() {
509        let device = Device::CPU;
510        let config = InternalMemoryPoolConfig::default();
511        let pool = MemoryPool::new(device, config);
512
513        // Allocate and deallocate to create fragmentation
514        let handle1 = pool.allocate(1024).unwrap();
515        let handle2 = pool.allocate(2048).unwrap();
516        let handle3 = pool.allocate(512).unwrap();
517
518        pool.deallocate(handle2).unwrap(); // Free middle block
519
520        let stats_before = pool.stats();
521        pool.defragment().unwrap();
522        let stats_after = pool.stats();
523
524        // After defragmentation, we should still have the same allocations
525        assert_eq!(
526            stats_before.active_allocations,
527            stats_after.active_allocations
528        );
529
530        // Clean up
531        pool.deallocate(handle1).ok();
532        pool.deallocate(handle3).ok();
533    }
534
535    #[tokio::test]
536    async fn test_device_memory_manager_trait() {
537        use ferrum_interfaces::memory::DeviceMemoryManager;
538
539        let device = Device::CPU;
540        let config = InternalMemoryPoolConfig::default();
541        let pool = MemoryPool::new(device.clone(), config);
542
543        // Test async allocate via trait
544        let handle = DeviceMemoryManager::allocate(&pool, 1024, &device)
545            .await
546            .unwrap();
547        assert_ne!(handle.id(), 0);
548
549        // Test aligned allocation
550        let aligned_handle = DeviceMemoryManager::allocate_aligned(&pool, 1000, 256, &device)
551            .await
552            .unwrap();
553        assert_ne!(aligned_handle.id(), 0);
554
555        // Test memory info
556        let info = DeviceMemoryManager::memory_info(&pool, &device)
557            .await
558            .unwrap();
559        assert_eq!(info.active_allocations, 2);
560
561        // Test deallocate
562        DeviceMemoryManager::deallocate(&pool, handle)
563            .await
564            .unwrap();
565        let info = DeviceMemoryManager::memory_info(&pool, &device)
566            .await
567            .unwrap();
568        assert_eq!(info.active_allocations, 1);
569
570        // Clean up
571        DeviceMemoryManager::deallocate(&pool, aligned_handle)
572            .await
573            .ok();
574    }
575
576    #[tokio::test]
577    async fn test_device_memory_manager_defragment() {
578        use ferrum_interfaces::memory::DeviceMemoryManager;
579
580        let device = Device::CPU;
581        let config = InternalMemoryPoolConfig::default();
582        let pool = MemoryPool::new(device.clone(), config);
583
584        // Allocate some memory
585        let _handle1 = DeviceMemoryManager::allocate(&pool, 1024, &device)
586            .await
587            .unwrap();
588        let _handle2 = DeviceMemoryManager::allocate(&pool, 2048, &device)
589            .await
590            .unwrap();
591
592        // Test defragmentation
593        let defrag_stats = DeviceMemoryManager::defragment(&pool, &device)
594            .await
595            .unwrap();
596        assert!(defrag_stats.fragmentation_before >= 0.0);
597        assert!(defrag_stats.fragmentation_after >= 0.0);
598    }
599
600    #[test]
601    fn test_handle_info() {
602        let device = Device::CPU;
603        let config = InternalMemoryPoolConfig::default();
604        let pool = MemoryPool::new(device, config);
605
606        let handle = pool.allocate(1024).unwrap();
607
608        // Get handle info
609        let info = pool.handle_info(handle);
610        assert!(info.is_some());
611        let info = info.unwrap();
612        assert_eq!(info.handle.id(), handle.id());
613        assert!(info.size >= 1024);
614        assert_eq!(info.alignment, 256);
615        assert!(!info.is_mapped);
616
617        // Test with invalid handle
618        let invalid_handle = MemoryHandle::new(99999);
619        let info = pool.handle_info(invalid_handle);
620        assert!(info.is_none());
621    }
622}