use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
pub enum MemoryPressure {
Normal = 0,
Warning = 1,
Critical = 2,
Emergency = 3,
}
impl MemoryPressure {
fn from_u8(value: u8) -> Self {
match value {
0 => Self::Normal,
1 => Self::Warning,
2 => Self::Critical,
_ => Self::Emergency,
}
}
}
impl std::fmt::Display for MemoryPressure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Normal => write!(f, "normal"),
Self::Warning => write!(f, "warning"),
Self::Critical => write!(f, "critical"),
Self::Emergency => write!(f, "emergency"),
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub warning_threshold: f32,
pub critical_threshold: f32,
pub emergency_threshold: f32,
pub check_interval_ms: u64,
pub enable_auto_recovery: bool,
pub monitor_gpu: bool,
pub monitor_system: bool,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
warning_threshold: 0.70,
critical_threshold: 0.85,
emergency_threshold: 0.95,
check_interval_ms: 100,
enable_auto_recovery: true,
monitor_gpu: true,
monitor_system: true,
}
}
}
impl MemoryConfig {
pub fn memory_constrained() -> Self {
Self {
warning_threshold: 0.60,
critical_threshold: 0.75,
emergency_threshold: 0.90,
check_interval_ms: 50,
enable_auto_recovery: true,
monitor_gpu: true,
monitor_system: true,
}
}
pub fn high_memory() -> Self {
Self {
warning_threshold: 0.80,
critical_threshold: 0.90,
emergency_threshold: 0.97,
check_interval_ms: 200,
enable_auto_recovery: true,
monitor_gpu: true,
monitor_system: true,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.warning_threshold < 0.0 || self.warning_threshold > 1.0 {
return Err("warning_threshold must be between 0.0 and 1.0".to_string());
}
if self.critical_threshold < 0.0 || self.critical_threshold > 1.0 {
return Err("critical_threshold must be between 0.0 and 1.0".to_string());
}
if self.emergency_threshold < 0.0 || self.emergency_threshold > 1.0 {
return Err("emergency_threshold must be between 0.0 and 1.0".to_string());
}
if self.warning_threshold >= self.critical_threshold {
return Err("warning_threshold must be less than critical_threshold".to_string());
}
if self.critical_threshold >= self.emergency_threshold {
return Err("critical_threshold must be less than emergency_threshold".to_string());
}
if self.check_interval_ms == 0 {
return Err("check_interval_ms must be greater than 0".to_string());
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MemoryStats {
pub gpu_used: u64,
pub gpu_total: u64,
pub system_used: u64,
pub system_total: u64,
}
impl MemoryStats {
pub fn gpu_usage(&self) -> f32 {
if self.gpu_total == 0 {
0.0
} else {
self.gpu_used as f32 / self.gpu_total as f32
}
}
pub fn system_usage(&self) -> f32 {
if self.system_total == 0 {
0.0
} else {
self.system_used as f32 / self.system_total as f32
}
}
pub fn max_usage(&self) -> f32 {
self.gpu_usage().max(self.system_usage())
}
pub fn gpu_available(&self) -> u64 {
self.gpu_total.saturating_sub(self.gpu_used)
}
pub fn system_available(&self) -> u64 {
self.system_total.saturating_sub(self.system_used)
}
}
pub struct MemoryMonitor {
config: MemoryConfig,
current_pressure: AtomicU8,
gpu_memory_used: AtomicU64,
gpu_memory_total: AtomicU64,
system_memory_used: AtomicU64,
system_memory_total: AtomicU64,
running: AtomicBool,
}
impl MemoryMonitor {
pub fn new(config: MemoryConfig) -> Arc<Self> {
Arc::new(Self {
config,
current_pressure: AtomicU8::new(MemoryPressure::Normal as u8),
gpu_memory_used: AtomicU64::new(0),
gpu_memory_total: AtomicU64::new(0),
system_memory_used: AtomicU64::new(0),
system_memory_total: AtomicU64::new(0),
running: AtomicBool::new(false),
})
}
pub fn with_defaults() -> Arc<Self> {
Self::new(MemoryConfig::default())
}
pub fn start(self: &Arc<Self>) -> JoinHandle<()> {
let monitor = Arc::clone(self);
self.running.store(true, Ordering::SeqCst);
thread::spawn(move || {
while monitor.running.load(Ordering::SeqCst) {
monitor.update_stats();
thread::sleep(Duration::from_millis(monitor.config.check_interval_ms));
}
})
}
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn pressure(&self) -> MemoryPressure {
MemoryPressure::from_u8(self.current_pressure.load(Ordering::Relaxed))
}
pub fn stats(&self) -> MemoryStats {
MemoryStats {
gpu_used: self.gpu_memory_used.load(Ordering::Relaxed),
gpu_total: self.gpu_memory_total.load(Ordering::Relaxed),
system_used: self.system_memory_used.load(Ordering::Relaxed),
system_total: self.system_memory_total.load(Ordering::Relaxed),
}
}
pub fn gpu_memory(&self) -> (u64, u64) {
(
self.gpu_memory_used.load(Ordering::Relaxed),
self.gpu_memory_total.load(Ordering::Relaxed),
)
}
pub fn system_memory(&self) -> (u64, u64) {
(
self.system_memory_used.load(Ordering::Relaxed),
self.system_memory_total.load(Ordering::Relaxed),
)
}
pub fn update_stats(&self) {
if self.config.monitor_system {
if let Some((used, total)) = get_system_memory() {
self.system_memory_used.store(used, Ordering::Relaxed);
self.system_memory_total.store(total, Ordering::Relaxed);
}
}
if self.config.monitor_gpu {
if let Some((used, total)) = get_gpu_memory() {
self.gpu_memory_used.store(used, Ordering::Relaxed);
self.gpu_memory_total.store(total, Ordering::Relaxed);
}
}
let stats = self.stats();
let usage = stats.max_usage();
let pressure = self.calculate_pressure(usage);
self.current_pressure
.store(pressure as u8, Ordering::Relaxed);
}
fn calculate_pressure(&self, usage: f32) -> MemoryPressure {
if usage >= self.config.emergency_threshold {
MemoryPressure::Emergency
} else if usage >= self.config.critical_threshold {
MemoryPressure::Critical
} else if usage >= self.config.warning_threshold {
MemoryPressure::Warning
} else {
MemoryPressure::Normal
}
}
pub fn should_recover(&self) -> bool {
self.config.enable_auto_recovery && self.pressure() >= MemoryPressure::Critical
}
pub fn config(&self) -> &MemoryConfig {
&self.config
}
}
impl Drop for MemoryMonitor {
fn drop(&mut self) {
self.stop();
}
}
#[derive(Debug, Clone)]
pub enum RecoveryStrategy {
ClearCache,
ReduceContext {
factor: f32,
min_size: u32,
},
EvictLRU {
keep_count: usize,
},
ShiftContext {
keep_ratio: f32,
},
Abort,
}
impl Default for RecoveryStrategy {
fn default() -> Self {
Self::ShiftContext { keep_ratio: 0.5 }
}
}
#[derive(Debug)]
pub enum RecoveryResult {
Success,
ContextReduced { new_size: u32 },
SequencesEvicted { evicted: usize },
ContextShifted { tokens_removed: u32 },
Failed { reason: String },
}
pub struct RecoveryManager {
strategies: Vec<RecoveryStrategy>,
monitor: Option<Arc<MemoryMonitor>>,
max_attempts: u32,
}
impl RecoveryManager {
pub fn new() -> Self {
Self {
strategies: vec![
RecoveryStrategy::ShiftContext { keep_ratio: 0.75 },
RecoveryStrategy::ShiftContext { keep_ratio: 0.5 },
RecoveryStrategy::ClearCache,
RecoveryStrategy::Abort,
],
monitor: None,
max_attempts: 3,
}
}
pub fn with_strategies(strategies: Vec<RecoveryStrategy>) -> Self {
Self {
strategies,
monitor: None,
max_attempts: 3,
}
}
pub fn with_monitor(mut self, monitor: Arc<MemoryMonitor>) -> Self {
self.monitor = Some(monitor);
self
}
pub fn with_max_attempts(mut self, max: u32) -> Self {
self.max_attempts = max;
self
}
pub fn strategies(&self) -> &[RecoveryStrategy] {
&self.strategies
}
pub fn needs_recovery(&self) -> bool {
self.monitor
.as_ref()
.map(|m| m.should_recover())
.unwrap_or(false)
}
pub fn pressure(&self) -> Option<MemoryPressure> {
self.monitor.as_ref().map(|m| m.pressure())
}
pub fn get_strategy(&self, attempt: usize) -> Option<&RecoveryStrategy> {
self.strategies.get(attempt)
}
}
impl Default for RecoveryManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(target_os = "linux")]
pub fn get_system_memory() -> Option<(u64, u64)> {
use std::fs;
let meminfo = fs::read_to_string("/proc/meminfo").ok()?;
let mut total: u64 = 0;
let mut available: u64 = 0;
for line in meminfo.lines() {
if line.starts_with("MemTotal:") {
total = parse_meminfo_value(line)?;
} else if line.starts_with("MemAvailable:") {
available = parse_meminfo_value(line)?;
}
}
if total > 0 {
Some((total.saturating_sub(available), total))
} else {
None
}
}
#[cfg(target_os = "linux")]
fn parse_meminfo_value(line: &str) -> Option<u64> {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
parts[1].parse::<u64>().ok().map(|v| v * 1024)
} else {
None
}
}
#[cfg(target_os = "macos")]
pub fn get_system_memory() -> Option<(u64, u64)> {
use std::process::Command;
let total_output = Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
.ok()?;
let total_str = String::from_utf8_lossy(&total_output.stdout);
let total: u64 = total_str.trim().parse().ok()?;
let vm_output = Command::new("vm_stat").output().ok()?;
let vm_str = String::from_utf8_lossy(&vm_output.stdout);
let mut free_pages: u64 = 0;
let mut inactive_pages: u64 = 0;
let page_size: u64 = 4096;
for line in vm_str.lines() {
if line.starts_with("Pages free:") {
free_pages = parse_vm_stat_value(line)?;
} else if line.starts_with("Pages inactive:") {
inactive_pages = parse_vm_stat_value(line)?;
}
}
let available = (free_pages + inactive_pages) * page_size;
Some((total.saturating_sub(available), total))
}
#[cfg(target_os = "macos")]
fn parse_vm_stat_value(line: &str) -> Option<u64> {
let parts: Vec<&str> = line.split(':').collect();
if parts.len() >= 2 {
parts[1].trim().trim_end_matches('.').parse().ok()
} else {
None
}
}
#[cfg(target_os = "windows")]
pub fn get_system_memory() -> Option<(u64, u64)> {
use std::mem;
#[repr(C)]
struct MemoryStatusEx {
dw_length: u32,
dw_memory_load: u32,
ull_total_phys: u64,
ull_avail_phys: u64,
ull_total_page_file: u64,
ull_avail_page_file: u64,
ull_total_virtual: u64,
ull_avail_virtual: u64,
ull_avail_extended_virtual: u64,
}
#[link(name = "kernel32")]
extern "system" {
fn GlobalMemoryStatusEx(buffer: *mut MemoryStatusEx) -> i32;
}
let mut status = MemoryStatusEx {
dw_length: mem::size_of::<MemoryStatusEx>() as u32,
dw_memory_load: 0,
ull_total_phys: 0,
ull_avail_phys: 0,
ull_total_page_file: 0,
ull_avail_page_file: 0,
ull_total_virtual: 0,
ull_avail_virtual: 0,
ull_avail_extended_virtual: 0,
};
unsafe {
if GlobalMemoryStatusEx(&mut status) != 0 {
let used = status.ull_total_phys.saturating_sub(status.ull_avail_phys);
Some((used, status.ull_total_phys))
} else {
None
}
}
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
pub fn get_system_memory() -> Option<(u64, u64)> {
None
}
fn get_gpu_memory() -> Option<(u64, u64)> {
get_cuda_memory_via_nvidia_smi().or_else(|| {
#[cfg(target_os = "macos")]
{
get_metal_memory()
}
#[cfg(not(target_os = "macos"))]
{
None
}
})
}
fn get_cuda_memory_via_nvidia_smi() -> Option<(u64, u64)> {
use std::process::Command;
let output = Command::new("nvidia-smi")
.args([
"--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits",
])
.output()
.ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
let line = stdout.lines().next()?;
let parts: Vec<&str> = line.trim().split(',').collect();
if parts.len() != 2 {
return None;
}
let used_mb: u64 = parts[0].trim().parse().ok()?;
let total_mb: u64 = parts[1].trim().parse().ok()?;
Some((used_mb * 1024 * 1024, total_mb * 1024 * 1024))
}
#[cfg(target_os = "macos")]
fn get_metal_memory() -> Option<(u64, u64)> {
let (sys_used, sys_total) = get_system_memory()?;
let available = sys_total.saturating_sub(sys_used);
let gpu_wired = get_macos_wired_memory().unwrap_or(0);
let gpu_used = gpu_wired;
let gpu_available = available.saturating_sub(gpu_wired);
let gpu_effective_total = gpu_available.saturating_add(gpu_used);
Some((gpu_used, gpu_effective_total.max(gpu_used)))
}
#[cfg(target_os = "macos")]
pub fn get_macos_wired_memory() -> Option<u64> {
use std::process::Command;
let output = Command::new("vm_stat").output().ok()?;
let vm_str = String::from_utf8_lossy(&output.stdout);
let mut wired_pages: u64 = 0;
let page_size: u64 = 16384;
for line in vm_str.lines() {
if line.starts_with("Pages wired down:") {
wired_pages = parse_vm_stat_value(line)?;
}
}
Some(wired_pages * page_size)
}
#[cfg(not(target_os = "macos"))]
pub fn get_macos_wired_memory() -> Option<u64> {
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_config_validation() {
let config = MemoryConfig::default();
assert!(config.validate().is_ok());
let bad_config = MemoryConfig {
warning_threshold: 0.9,
critical_threshold: 0.8,
..Default::default()
};
assert!(bad_config.validate().is_err());
}
#[test]
fn test_memory_pressure_ordering() {
assert!(MemoryPressure::Normal < MemoryPressure::Warning);
assert!(MemoryPressure::Warning < MemoryPressure::Critical);
assert!(MemoryPressure::Critical < MemoryPressure::Emergency);
}
#[test]
fn test_memory_stats() {
let stats = MemoryStats {
gpu_used: 4_000_000_000,
gpu_total: 8_000_000_000,
system_used: 8_000_000_000,
system_total: 16_000_000_000,
};
assert!((stats.gpu_usage() - 0.5).abs() < 0.001);
assert!((stats.system_usage() - 0.5).abs() < 0.001);
assert_eq!(stats.gpu_available(), 4_000_000_000);
}
#[test]
fn test_recovery_manager() {
let manager = RecoveryManager::new();
assert!(!manager.strategies().is_empty());
assert!(manager.get_strategy(0).is_some());
assert!(!manager.needs_recovery()); }
#[test]
fn test_monitor_creation() {
let monitor = MemoryMonitor::with_defaults();
assert_eq!(monitor.pressure(), MemoryPressure::Normal);
assert!(!monitor.is_running());
}
}