use anyhow::{bail, Result};
use candle_core::Device;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum MemoryPressure {
Low = 0,
Moderate = 1,
High = 2,
Critical = 3,
}
impl MemoryPressure {
pub fn from_ratio(free_ratio: f32) -> Self {
match free_ratio {
r if r > 0.80 => MemoryPressure::Low,
r if r > 0.60 => MemoryPressure::Moderate,
r if r > 0.40 => MemoryPressure::High,
_ => MemoryPressure::Critical,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerPlacement {
GPU,
CPU,
InTransit,
}
#[derive(Debug, Clone, Default)]
pub struct IsomorphicStats {
pub transfers: usize,
pub bytes_transferred: u64,
pub avg_transfer_time_ms: f64,
pub peak_vram_used: u64,
pub offload_triggers: usize,
pub preload_triggers: usize,
}
pub struct MemoryMonitor {
gpu_device_id: usize,
total_vram: u64,
last_free_vram: u64,
peak_used_vram: u64,
#[allow(dead_code)]
low_threshold: u64,
#[allow(dead_code)]
high_threshold: u64,
}
impl MemoryMonitor {
pub fn new(gpu_device_id: usize, vram_threshold_bytes: u64) -> Result<Self> {
let (free, total) = crate::device_utils::get_vram_info(gpu_device_id)?;
if total == 0 {
bail!("No GPU VRAM detected. CPU-only mode.");
}
let total = total as u64;
let low_threshold = vram_threshold_bytes;
let high_threshold = (total as f64 * 0.75) as u64;
Ok(Self {
gpu_device_id,
total_vram: total,
last_free_vram: free as u64,
peak_used_vram: 0,
low_threshold,
high_threshold,
})
}
pub fn update(&mut self) -> Result<(u64, MemoryPressure)> {
let (free, _) = crate::device_utils::get_vram_info(self.gpu_device_id)?;
let free = free as u64;
self.last_free_vram = free;
let used = self.total_vram.saturating_sub(free);
if used > self.peak_used_vram {
self.peak_used_vram = used;
}
let free_ratio = free as f32 / self.total_vram as f32;
let pressure = MemoryPressure::from_ratio(free_ratio);
debug!(
"Memory update: {} MB free / {} MB total ({:.1}%) - {:?}",
free / 1024 / 1024,
self.total_vram / 1024 / 1024,
free_ratio * 100.0,
pressure
);
Ok((free, pressure))
}
pub fn pressure(&mut self) -> Result<MemoryPressure> {
self.update().map(|(_, p)| p)
}
pub fn get_status(&mut self) -> Result<(u64, u64, u64)> {
let (free, _) = crate::device_utils::get_vram_info(self.gpu_device_id)?;
let free = free as u64;
let used = self.total_vram.saturating_sub(free);
Ok((free, used, self.peak_used_vram))
}
}
#[derive(Debug, Clone)]
pub struct LayerTransferPlan {
pub layer_id: usize,
pub from_device: LayerPlacement,
pub to_device: LayerPlacement,
pub estimated_bytes: u64,
}
pub enum OffloadStrategy {
GPUOnly,
Hybrid { n_gpu: usize },
Dynamic {
min_free_vram: u64,
staging_reserve: u64,
},
Adaptive { profile_file: String },
}
pub struct IsomorphicOffloader {
gpu_device: Device,
cpu_device: Device,
strategy: OffloadStrategy,
monitor: MemoryMonitor,
placement: HashMap<usize, LayerPlacement>,
stats: IsomorphicStats,
}
impl IsomorphicOffloader {
pub fn new_dynamic(min_free_vram_bytes: u64, staging_reserve_bytes: u64) -> Result<Self> {
let gpu_device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
let cpu_device = Device::Cpu;
let monitor = MemoryMonitor::new(0, min_free_vram_bytes)?;
Ok(Self {
gpu_device,
cpu_device,
strategy: OffloadStrategy::Dynamic {
min_free_vram: min_free_vram_bytes,
staging_reserve: staging_reserve_bytes,
},
monitor,
placement: HashMap::new(),
stats: IsomorphicStats::default(),
})
}
pub fn new_hybrid(n_gpu: usize) -> Result<Self> {
let gpu_device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
let cpu_device = Device::Cpu;
let monitor = MemoryMonitor::new(0, 1024 * 1024 * 1024)?;
Ok(Self {
gpu_device,
cpu_device,
strategy: OffloadStrategy::Hybrid { n_gpu },
monitor,
placement: HashMap::new(),
stats: IsomorphicStats::default(),
})
}
pub fn init_layers(&mut self, n_layers: usize) -> Result<()> {
match &self.strategy {
OffloadStrategy::Hybrid { n_gpu } => {
for i in 0..n_layers {
let placement = if i < *n_gpu {
LayerPlacement::GPU
} else {
LayerPlacement::CPU
};
self.placement.insert(i, placement);
debug!("Layer {} initialized at {:?}", i, placement);
}
}
OffloadStrategy::Dynamic { .. } => {
for i in 0..n_layers {
self.placement.insert(i, LayerPlacement::CPU);
debug!("Layer {} initialized at {:?}", i, LayerPlacement::CPU);
}
}
_ => {
for i in 0..n_layers {
self.placement.insert(i, LayerPlacement::GPU);
}
}
}
Ok(())
}
pub fn ensure_layer_ready(&mut self, layer_id: usize, layer_bytes: u64) -> Result<()> {
let current_placement = self
.placement
.get(&layer_id)
.copied()
.unwrap_or(LayerPlacement::CPU);
if current_placement == LayerPlacement::GPU {
return Ok(()); }
let (_free_vram, pressure) = self.monitor.update()?;
match pressure {
MemoryPressure::Critical => {
self.aggressive_offload(layer_bytes)?;
}
MemoryPressure::High => {
self.moderate_offload(layer_bytes)?;
}
_ => {}
}
if self.placement.get(&layer_id) == Some(&LayerPlacement::CPU) {
info!(
"🚀 Moving layer {} to GPU ({} MB)",
layer_id,
layer_bytes / 1024 / 1024
);
self.placement.insert(layer_id, LayerPlacement::GPU);
self.stats.preload_triggers += 1;
}
Ok(())
}
fn moderate_offload(&mut self, needed_bytes: u64) -> Result<()> {
let (free_vram, _) = self.monitor.update()?;
if free_vram > needed_bytes {
return Ok(()); }
let mut gpu_layers: Vec<usize> = self
.placement
.iter()
.filter(|(_, p)| **p == LayerPlacement::GPU)
.map(|(id, _)| *id)
.collect();
gpu_layers.sort_by(|a, b| b.cmp(a));
for layer_id in gpu_layers.iter().take(2) {
self.placement.insert(*layer_id, LayerPlacement::CPU);
debug!("Offloaded layer {} to CPU", layer_id);
self.stats.offload_triggers += 1;
let (free_vram, _) = self.monitor.update()?;
if free_vram > needed_bytes {
break;
}
}
Ok(())
}
fn aggressive_offload(&mut self, _needed_bytes: u64) -> Result<()> {
warn!("⚠️ Aggressive offloading triggered");
let mut gpu_layers: Vec<usize> = self
.placement
.iter()
.filter(|(_, p)| **p == LayerPlacement::GPU)
.map(|(id, _)| *id)
.collect();
gpu_layers.sort_by(|a, b| b.cmp(a));
for layer_id in gpu_layers.iter().skip(1) {
self.placement.insert(*layer_id, LayerPlacement::CPU);
debug!("Aggressive offload: layer {} to CPU", layer_id);
self.stats.offload_triggers += 1;
}
Ok(())
}
pub fn get_target_device(&self, layer_id: usize) -> Device {
match self.placement.get(&layer_id) {
Some(LayerPlacement::GPU) => self.gpu_device.clone(),
_ => self.cpu_device.clone(),
}
}
pub fn get_placement(&self, layer_id: usize) -> LayerPlacement {
self.placement
.get(&layer_id)
.copied()
.unwrap_or(LayerPlacement::CPU)
}
pub fn gpu_layers(&self) -> Vec<usize> {
self.placement
.iter()
.filter(|(_, p)| **p == LayerPlacement::GPU)
.map(|(id, _)| *id)
.collect()
}
pub fn stats(&self) -> &IsomorphicStats {
&self.stats
}
pub fn memory_status(&mut self) -> Result<(u64, u64, u64)> {
self.monitor.get_status()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pressure_calculation() {
assert_eq!(MemoryPressure::from_ratio(0.9), MemoryPressure::Low);
assert_eq!(MemoryPressure::from_ratio(0.7), MemoryPressure::Moderate);
assert_eq!(MemoryPressure::from_ratio(0.5), MemoryPressure::High);
assert_eq!(MemoryPressure::from_ratio(0.3), MemoryPressure::Critical);
}
#[test]
fn test_memory_pressure_ordering() {
assert!(MemoryPressure::Low < MemoryPressure::Moderate);
assert!(MemoryPressure::Moderate < MemoryPressure::High);
assert!(MemoryPressure::High < MemoryPressure::Critical);
}
}