use crate::memory_profiler::allocation::{MemoryType, PressureLevel};
use crate::Device;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct MemoryPressureEvent {
pub timestamp: Instant,
pub pressure_level: PressureLevel,
pub device: Option<Device>,
pub memory_type: MemoryType,
pub total_memory: usize,
pub available_memory: usize,
pub actions_taken: Vec<PressureAction>,
pub resolution_time: Option<Duration>,
}
#[derive(Debug, Clone)]
pub enum PressureAction {
FreedUnusedMemory { amount: usize },
CompactedPools { pools_affected: usize },
TriggeredGarbageCollection,
ReducedCaches { cache_reduction: usize },
SwappedToDisk { amount: usize },
KilledAllocations { count: usize },
RequestedMoreMemory { amount: usize },
}
#[derive(Debug, Clone)]
pub struct MemorySnapshot {
pub timestamp: Instant,
pub device_usage: HashMap<Device, DeviceMemoryUsage>,
pub host_usage: HostMemoryUsage,
pub memory_pressure: f64,
pub active_allocations: usize,
pub total_allocated: usize,
pub fragmentation_level: f64,
pub bandwidth_utilization: BandwidthUtilization,
}
#[derive(Debug, Clone)]
pub struct DeviceMemoryUsage {
pub total_memory: usize,
pub used_memory: usize,
pub free_memory: usize,
pub reserved_memory: usize,
pub utilization_percent: f64,
pub bandwidth_usage: f64,
pub active_transfers: usize,
}
#[derive(Debug, Clone)]
pub struct HostMemoryUsage {
pub total_memory: usize,
pub available_memory: usize,
pub process_memory: usize,
pub pinned_memory: usize,
pub virtual_memory: usize,
pub pressure_indicators: MemoryPressureIndicators,
}
#[derive(Debug, Clone)]
pub struct MemoryPressureIndicators {
pub system_pressure: PressureLevel,
pub process_pressure: PressureLevel,
pub swap_usage: usize,
pub page_fault_rate: f64,
pub allocation_failure_rate: f64,
}
#[derive(Debug, Clone)]
pub struct BandwidthUtilization {
pub total_bandwidth: f64,
pub current_usage: f64,
pub peak_usage: f64,
pub efficiency: f64,
pub device_breakdown: HashMap<Device, f64>,
}
pub struct MemoryPressureMonitor {
current_snapshots: Arc<RwLock<HashMap<Device, MemorySnapshot>>>,
pressure_events: Arc<Mutex<Vec<MemoryPressureEvent>>>,
global_stats: Arc<Mutex<GlobalPressureStats>>,
thresholds: PressureThresholds,
event_callbacks: Vec<Box<dyn Fn(&MemoryPressureEvent) + Send + Sync>>,
last_check: Arc<Mutex<Option<Instant>>>,
auto_mitigation: bool,
}
impl std::fmt::Debug for MemoryPressureMonitor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryPressureMonitor")
.field("current_snapshots", &self.current_snapshots)
.field("pressure_events", &self.pressure_events)
.field("global_stats", &self.global_stats)
.field("thresholds", &self.thresholds)
.field(
"event_callbacks",
&format!("{} callbacks", self.event_callbacks.len()),
)
.field("last_check", &self.last_check)
.field("auto_mitigation", &self.auto_mitigation)
.finish()
}
}
#[derive(Debug, Default)]
pub struct GlobalPressureStats {
pub total_events: AtomicU64,
pub events_by_level: HashMap<PressureLevel, AtomicU64>,
pub total_memory_freed: AtomicUsize,
pub avg_resolution_time: AtomicU64,
pub current_system_pressure: PressureLevel,
pub peak_memory_usage: AtomicUsize,
pub pressure_frequency: f64,
}
#[derive(Debug, Clone)]
pub struct PressureThresholds {
pub low_pressure: f64,
pub medium_pressure: f64,
pub high_pressure: f64,
pub critical_pressure: f64,
pub bandwidth_warning: f64,
pub allocation_failure_threshold: f64,
pub page_fault_threshold: f64,
}
impl MemoryPressureEvent {
pub fn new(
pressure_level: PressureLevel,
device: Option<Device>,
memory_type: MemoryType,
total_memory: usize,
available_memory: usize,
) -> Self {
Self {
timestamp: Instant::now(),
pressure_level,
device,
memory_type,
total_memory,
available_memory,
actions_taken: Vec::new(),
resolution_time: None,
}
}
pub fn add_action(&mut self, action: PressureAction) {
self.actions_taken.push(action);
}
pub fn mark_resolved(&mut self) {
self.resolution_time = Some(self.timestamp.elapsed());
}
pub fn memory_usage_percent(&self) -> f64 {
if self.total_memory == 0 {
0.0
} else {
let used = self.total_memory.saturating_sub(self.available_memory);
(used as f64 / self.total_memory as f64) * 100.0
}
}
pub fn is_resolved(&self) -> bool {
self.resolution_time.is_some()
}
pub fn total_memory_freed(&self) -> usize {
self.actions_taken
.iter()
.map(|action| match action {
PressureAction::FreedUnusedMemory { amount } => *amount,
PressureAction::SwappedToDisk { amount } => *amount,
PressureAction::ReducedCaches { cache_reduction } => *cache_reduction,
_ => 0,
})
.sum()
}
}
impl MemorySnapshot {
pub fn new() -> Self {
Self {
timestamp: Instant::now(),
device_usage: HashMap::new(),
host_usage: HostMemoryUsage::default(),
memory_pressure: 0.0,
active_allocations: 0,
total_allocated: 0,
fragmentation_level: 0.0,
bandwidth_utilization: BandwidthUtilization::default(),
}
}
pub fn calculate_system_pressure(&mut self) {
let mut total_pressure = 0.0;
let mut device_count = 0;
total_pressure += self.host_usage.get_pressure_score();
device_count += 1;
for usage in self.device_usage.values() {
total_pressure += usage.utilization_percent / 100.0;
device_count += 1;
}
self.memory_pressure = if device_count > 0 {
total_pressure / device_count as f64
} else {
0.0
};
}
pub fn highest_pressure_device(&self) -> Option<(Device, f64)> {
self.device_usage
.iter()
.max_by(|(_, a), (_, b)| {
a.utilization_percent
.partial_cmp(&b.utilization_percent)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(device, usage)| (device.clone(), usage.utilization_percent))
}
pub fn has_critical_pressure(&self, threshold: f64) -> bool {
self.memory_pressure > threshold
|| self
.device_usage
.values()
.any(|usage| usage.utilization_percent > threshold * 100.0)
}
}
impl DeviceMemoryUsage {
pub fn new(total_memory: usize) -> Self {
Self {
total_memory,
used_memory: 0,
free_memory: total_memory,
reserved_memory: 0,
utilization_percent: 0.0,
bandwidth_usage: 0.0,
active_transfers: 0,
}
}
pub fn update_usage(&mut self, used: usize, reserved: usize) {
self.used_memory = used;
self.reserved_memory = reserved;
self.free_memory = self.total_memory.saturating_sub(used + reserved);
self.utilization_percent = if self.total_memory > 0 {
((used + reserved) as f64 / self.total_memory as f64) * 100.0
} else {
0.0
};
}
pub fn is_under_pressure(&self, threshold: f64) -> bool {
self.utilization_percent > threshold
}
pub fn available_memory(&self) -> usize {
self.free_memory
}
pub fn fragmentation_ratio(&self) -> f64 {
if self.used_memory > 0 {
self.reserved_memory as f64 / self.used_memory as f64
} else {
0.0
}
}
}
impl HostMemoryUsage {
pub fn default() -> Self {
Self {
total_memory: 0,
available_memory: 0,
process_memory: 0,
pinned_memory: 0,
virtual_memory: 0,
pressure_indicators: MemoryPressureIndicators::default(),
}
}
pub fn get_pressure_score(&self) -> f64 {
if self.total_memory == 0 {
return 0.0;
}
let usage_ratio =
(self.total_memory - self.available_memory) as f64 / self.total_memory as f64;
let pressure_multiplier = match self.pressure_indicators.system_pressure {
PressureLevel::None => 1.0,
PressureLevel::Low => 1.2,
PressureLevel::Medium => 1.5,
PressureLevel::High => 2.0,
PressureLevel::Critical => 3.0,
};
usage_ratio * pressure_multiplier
}
pub fn is_critically_low(&self, threshold: f64) -> bool {
let available_gb = self.available_memory as f64 / (1024.0 * 1024.0 * 1024.0);
available_gb < threshold
}
pub fn update_pressure_indicators(&mut self, page_faults: f64, alloc_failures: f64) {
self.pressure_indicators.page_fault_rate = page_faults;
self.pressure_indicators.allocation_failure_rate = alloc_failures;
let usage_ratio =
(self.total_memory - self.available_memory) as f64 / self.total_memory as f64;
self.pressure_indicators.system_pressure = if usage_ratio > 0.95 || alloc_failures > 0.1 {
PressureLevel::Critical
} else if usage_ratio > 0.85 || page_faults > 1000.0 {
PressureLevel::High
} else if usage_ratio > 0.75 || page_faults > 500.0 {
PressureLevel::Medium
} else if usage_ratio > 0.60 {
PressureLevel::Low
} else {
PressureLevel::None
};
}
}
impl MemoryPressureIndicators {
pub fn default() -> Self {
Self {
system_pressure: PressureLevel::None,
process_pressure: PressureLevel::None,
swap_usage: 0,
page_fault_rate: 0.0,
allocation_failure_rate: 0.0,
}
}
pub fn combined_pressure_score(&self) -> f64 {
let system_score = match self.system_pressure {
PressureLevel::None => 0.0,
PressureLevel::Low => 0.2,
PressureLevel::Medium => 0.4,
PressureLevel::High => 0.7,
PressureLevel::Critical => 1.0,
};
let process_score = match self.process_pressure {
PressureLevel::None => 0.0,
PressureLevel::Low => 0.2,
PressureLevel::Medium => 0.4,
PressureLevel::High => 0.7,
PressureLevel::Critical => 1.0,
};
(system_score + process_score) / 2.0
}
pub fn requires_immediate_action(&self) -> bool {
matches!(self.system_pressure, PressureLevel::Critical)
|| matches!(self.process_pressure, PressureLevel::Critical)
|| self.allocation_failure_rate > 0.1
}
}
impl BandwidthUtilization {
pub fn default() -> Self {
Self {
total_bandwidth: 0.0,
current_usage: 0.0,
peak_usage: 0.0,
efficiency: 0.0,
device_breakdown: HashMap::new(),
}
}
pub fn update_usage(&mut self, current: f64) {
self.current_usage = current;
if current > self.peak_usage {
self.peak_usage = current;
}
self.efficiency = if self.total_bandwidth > 0.0 {
(self.current_usage / self.total_bandwidth).min(1.0)
} else {
0.0
};
}
pub fn is_underutilized(&self, threshold: f64) -> bool {
self.efficiency < threshold
}
pub fn is_saturated(&self, threshold: f64) -> bool {
self.efficiency > threshold
}
pub fn headroom_gbps(&self) -> f64 {
(self.total_bandwidth - self.current_usage).max(0.0)
}
}
impl MemoryPressureMonitor {
pub fn new(thresholds: PressureThresholds, auto_mitigation: bool) -> Self {
let mut events_by_level = HashMap::new();
events_by_level.insert(PressureLevel::Low, AtomicU64::new(0));
events_by_level.insert(PressureLevel::Medium, AtomicU64::new(0));
events_by_level.insert(PressureLevel::High, AtomicU64::new(0));
events_by_level.insert(PressureLevel::Critical, AtomicU64::new(0));
let global_stats = GlobalPressureStats {
total_events: AtomicU64::new(0),
events_by_level,
total_memory_freed: AtomicUsize::new(0),
avg_resolution_time: AtomicU64::new(0),
current_system_pressure: PressureLevel::None,
peak_memory_usage: AtomicUsize::new(0),
pressure_frequency: 0.0,
};
Self {
current_snapshots: Arc::new(RwLock::new(HashMap::new())),
pressure_events: Arc::new(Mutex::new(Vec::new())),
global_stats: Arc::new(Mutex::new(global_stats)),
thresholds,
event_callbacks: Vec::new(),
last_check: Arc::new(Mutex::new(None)),
auto_mitigation,
}
}
pub fn update_snapshot(&self, device: Device, snapshot: MemorySnapshot) {
let mut snapshots = self.current_snapshots.write();
snapshots.insert(device, snapshot);
}
pub fn check_memory_pressure(&self) -> Vec<MemoryPressureEvent> {
let mut events = Vec::new();
let snapshots = self.current_snapshots.read();
for (device, snapshot) in snapshots.iter() {
if let Some(device_usage) = snapshot.device_usage.get(device) {
if let Some(event) = self.check_device_pressure(device.clone(), device_usage) {
events.push(event);
}
}
if let Some(event) = self.check_host_pressure(&snapshot.host_usage) {
events.push(event);
}
if let Some(event) = self.check_bandwidth_pressure(&snapshot.bandwidth_utilization) {
events.push(event);
}
}
for event in &events {
self.record_pressure_event(event.clone());
}
*self.last_check.lock() = Some(Instant::now());
events
}
pub fn record_pressure_event(&self, event: MemoryPressureEvent) {
{
let mut stats = self.global_stats.lock();
stats.total_events.fetch_add(1, Ordering::Relaxed);
if let Some(counter) = stats.events_by_level.get(&event.pressure_level) {
counter.fetch_add(1, Ordering::Relaxed);
}
if event.pressure_level > stats.current_system_pressure {
stats.current_system_pressure = event.pressure_level;
}
}
for callback in &self.event_callbacks {
callback(&event);
}
self.pressure_events.lock().push(event);
}
pub fn recent_events(&self, since: Duration) -> Vec<MemoryPressureEvent> {
let cutoff = Instant::now() - since;
self.pressure_events
.lock()
.iter()
.filter(|event| event.timestamp > cutoff)
.cloned()
.collect()
}
fn check_device_pressure(
&self,
device: Device,
usage: &DeviceMemoryUsage,
) -> Option<MemoryPressureEvent> {
let pressure_level = if usage.utilization_percent > self.thresholds.critical_pressure {
PressureLevel::Critical
} else if usage.utilization_percent > self.thresholds.high_pressure {
PressureLevel::High
} else if usage.utilization_percent > self.thresholds.medium_pressure {
PressureLevel::Medium
} else if usage.utilization_percent > self.thresholds.low_pressure {
PressureLevel::Low
} else {
return None;
};
Some(MemoryPressureEvent::new(
pressure_level,
Some(device),
MemoryType::Device,
usage.total_memory,
usage.available_memory(),
))
}
fn check_host_pressure(&self, usage: &HostMemoryUsage) -> Option<MemoryPressureEvent> {
let pressure_score = usage.get_pressure_score();
let pressure_level = if pressure_score > self.thresholds.critical_pressure {
PressureLevel::Critical
} else if pressure_score > self.thresholds.high_pressure {
PressureLevel::High
} else if pressure_score > self.thresholds.medium_pressure {
PressureLevel::Medium
} else if pressure_score > self.thresholds.low_pressure {
PressureLevel::Low
} else {
return None;
};
Some(MemoryPressureEvent::new(
pressure_level,
None,
MemoryType::Host,
usage.total_memory,
usage.available_memory,
))
}
fn check_bandwidth_pressure(
&self,
bandwidth: &BandwidthUtilization,
) -> Option<MemoryPressureEvent> {
if bandwidth.efficiency > self.thresholds.bandwidth_warning {
let pressure_level = if bandwidth.efficiency > 0.95 {
PressureLevel::Critical
} else if bandwidth.efficiency > 0.85 {
PressureLevel::High
} else {
PressureLevel::Medium
};
Some(MemoryPressureEvent::new(
pressure_level,
None,
MemoryType::Host, (bandwidth.total_bandwidth * 1024.0 * 1024.0 * 1024.0) as usize, (bandwidth.headroom_gbps() * 1024.0 * 1024.0 * 1024.0) as usize,
))
} else {
None
}
}
}
impl Default for PressureThresholds {
fn default() -> Self {
Self {
low_pressure: 60.0,
medium_pressure: 75.0,
high_pressure: 85.0,
critical_pressure: 95.0,
bandwidth_warning: 80.0,
allocation_failure_threshold: 0.05,
page_fault_threshold: 1000.0,
}
}
}
impl std::fmt::Display for PressureAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PressureAction::FreedUnusedMemory { amount } => {
write!(f, "Freed {} bytes of unused memory", amount)
}
PressureAction::CompactedPools { pools_affected } => {
write!(f, "Compacted {} memory pools", pools_affected)
}
PressureAction::TriggeredGarbageCollection => {
write!(f, "Triggered garbage collection")
}
PressureAction::ReducedCaches { cache_reduction } => {
write!(f, "Reduced caches by {} bytes", cache_reduction)
}
PressureAction::SwappedToDisk { amount } => {
write!(f, "Swapped {} bytes to disk", amount)
}
PressureAction::KilledAllocations { count } => {
write!(f, "Killed {} low-priority allocations", count)
}
PressureAction::RequestedMoreMemory { amount } => {
write!(f, "Requested {} additional bytes from system", amount)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pressure_event_creation() {
let event = MemoryPressureEvent::new(
PressureLevel::High,
None,
MemoryType::Host,
1024 * 1024 * 1024, 128 * 1024 * 1024, );
assert_eq!(event.pressure_level, PressureLevel::High);
assert!(event.memory_usage_percent() > 85.0);
assert!(!event.is_resolved());
}
#[test]
fn test_device_memory_usage() {
let mut usage = DeviceMemoryUsage::new(1024 * 1024 * 1024); usage.update_usage(512 * 1024 * 1024, 128 * 1024 * 1024);
assert_eq!(usage.utilization_percent, 62.5); assert!(usage.is_under_pressure(60.0));
assert!(!usage.is_under_pressure(70.0));
}
#[test]
fn test_host_memory_pressure_calculation() {
let mut host_usage = HostMemoryUsage::default();
host_usage.total_memory = 8 * 1024 * 1024 * 1024; host_usage.available_memory = 1 * 1024 * 1024 * 1024;
host_usage.update_pressure_indicators(100.0, 0.01);
let pressure_score = host_usage.get_pressure_score();
assert!(pressure_score > 0.8); assert!(!host_usage.is_critically_low(0.5)); }
#[test]
fn test_bandwidth_utilization() {
let mut bandwidth = BandwidthUtilization::default();
bandwidth.total_bandwidth = 100.0; bandwidth.update_usage(85.0);
assert_eq!(bandwidth.efficiency, 0.85);
assert!(bandwidth.is_saturated(0.80));
assert!(!bandwidth.is_underutilized(0.50));
assert_eq!(bandwidth.headroom_gbps(), 15.0);
}
#[test]
fn test_memory_snapshot() {
let mut snapshot = MemorySnapshot::new();
snapshot.host_usage.total_memory = 8 * 1024 * 1024 * 1024;
snapshot.host_usage.available_memory = 2 * 1024 * 1024 * 1024;
let device = Device::cpu().expect("Device should succeed"); let mut device_usage = DeviceMemoryUsage::new(4 * 1024 * 1024 * 1024);
device_usage.update_usage(3 * 1024 * 1024 * 1024, 0);
snapshot.device_usage.insert(device, device_usage);
snapshot.calculate_system_pressure();
assert!(snapshot.memory_pressure > 0.5); assert!(snapshot.has_critical_pressure(0.7)); }
#[test]
fn test_pressure_monitor() {
let thresholds = PressureThresholds::default();
let monitor = MemoryPressureMonitor::new(thresholds, true);
let mut snapshot = MemorySnapshot::new();
let device = Device::cpu().expect("Device should succeed");
let mut device_usage = DeviceMemoryUsage::new(1024 * 1024 * 1024);
device_usage.update_usage(900 * 1024 * 1024, 0); snapshot.device_usage.insert(device.clone(), device_usage);
monitor.update_snapshot(device, snapshot);
let events = monitor.check_memory_pressure();
assert!(!events.is_empty());
assert!(events
.iter()
.any(|e| e.pressure_level >= PressureLevel::High));
}
#[test]
fn test_pressure_thresholds() {
let thresholds = PressureThresholds::default();
assert_eq!(thresholds.low_pressure, 60.0);
assert_eq!(thresholds.critical_pressure, 95.0);
assert!(thresholds.medium_pressure > thresholds.low_pressure);
assert!(thresholds.high_pressure > thresholds.medium_pressure);
}
}