use async_trait::async_trait;
use ferrum_types::{Device, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait]
pub trait DeviceMemoryManager: Send + Sync {
async fn allocate(&self, size: usize, device: &Device) -> Result<MemoryHandle>;
async fn allocate_aligned(
&self,
size: usize,
alignment: usize,
device: &Device,
) -> Result<MemoryHandle>;
async fn deallocate(&self, handle: MemoryHandle) -> Result<()>;
async fn copy(
&self,
src: MemoryHandle,
dst: MemoryHandle,
size: usize,
src_offset: usize,
dst_offset: usize,
) -> Result<()>;
async fn copy_async(
&self,
transfer: MemoryTransfer,
stream: Option<StreamHandle>,
) -> Result<()>;
async fn memory_info(&self, device: &Device) -> Result<MemoryInfo>;
fn handle_info(&self, handle: MemoryHandle) -> Option<MemoryHandleInfo>;
async fn configure_pool(&self, device: &Device, config: MemoryPoolConfig) -> Result<()>;
async fn defragment(&self, device: &Device) -> Result<DefragmentationStats>;
fn set_pressure_callback(&self, callback: Box<dyn Fn(MemoryPressure) + Send + Sync>);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MemoryHandle(pub u64);
impl MemoryHandle {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn id(&self) -> u64 {
self.0
}
pub fn is_valid(&self) -> bool {
self.0 != 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamHandle(pub u64);
impl StreamHandle {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn default() -> Self {
Self(0)
}
}
#[derive(Debug, Clone)]
pub struct MemoryTransfer {
pub src: MemoryHandle,
pub dst: MemoryHandle,
pub size: usize,
pub src_offset: usize,
pub dst_offset: usize,
}
impl MemoryTransfer {
pub fn new(src: MemoryHandle, dst: MemoryHandle, size: usize) -> Self {
Self {
src,
dst,
size,
src_offset: 0,
dst_offset: 0,
}
}
pub fn with_src_offset(mut self, offset: usize) -> Self {
self.src_offset = offset;
self
}
pub fn with_dst_offset(mut self, offset: usize) -> Self {
self.dst_offset = offset;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryInfo {
pub total_bytes: u64,
pub used_bytes: u64,
pub free_bytes: u64,
pub reserved_bytes: u64,
pub active_allocations: usize,
pub fragmentation_ratio: f32,
pub bandwidth_gbps: Option<f32>,
}
impl MemoryInfo {
pub fn utilization_percent(&self) -> f32 {
if self.total_bytes > 0 {
(self.used_bytes as f32 / self.total_bytes as f32) * 100.0
} else {
0.0
}
}
pub fn pressure_level(&self) -> MemoryPressure {
let utilization = self.utilization_percent();
if utilization >= 95.0 {
MemoryPressure::Critical
} else if utilization >= 85.0 {
MemoryPressure::High
} else if utilization >= 70.0 {
MemoryPressure::Medium
} else {
MemoryPressure::Low
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryHandleInfo {
pub handle: MemoryHandle,
pub size: usize,
pub device: Device,
pub alignment: usize,
pub allocated_at: std::time::Instant,
pub is_mapped: bool,
pub memory_type: MemoryType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MemoryType {
General,
Tensor,
Cache,
Temporary,
Pinned,
Mapped,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum MemoryPressure {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPoolConfig {
pub initial_size: u64,
pub max_size: Option<u64>,
pub growth_increment: u64,
pub enable_auto_expansion: bool,
pub alignment: usize,
pub pre_allocate: bool,
pub enable_stats: bool,
}
impl Default for MemoryPoolConfig {
fn default() -> Self {
Self {
initial_size: 1024 * 1024 * 1024, max_size: None,
growth_increment: 512 * 1024 * 1024, enable_auto_expansion: true,
alignment: 256,
pre_allocate: false,
enable_stats: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DefragmentationStats {
pub memory_freed: u64,
pub blocks_moved: usize,
pub time_taken_ms: u64,
pub fragmentation_before: f32,
pub fragmentation_after: f32,
}
#[async_trait]
pub trait AdvancedMemoryManager: DeviceMemoryManager {
async fn map_memory(&self, handle: MemoryHandle, access: MemoryAccess) -> Result<*mut u8>;
async fn unmap_memory(&self, handle: MemoryHandle) -> Result<()>;
async fn create_mapping(
&self,
src_device: &Device,
dst_device: &Device,
size: usize,
) -> Result<(MemoryHandle, MemoryHandle)>;
async fn prefetch(&self, handle: MemoryHandle, target_device: &Device) -> Result<()>;
fn access_stats(&self, handle: MemoryHandle) -> Option<MemoryAccessStats>;
async fn set_usage_hint(&self, handle: MemoryHandle, hint: MemoryUsageHint) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryAccess {
ReadOnly,
WriteOnly,
ReadWrite,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryUsageHint {
Sequential,
Random,
ReadMostly,
WriteMostly,
Temporary,
Resident,
}
#[derive(Debug, Clone)]
pub struct MemoryAccessStats {
pub read_count: u64,
pub write_count: u64,
pub avg_read_size: usize,
pub avg_write_size: usize,
pub last_access: std::time::Instant,
pub pattern_type: AccessPatternType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AccessPatternType {
Sequential,
Random,
Burst,
Mixed,
Unknown,
}
#[async_trait]
pub trait StreamManager: Send + Sync {
async fn create_stream(&self, device: &Device) -> Result<StreamHandle>;
async fn destroy_stream(&self, stream: StreamHandle) -> Result<()>;
async fn synchronize_stream(&self, stream: StreamHandle) -> Result<()>;
async fn is_stream_ready(&self, stream: StreamHandle) -> Result<bool>;
fn default_stream(&self, device: &Device) -> StreamHandle;
async fn record_event(&self, stream: StreamHandle) -> Result<EventHandle>;
async fn wait_event(&self, stream: StreamHandle, event: EventHandle) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EventHandle(pub u64);
#[async_trait]
pub trait MemoryManagerFactory: Send + Sync {
async fn create_memory_manager(
&self,
device: &Device,
config: &MemoryManagerConfig,
) -> Result<Box<dyn DeviceMemoryManager>>;
async fn create_advanced_memory_manager(
&self,
device: &Device,
config: &MemoryManagerConfig,
) -> Result<Box<dyn AdvancedMemoryManager>>;
async fn create_stream_manager(&self, device: &Device) -> Result<Box<dyn StreamManager>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryManagerConfig {
pub pool_configs: HashMap<MemoryType, MemoryPoolConfig>,
pub enable_tracking: bool,
pub enable_auto_gc: bool,
pub gc_threshold: f32,
pub enable_debug: bool,
pub max_concurrent_transfers: usize,
}
impl Default for MemoryManagerConfig {
fn default() -> Self {
let mut pool_configs = HashMap::new();
pool_configs.insert(MemoryType::General, MemoryPoolConfig::default());
Self {
pool_configs,
enable_tracking: true,
enable_auto_gc: true,
gc_threshold: 0.85,
enable_debug: false,
max_concurrent_transfers: 4,
}
}
}
pub trait GlobalMemoryMonitor: Send + Sync {
fn global_memory_info(&self) -> HashMap<Device, MemoryInfo>;
fn global_memory_pressure(&self) -> MemoryPressure;
fn register_manager(&mut self, device: Device, manager: &dyn DeviceMemoryManager);
fn unregister_manager(&mut self, device: &Device);
fn set_global_pressure_callback(
&mut self,
callback: Box<dyn Fn(HashMap<Device, MemoryPressure>) + Send + Sync>,
);
async fn global_gc(&self) -> Result<HashMap<Device, DefragmentationStats>>;
}
pub trait AllocationStrategy: Send + Sync {
fn select_device(
&self,
size: usize,
requirements: &AllocationRequirements,
available_devices: &[Device],
memory_info: &HashMap<Device, MemoryInfo>,
) -> Option<Device>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct AllocationRequirements {
pub preferred_devices: Vec<Device>,
pub memory_type: MemoryType,
pub alignment: Option<usize>,
pub is_critical: bool,
pub expected_lifetime: Option<std::time::Duration>,
}
pub struct BestFitStrategy;
impl AllocationStrategy for BestFitStrategy {
fn select_device(
&self,
size: usize,
requirements: &AllocationRequirements,
available_devices: &[Device],
memory_info: &HashMap<Device, MemoryInfo>,
) -> Option<Device> {
let mut best_device = None;
let mut best_score = f32::NEG_INFINITY;
for device in available_devices {
if let Some(info) = memory_info.get(device) {
if info.free_bytes < size as u64 {
continue;
}
let waste_ratio = (info.free_bytes - size as u64) as f32 / info.total_bytes as f32;
let utilization = info.utilization_percent() / 100.0;
let score = 1.0 - waste_ratio - (utilization - 0.5).abs() * 0.5;
let preference_bonus = requirements
.preferred_devices
.iter()
.position(|d| d == device)
.map(|pos| 1.0 / (pos as f32 + 1.0))
.unwrap_or(0.0)
* 0.2;
let final_score = score + preference_bonus;
if final_score > best_score {
best_score = final_score;
best_device = Some(device.clone());
}
}
}
best_device
}
fn name(&self) -> &str {
"best_fit"
}
}