1use async_trait::async_trait;
8use ferrum_types::{Device, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[async_trait]
14pub trait DeviceMemoryManager: Send + Sync {
15 async fn allocate(&self, size: usize, device: &Device) -> Result<MemoryHandle>;
17
18 async fn allocate_aligned(
20 &self,
21 size: usize,
22 alignment: usize,
23 device: &Device,
24 ) -> Result<MemoryHandle>;
25
26 async fn deallocate(&self, handle: MemoryHandle) -> Result<()>;
28
29 async fn copy(
31 &self,
32 src: MemoryHandle,
33 dst: MemoryHandle,
34 size: usize,
35 src_offset: usize,
36 dst_offset: usize,
37 ) -> Result<()>;
38
39 async fn copy_async(
41 &self,
42 transfer: MemoryTransfer,
43 stream: Option<StreamHandle>,
44 ) -> Result<()>;
45
46 async fn memory_info(&self, device: &Device) -> Result<MemoryInfo>;
48
49 fn handle_info(&self, handle: MemoryHandle) -> Option<MemoryHandleInfo>;
51
52 async fn configure_pool(&self, device: &Device, config: MemoryPoolConfig) -> Result<()>;
54
55 async fn defragment(&self, device: &Device) -> Result<DefragmentationStats>;
57
58 fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub struct MemoryHandle(pub u64);
65
66impl MemoryHandle {
67 pub fn new(id: u64) -> Self {
69 Self(id)
70 }
71
72 pub fn id(&self) -> u64 {
74 self.0
75 }
76
77 pub fn is_valid(&self) -> bool {
79 self.0 != 0
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct StreamHandle(pub u64);
86
87impl StreamHandle {
88 pub fn new(id: u64) -> Self {
90 Self(id)
91 }
92
93 pub fn default() -> Self {
95 Self(0)
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct MemoryTransfer {
102 pub src: MemoryHandle,
104 pub dst: MemoryHandle,
106 pub size: usize,
108 pub src_offset: usize,
110 pub dst_offset: usize,
112}
113
114impl MemoryTransfer {
115 pub fn new(src: MemoryHandle, dst: MemoryHandle, size: usize) -> Self {
117 Self {
118 src,
119 dst,
120 size,
121 src_offset: 0,
122 dst_offset: 0,
123 }
124 }
125
126 pub fn with_src_offset(mut self, offset: usize) -> Self {
128 self.src_offset = offset;
129 self
130 }
131
132 pub fn with_dst_offset(mut self, offset: usize) -> Self {
134 self.dst_offset = offset;
135 self
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MemoryInfo {
142 pub total_bytes: u64,
144 pub used_bytes: u64,
146 pub free_bytes: u64,
148 pub reserved_bytes: u64,
150 pub active_allocations: usize,
152 pub fragmentation_ratio: f32,
154 pub bandwidth_gbps: Option<f32>,
156}
157
158impl MemoryInfo {
159 pub fn utilization_percent(&self) -> f32 {
161 if self.total_bytes > 0 {
162 (self.used_bytes as f32 / self.total_bytes as f32) * 100.0
163 } else {
164 0.0
165 }
166 }
167
168 pub fn pressure_level(&self) -> MemoryPressure {
170 let utilization = self.utilization_percent();
171
172 if utilization >= 95.0 {
173 MemoryPressure::Critical
174 } else if utilization >= 85.0 {
175 MemoryPressure::High
176 } else if utilization >= 70.0 {
177 MemoryPressure::Medium
178 } else {
179 MemoryPressure::Low
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct MemoryHandleInfo {
187 pub handle: MemoryHandle,
189 pub size: usize,
191 pub device: Device,
193 pub alignment: usize,
195 pub allocated_at: std::time::Instant,
197 pub is_mapped: bool,
199 pub memory_type: MemoryType,
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
205pub enum MemoryType {
206 General,
208 Tensor,
210 Cache,
212 Temporary,
214 Pinned,
216 Mapped,
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
222pub enum MemoryPressure {
223 Low,
225 Medium,
227 High,
229 Critical,
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct MemoryPoolConfig {
236 pub initial_size: u64,
238 pub max_size: Option<u64>,
240 pub growth_increment: u64,
242 pub enable_auto_expansion: bool,
244 pub alignment: usize,
246 pub pre_allocate: bool,
248 pub enable_stats: bool,
250}
251
252impl Default for MemoryPoolConfig {
253 fn default() -> Self {
254 Self {
255 initial_size: 1024 * 1024 * 1024, max_size: None,
257 growth_increment: 512 * 1024 * 1024, enable_auto_expansion: true,
259 alignment: 256,
260 pre_allocate: false,
261 enable_stats: true,
262 }
263 }
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct DefragmentationStats {
269 pub memory_freed: u64,
271 pub blocks_moved: usize,
273 pub time_taken_ms: u64,
275 pub fragmentation_before: f32,
277 pub fragmentation_after: f32,
279}
280
281#[async_trait]
283pub trait AdvancedMemoryManager: DeviceMemoryManager {
284 async fn map_memory(&self, handle: MemoryHandle, access: MemoryAccess) -> Result<*mut u8>;
286
287 async fn unmap_memory(&self, handle: MemoryHandle) -> Result<()>;
289
290 async fn create_mapping(
292 &self,
293 src_device: &Device,
294 dst_device: &Device,
295 size: usize,
296 ) -> Result<(MemoryHandle, MemoryHandle)>;
297
298 async fn prefetch(&self, handle: MemoryHandle, target_device: &Device) -> Result<()>;
300
301 fn access_stats(&self, handle: MemoryHandle) -> Option<MemoryAccessStats>;
303
304 async fn set_usage_hint(&self, handle: MemoryHandle, hint: MemoryUsageHint) -> Result<()>;
306}
307
308#[derive(Debug, Clone, Copy, PartialEq, Eq)]
310pub enum MemoryAccess {
311 ReadOnly,
313 WriteOnly,
315 ReadWrite,
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
321pub enum MemoryUsageHint {
322 Sequential,
324 Random,
326 ReadMostly,
328 WriteMostly,
330 Temporary,
332 Resident,
334}
335
336#[derive(Debug, Clone)]
338pub struct MemoryAccessStats {
339 pub read_count: u64,
341 pub write_count: u64,
343 pub avg_read_size: usize,
345 pub avg_write_size: usize,
347 pub last_access: std::time::Instant,
349 pub pattern_type: AccessPatternType,
351}
352
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
355pub enum AccessPatternType {
356 Sequential,
358 Random,
360 Burst,
362 Mixed,
364 Unknown,
366}
367
368#[async_trait]
370pub trait StreamManager: Send + Sync {
371 async fn create_stream(&self, device: &Device) -> Result<StreamHandle>;
373
374 async fn destroy_stream(&self, stream: StreamHandle) -> Result<()>;
376
377 async fn synchronize_stream(&self, stream: StreamHandle) -> Result<()>;
379
380 async fn is_stream_ready(&self, stream: StreamHandle) -> Result<bool>;
382
383 fn default_stream(&self, device: &Device) -> StreamHandle;
385
386 async fn record_event(&self, stream: StreamHandle) -> Result<EventHandle>;
388
389 async fn wait_event(&self, stream: StreamHandle, event: EventHandle) -> Result<()>;
391}
392
393#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
395pub struct EventHandle(pub u64);
396
397#[async_trait]
399pub trait MemoryManagerFactory: Send + Sync {
400 async fn create_memory_manager(
402 &self,
403 device: &Device,
404 config: &MemoryManagerConfig,
405 ) -> Result<Box<dyn DeviceMemoryManager>>;
406
407 async fn create_advanced_memory_manager(
409 &self,
410 device: &Device,
411 config: &MemoryManagerConfig,
412 ) -> Result<Box<dyn AdvancedMemoryManager>>;
413
414 async fn create_stream_manager(&self, device: &Device) -> Result<Box<dyn StreamManager>>;
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
420pub struct MemoryManagerConfig {
421 pub pool_configs: HashMap<MemoryType, MemoryPoolConfig>,
423 pub enable_tracking: bool,
425 pub enable_auto_gc: bool,
427 pub gc_threshold: f32,
429 pub enable_debug: bool,
431 pub max_concurrent_transfers: usize,
433}
434
435impl Default for MemoryManagerConfig {
436 fn default() -> Self {
437 let mut pool_configs = HashMap::new();
438 pool_configs.insert(MemoryType::General, MemoryPoolConfig::default());
439
440 Self {
441 pool_configs,
442 enable_tracking: true,
443 enable_auto_gc: true,
444 gc_threshold: 0.85,
445 enable_debug: false,
446 max_concurrent_transfers: 4,
447 }
448 }
449}
450
451pub trait GlobalMemoryMonitor: Send + Sync {
453 fn global_memory_info(&self) -> HashMap<Device, MemoryInfo>;
455
456 fn global_memory_pressure(&self) -> MemoryPressure;
458
459 fn register_manager(&mut self, device: Device, manager: &dyn DeviceMemoryManager);
461
462 fn unregister_manager(&mut self, device: &Device);
464
465 fn set_global_pressure_callback(
467 &mut self,
468 callback: Box<dyn Fn(HashMap<Device, MemoryPressure>) + Send + Sync>,
469 );
470
471 async fn global_gc(&self) -> Result<HashMap<Device, DefragmentationStats>>;
473}
474
475pub trait AllocationStrategy: Send + Sync {
477 fn select_device(
479 &self,
480 size: usize,
481 requirements: &AllocationRequirements,
482 available_devices: &[Device],
483 memory_info: &HashMap<Device, MemoryInfo>,
484 ) -> Option<Device>;
485
486 fn name(&self) -> &str;
488}
489
490#[derive(Debug, Clone)]
492pub struct AllocationRequirements {
493 pub preferred_devices: Vec<Device>,
495 pub memory_type: MemoryType,
497 pub alignment: Option<usize>,
499 pub is_critical: bool,
501 pub expected_lifetime: Option<std::time::Duration>,
503}
504
505pub struct BestFitStrategy;
507
508impl AllocationStrategy for BestFitStrategy {
509 fn select_device(
510 &self,
511 size: usize,
512 requirements: &AllocationRequirements,
513 available_devices: &[Device],
514 memory_info: &HashMap<Device, MemoryInfo>,
515 ) -> Option<Device> {
516 let mut best_device = None;
517 let mut best_score = f32::NEG_INFINITY;
518
519 for device in available_devices {
520 if let Some(info) = memory_info.get(device) {
521 if info.free_bytes < size as u64 {
523 continue;
524 }
525
526 let waste_ratio = (info.free_bytes - size as u64) as f32 / info.total_bytes as f32;
528 let utilization = info.utilization_percent() / 100.0;
529
530 let score = 1.0 - waste_ratio - (utilization - 0.5).abs() * 0.5;
532
533 let preference_bonus = requirements
535 .preferred_devices
536 .iter()
537 .position(|d| d == device)
538 .map(|pos| 1.0 / (pos as f32 + 1.0))
539 .unwrap_or(0.0)
540 * 0.2;
541
542 let final_score = score + preference_bonus;
543
544 if final_score > best_score {
545 best_score = final_score;
546 best_device = Some(device.clone());
547 }
548 }
549 }
550
551 best_device
552 }
553
554 fn name(&self) -> &str {
555 "best_fit"
556 }
557}