Skip to main content

ferrum_interfaces/
memory.rs

1//! Memory management interfaces for device memory operations
2//!
3//! This module provides device memory management abstractions, separate from
4//! KV cache management. It handles raw memory allocation, transfers, and
5//! memory pool management across different devices.
6
7use async_trait::async_trait;
8use ferrum_types::{Device, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Device memory manager for raw memory operations
13#[async_trait]
14pub trait DeviceMemoryManager: Send + Sync {
15    /// Allocate memory on device
16    async fn allocate(&self, size: usize, device: &Device) -> Result<MemoryHandle>;
17
18    /// Allocate aligned memory
19    async fn allocate_aligned(
20        &self,
21        size: usize,
22        alignment: usize,
23        device: &Device,
24    ) -> Result<MemoryHandle>;
25
26    /// Deallocate memory
27    async fn deallocate(&self, handle: MemoryHandle) -> Result<()>;
28
29    /// Copy memory between handles
30    async fn copy(
31        &self,
32        src: MemoryHandle,
33        dst: MemoryHandle,
34        size: usize,
35        src_offset: usize,
36        dst_offset: usize,
37    ) -> Result<()>;
38
39    /// Copy memory between devices asynchronously
40    async fn copy_async(
41        &self,
42        transfer: MemoryTransfer,
43        stream: Option<StreamHandle>,
44    ) -> Result<()>;
45
46    /// Get memory information for device
47    async fn memory_info(&self, device: &Device) -> Result<MemoryInfo>;
48
49    /// Get handle information
50    fn handle_info(&self, handle: MemoryHandle) -> Option<MemoryHandleInfo>;
51
52    /// Set memory pool configuration
53    async fn configure_pool(&self, device: &Device, config: MemoryPoolConfig) -> Result<()>;
54
55    /// Defragment memory (if supported)
56    async fn defragment(&self, device: &Device) -> Result<DefragmentationStats>;
57
58    /// Set memory pressure callback
59    fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
60}
61
62/// Memory handle representing allocated memory
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub struct MemoryHandle(pub u64);
65
66impl MemoryHandle {
67    /// Create new memory handle
68    pub fn new(id: u64) -> Self {
69        Self(id)
70    }
71
72    /// Get handle ID
73    pub fn id(&self) -> u64 {
74        self.0
75    }
76
77    /// Check if handle is valid (non-zero)
78    pub fn is_valid(&self) -> bool {
79        self.0 != 0
80    }
81}
82
83/// Stream handle for asynchronous operations
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct StreamHandle(pub u64);
86
87impl StreamHandle {
88    /// Create new stream handle
89    pub fn new(id: u64) -> Self {
90        Self(id)
91    }
92
93    /// Get default stream (usually synchronous)
94    pub fn default() -> Self {
95        Self(0)
96    }
97}
98
99/// Memory transfer specification
100#[derive(Debug, Clone)]
101pub struct MemoryTransfer {
102    /// Source memory handle
103    pub src: MemoryHandle,
104    /// Destination memory handle
105    pub dst: MemoryHandle,
106    /// Number of bytes to transfer
107    pub size: usize,
108    /// Offset in source memory
109    pub src_offset: usize,
110    /// Offset in destination memory
111    pub dst_offset: usize,
112}
113
114impl MemoryTransfer {
115    /// Create new memory transfer
116    pub fn new(src: MemoryHandle, dst: MemoryHandle, size: usize) -> Self {
117        Self {
118            src,
119            dst,
120            size,
121            src_offset: 0,
122            dst_offset: 0,
123        }
124    }
125
126    /// Set source offset
127    pub fn with_src_offset(mut self, offset: usize) -> Self {
128        self.src_offset = offset;
129        self
130    }
131
132    /// Set destination offset
133    pub fn with_dst_offset(mut self, offset: usize) -> Self {
134        self.dst_offset = offset;
135        self
136    }
137}
138
139/// Memory information for a device
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MemoryInfo {
142    /// Total memory available on device (bytes)
143    pub total_bytes: u64,
144    /// Currently used memory (bytes)
145    pub used_bytes: u64,
146    /// Free memory available (bytes)
147    pub free_bytes: u64,
148    /// Memory reserved by the runtime/driver (bytes)
149    pub reserved_bytes: u64,
150    /// Number of active allocations
151    pub active_allocations: usize,
152    /// Memory fragmentation ratio (0.0 - 1.0)
153    pub fragmentation_ratio: f32,
154    /// Memory bandwidth (GB/s)
155    pub bandwidth_gbps: Option<f32>,
156}
157
158impl MemoryInfo {
159    /// Calculate memory utilization percentage
160    pub fn utilization_percent(&self) -> f32 {
161        if self.total_bytes > 0 {
162            (self.used_bytes as f32 / self.total_bytes as f32) * 100.0
163        } else {
164            0.0
165        }
166    }
167
168    /// Check if memory is under pressure
169    pub fn pressure_level(&self) -> MemoryPressure {
170        let utilization = self.utilization_percent();
171
172        if utilization >= 95.0 {
173            MemoryPressure::Critical
174        } else if utilization >= 85.0 {
175            MemoryPressure::High
176        } else if utilization >= 70.0 {
177            MemoryPressure::Medium
178        } else {
179            MemoryPressure::Low
180        }
181    }
182}
183
184/// Information about a memory handle
185#[derive(Debug, Clone)]
186pub struct MemoryHandleInfo {
187    /// Memory handle
188    pub handle: MemoryHandle,
189    /// Size in bytes
190    pub size: usize,
191    /// Device where memory is allocated
192    pub device: Device,
193    /// Memory alignment
194    pub alignment: usize,
195    /// Allocation timestamp
196    pub allocated_at: std::time::Instant,
197    /// Whether memory is currently mapped
198    pub is_mapped: bool,
199    /// Memory type/usage hint
200    pub memory_type: MemoryType,
201}
202
203/// Memory types for different use cases
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
205pub enum MemoryType {
206    /// General purpose memory
207    General,
208    /// Memory optimized for tensor operations
209    Tensor,
210    /// Memory for KV cache
211    Cache,
212    /// Temporary/scratch memory
213    Temporary,
214    /// Pinned/page-locked memory for fast transfers
215    Pinned,
216    /// Mapped memory (shared between devices)
217    Mapped,
218}
219
220/// Memory pressure levels
221#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
222pub enum MemoryPressure {
223    /// Low pressure - plenty of memory available
224    Low,
225    /// Medium pressure - should be conservative
226    Medium,
227    /// High pressure - consider cleanup/eviction
228    High,
229    /// Critical pressure - must free memory or reject requests
230    Critical,
231}
232
233/// Memory pool configuration
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct MemoryPoolConfig {
236    /// Initial pool size in bytes
237    pub initial_size: u64,
238    /// Maximum pool size in bytes (None for unlimited)
239    pub max_size: Option<u64>,
240    /// Growth increment when expanding pool
241    pub growth_increment: u64,
242    /// Enable automatic pool expansion
243    pub enable_auto_expansion: bool,
244    /// Memory alignment for pool allocations
245    pub alignment: usize,
246    /// Pre-allocate entire pool upfront
247    pub pre_allocate: bool,
248    /// Enable pool statistics tracking
249    pub enable_stats: bool,
250}
251
252impl Default for MemoryPoolConfig {
253    fn default() -> Self {
254        Self {
255            initial_size: 1024 * 1024 * 1024, // 1GB
256            max_size: None,
257            growth_increment: 512 * 1024 * 1024, // 512MB
258            enable_auto_expansion: true,
259            alignment: 256,
260            pre_allocate: false,
261            enable_stats: true,
262        }
263    }
264}
265
266/// Defragmentation statistics
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct DefragmentationStats {
269    /// Memory freed by defragmentation (bytes)
270    pub memory_freed: u64,
271    /// Number of memory blocks moved
272    pub blocks_moved: usize,
273    /// Time taken for defragmentation
274    pub time_taken_ms: u64,
275    /// Fragmentation ratio before defragmentation
276    pub fragmentation_before: f32,
277    /// Fragmentation ratio after defragmentation
278    pub fragmentation_after: f32,
279}
280
281/// Advanced memory operations
282#[async_trait]
283pub trait AdvancedMemoryManager: DeviceMemoryManager {
284    /// Map memory for direct CPU access
285    async fn map_memory(&self, handle: MemoryHandle, access: MemoryAccess) -> Result<*mut u8>;
286
287    /// Unmap previously mapped memory
288    async fn unmap_memory(&self, handle: MemoryHandle) -> Result<()>;
289
290    /// Create memory mapping between devices
291    async fn create_mapping(
292        &self,
293        src_device: &Device,
294        dst_device: &Device,
295        size: usize,
296    ) -> Result<(MemoryHandle, MemoryHandle)>;
297
298    /// Enable memory prefetching
299    async fn prefetch(&self, handle: MemoryHandle, target_device: &Device) -> Result<()>;
300
301    /// Get memory access pattern statistics
302    fn access_stats(&self, handle: MemoryHandle) -> Option<MemoryAccessStats>;
303
304    /// Set memory usage hints
305    async fn set_usage_hint(&self, handle: MemoryHandle, hint: MemoryUsageHint) -> Result<()>;
306}
307
308/// Memory access modes
309#[derive(Debug, Clone, Copy, PartialEq, Eq)]
310pub enum MemoryAccess {
311    /// Read-only access
312    ReadOnly,
313    /// Write-only access
314    WriteOnly,
315    /// Read-write access
316    ReadWrite,
317}
318
319/// Memory usage hints for optimization
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
321pub enum MemoryUsageHint {
322    /// Memory will be accessed sequentially
323    Sequential,
324    /// Memory will be accessed randomly
325    Random,
326    /// Memory will be read frequently
327    ReadMostly,
328    /// Memory will be written frequently
329    WriteMostly,
330    /// Memory is temporary and can be freed aggressively
331    Temporary,
332    /// Memory should be kept resident
333    Resident,
334}
335
336/// Memory access pattern statistics
337#[derive(Debug, Clone)]
338pub struct MemoryAccessStats {
339    /// Total number of reads
340    pub read_count: u64,
341    /// Total number of writes
342    pub write_count: u64,
343    /// Average read size
344    pub avg_read_size: usize,
345    /// Average write size
346    pub avg_write_size: usize,
347    /// Last access timestamp
348    pub last_access: std::time::Instant,
349    /// Access pattern type (detected)
350    pub pattern_type: AccessPatternType,
351}
352
353/// Detected access patterns
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
355pub enum AccessPatternType {
356    /// Sequential access pattern
357    Sequential,
358    /// Random access pattern
359    Random,
360    /// Burst access pattern
361    Burst,
362    /// Mixed access pattern
363    Mixed,
364    /// Unknown/undetected pattern
365    Unknown,
366}
367
368/// Stream manager for asynchronous operations
369#[async_trait]
370pub trait StreamManager: Send + Sync {
371    /// Create new compute stream
372    async fn create_stream(&self, device: &Device) -> Result<StreamHandle>;
373
374    /// Destroy stream
375    async fn destroy_stream(&self, stream: StreamHandle) -> Result<()>;
376
377    /// Synchronize stream (wait for all operations to complete)
378    async fn synchronize_stream(&self, stream: StreamHandle) -> Result<()>;
379
380    /// Check if stream operations are complete
381    async fn is_stream_ready(&self, stream: StreamHandle) -> Result<bool>;
382
383    /// Get default stream for device
384    fn default_stream(&self, device: &Device) -> StreamHandle;
385
386    /// Record synchronization point
387    async fn record_event(&self, stream: StreamHandle) -> Result<EventHandle>;
388
389    /// Wait for event on stream
390    async fn wait_event(&self, stream: StreamHandle, event: EventHandle) -> Result<()>;
391}
392
393/// Event handle for stream synchronization
394#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
395pub struct EventHandle(pub u64);
396
397/// Memory manager factory
398#[async_trait]
399pub trait MemoryManagerFactory: Send + Sync {
400    /// Create memory manager for device
401    async fn create_memory_manager(
402        &self,
403        device: &Device,
404        config: &MemoryManagerConfig,
405    ) -> Result<Box<dyn DeviceMemoryManager>>;
406
407    /// Create advanced memory manager
408    async fn create_advanced_memory_manager(
409        &self,
410        device: &Device,
411        config: &MemoryManagerConfig,
412    ) -> Result<Box<dyn AdvancedMemoryManager>>;
413
414    /// Create stream manager
415    async fn create_stream_manager(&self, device: &Device) -> Result<Box<dyn StreamManager>>;
416}
417
418/// Memory manager configuration
419#[derive(Debug, Clone, Serialize, Deserialize)]
420pub struct MemoryManagerConfig {
421    /// Memory pool configurations per memory type
422    pub pool_configs: HashMap<MemoryType, MemoryPoolConfig>,
423    /// Enable memory tracking and statistics
424    pub enable_tracking: bool,
425    /// Enable automatic garbage collection
426    pub enable_auto_gc: bool,
427    /// Garbage collection trigger threshold
428    pub gc_threshold: f32,
429    /// Enable memory debugging
430    pub enable_debug: bool,
431    /// Maximum number of concurrent transfers
432    pub max_concurrent_transfers: usize,
433}
434
435impl Default for MemoryManagerConfig {
436    fn default() -> Self {
437        let mut pool_configs = HashMap::new();
438        pool_configs.insert(MemoryType::General, MemoryPoolConfig::default());
439
440        Self {
441            pool_configs,
442            enable_tracking: true,
443            enable_auto_gc: true,
444            gc_threshold: 0.85,
445            enable_debug: false,
446            max_concurrent_transfers: 4,
447        }
448    }
449}
450
451/// Global memory monitor for system-wide memory tracking
452pub trait GlobalMemoryMonitor: Send + Sync {
453    /// Get memory information across all devices
454    fn global_memory_info(&self) -> HashMap<Device, MemoryInfo>;
455
456    /// Get total system memory pressure
457    fn global_memory_pressure(&self) -> MemoryPressure;
458
459    /// Register memory manager for monitoring
460    fn register_manager(&mut self, device: Device, manager: &dyn DeviceMemoryManager);
461
462    /// Unregister memory manager
463    fn unregister_manager(&mut self, device: &Device);
464
465    /// Set global memory pressure callback
466    fn set_global_pressure_callback(
467        &mut self,
468        callback: Box<dyn Fn(HashMap<Device, MemoryPressure>) + Send + Sync>,
469    );
470
471    /// Force global garbage collection
472    async fn global_gc(&self) -> Result<HashMap<Device, DefragmentationStats>>;
473}
474
475/// Memory allocation strategy
476pub trait AllocationStrategy: Send + Sync {
477    /// Select best device for allocation
478    fn select_device(
479        &self,
480        size: usize,
481        requirements: &AllocationRequirements,
482        available_devices: &[Device],
483        memory_info: &HashMap<Device, MemoryInfo>,
484    ) -> Option<Device>;
485
486    /// Get strategy name
487    fn name(&self) -> &str;
488}
489
490/// Requirements for memory allocation
491#[derive(Debug, Clone)]
492pub struct AllocationRequirements {
493    /// Preferred devices in order
494    pub preferred_devices: Vec<Device>,
495    /// Memory type hint
496    pub memory_type: MemoryType,
497    /// Required alignment
498    pub alignment: Option<usize>,
499    /// Whether allocation is time-critical
500    pub is_critical: bool,
501    /// Expected lifetime
502    pub expected_lifetime: Option<std::time::Duration>,
503}
504
505/// Best-fit allocation strategy
506pub struct BestFitStrategy;
507
508impl AllocationStrategy for BestFitStrategy {
509    fn select_device(
510        &self,
511        size: usize,
512        requirements: &AllocationRequirements,
513        available_devices: &[Device],
514        memory_info: &HashMap<Device, MemoryInfo>,
515    ) -> Option<Device> {
516        let mut best_device = None;
517        let mut best_score = f32::NEG_INFINITY;
518
519        for device in available_devices {
520            if let Some(info) = memory_info.get(device) {
521                // Check if device has enough memory
522                if info.free_bytes < size as u64 {
523                    continue;
524                }
525
526                // Prefer devices with just enough memory (best fit)
527                let waste_ratio = (info.free_bytes - size as u64) as f32 / info.total_bytes as f32;
528                let utilization = info.utilization_percent() / 100.0;
529
530                // Score based on minimal waste and moderate utilization
531                let score = 1.0 - waste_ratio - (utilization - 0.5).abs() * 0.5;
532
533                // Bonus for preferred devices
534                let preference_bonus = requirements
535                    .preferred_devices
536                    .iter()
537                    .position(|d| d == device)
538                    .map(|pos| 1.0 / (pos as f32 + 1.0))
539                    .unwrap_or(0.0)
540                    * 0.2;
541
542                let final_score = score + preference_bonus;
543
544                if final_score > best_score {
545                    best_score = final_score;
546                    best_device = Some(device.clone());
547                }
548            }
549        }
550
551        best_device
552    }
553
554    fn name(&self) -> &str {
555        "best_fit"
556    }
557}