use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum OomPolicy {
#[default]
Fail,
RetryAfterFree,
WaitAndRetry {
timeout_secs: u64,
},
CheckpointAndFail,
}
pub struct MemoryHook {
pub name: String,
pub estimated_free_bytes: usize,
pub execution_overhead_bytes: usize,
pub priority: u32,
pub(crate) callback: Box<dyn Fn() -> usize + Send + Sync>,
}
impl MemoryHook {
pub fn new<S, F>(
name: S,
estimated_free_bytes: usize,
execution_overhead_bytes: usize,
priority: u32,
callback: F,
) -> Self
where
S: Into<String>,
F: Fn() -> usize + Send + Sync + 'static,
{
Self {
name: name.into(),
estimated_free_bytes,
execution_overhead_bytes,
priority,
callback: Box::new(callback),
}
}
}
impl std::fmt::Debug for MemoryHook {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryHook")
.field("name", &self.name)
.field("estimated_free_bytes", &self.estimated_free_bytes)
.field("execution_overhead_bytes", &self.execution_overhead_bytes)
.field("priority", &self.priority)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum PressureLevel {
None,
Low,
Medium,
High,
Critical,
}
impl std::fmt::Display for PressureLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let label = match self {
Self::None => "none",
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
Self::Critical => "critical",
};
f.write_str(label)
}
}
pub trait MemoryPressureListener: Send + Sync {
fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel);
}
pub struct MemoryReservation {
_reservation: CudaBuffer<u8>,
reserved_bytes: usize,
device_ordinal: usize,
}
impl MemoryReservation {
#[inline]
pub fn reserved_bytes(&self) -> usize {
self.reserved_bytes
}
#[inline]
pub fn device_ordinal(&self) -> usize {
self.device_ordinal
}
}
impl std::fmt::Debug for MemoryReservation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryReservation")
.field("reserved_bytes", &self.reserved_bytes)
.field("device_ordinal", &self.device_ordinal)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct MemoryStats {
pub used_bytes: usize,
pub budget_bytes: usize,
pub peak_bytes: usize,
pub free_device_bytes: usize,
pub total_device_bytes: usize,
pub num_allocations: usize,
pub num_oom_recoveries: usize,
}
pub struct MemoryGuard {
device: Arc<GpuDevice>,
reservation: Mutex<Option<MemoryReservation>>,
budget_bytes: AtomicUsize,
used_bytes: AtomicUsize,
peak_bytes: AtomicUsize,
num_allocations: AtomicUsize,
num_oom_recoveries: AtomicUsize,
oom_policy: Mutex<OomPolicy>,
on_oom_callback: Mutex<Option<Box<dyn Fn() + Send + Sync>>>,
hooks: Mutex<Vec<MemoryHook>>,
pressure_listeners: Mutex<Vec<Box<dyn MemoryPressureListener>>>,
last_pressure_level: Mutex<PressureLevel>,
}
unsafe impl Send for MemoryGuard {}
unsafe impl Sync for MemoryGuard {}
impl MemoryGuard {
pub fn set_budget(&self, bytes: usize) {
self.budget_bytes.store(bytes, Ordering::SeqCst);
}
#[inline]
pub fn budget(&self) -> usize {
self.budget_bytes.load(Ordering::Relaxed)
}
pub fn on_oom<F: Fn() + Send + Sync + 'static>(&self, f: F) {
*self.on_oom_callback.lock().unwrap() = Some(Box::new(f));
}
pub fn set_oom_policy(&self, policy: OomPolicy) {
*self.oom_policy.lock().unwrap() = policy;
}
pub fn register_hook(&self, hook: MemoryHook) {
self.hooks.lock().unwrap().push(hook);
}
pub fn remove_hook(&self, name: &str) -> bool {
let mut hooks = self.hooks.lock().unwrap();
let before = hooks.len();
hooks.retain(|h| h.name != name);
hooks.len() < before
}
pub fn pressure_level(&self) -> PressureLevel {
let budget = self.budget_bytes.load(Ordering::Relaxed);
if budget == 0 {
return PressureLevel::None;
}
let used = self.used_bytes.load(Ordering::Relaxed);
Self::compute_pressure(budget, used)
}
fn compute_pressure(budget: usize, used: usize) -> PressureLevel {
if budget == 0 {
return PressureLevel::None;
}
if used >= budget {
return PressureLevel::Critical;
}
let free_frac = ((budget - used) as f64) / (budget as f64);
if free_frac > 0.30 {
PressureLevel::None
} else if free_frac > 0.10 {
PressureLevel::Low
} else if free_frac > 0.05 {
PressureLevel::Medium
} else {
PressureLevel::High
}
}
pub fn add_pressure_listener(&self, listener: Box<dyn MemoryPressureListener>) {
self.pressure_listeners.lock().unwrap().push(listener);
}
fn notify_pressure_change(&self) {
let new_level = self.pressure_level();
let mut last = self.last_pressure_level.lock().unwrap();
if *last != new_level {
let old = *last;
*last = new_level;
drop(last);
let listeners = self.pressure_listeners.lock().unwrap();
for listener in listeners.iter() {
listener.on_pressure_change(old, new_level);
}
}
}
#[cfg(feature = "cuda")]
pub fn safe_alloc_with_hooks<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
let alloc_bytes = count.saturating_mul(std::mem::size_of::<T>());
if self.check_budget(alloc_bytes).is_ok() {
let result = self.try_alloc_zeros::<T>(count, alloc_bytes);
if result.is_ok() {
self.notify_pressure_change();
}
return result;
}
let budget = self.budget_bytes.load(Ordering::Relaxed);
let used = self.used_bytes.load(Ordering::Relaxed);
let shortfall = (used + alloc_bytes).saturating_sub(budget);
let freed = self.run_hooks(shortfall, budget, used);
if freed > 0 {
let used_now = self.used_bytes.load(Ordering::Relaxed);
let headroom = budget.saturating_sub(used_now).saturating_add(freed);
if headroom >= alloc_bytes {
let result = self.try_alloc_zeros::<T>(count, alloc_bytes);
if result.is_ok() {
self.notify_pressure_change();
return result;
}
if let Err(e) = result {
if self.is_oom(&e) {
return self.handle_oom(count, alloc_bytes, e);
}
return Err(e);
}
}
}
{
let used_now = self.used_bytes.load(Ordering::Relaxed);
let headroom = budget.saturating_sub(used_now).saturating_add(freed);
if headroom < alloc_bytes {
return Err(crate::error::GpuError::BudgetExceeded {
requested_bytes: alloc_bytes,
budget_bytes: budget,
used_bytes: used_now,
});
}
}
match self.try_alloc_zeros::<T>(count, alloc_bytes) {
Ok(buf) => {
self.notify_pressure_change();
Ok(buf)
}
Err(e) if self.is_oom(&e) => self.handle_oom(count, alloc_bytes, e),
Err(e) => Err(e),
}
}
#[allow(dead_code)]
fn run_hooks(&self, shortfall: usize, budget: usize, used: usize) -> usize {
let hooks = self.hooks.lock().unwrap();
if hooks.is_empty() {
return 0;
}
let mut indices: Vec<usize> = (0..hooks.len()).collect();
indices.sort_by(|&a, &b| {
hooks[a].priority.cmp(&hooks[b].priority).then_with(|| {
hooks[b]
.estimated_free_bytes
.cmp(&hooks[a].estimated_free_bytes)
})
});
let mut total_freed: usize = 0;
let mut current_used = used;
for &idx in &indices {
if total_freed >= shortfall {
break;
}
let hook = &hooks[idx];
let headroom = budget.saturating_sub(current_used);
if hook.execution_overhead_bytes > headroom {
continue;
}
let freed = (hook.callback)();
total_freed = total_freed.saturating_add(freed);
current_used = current_used.saturating_sub(freed);
}
if total_freed > 0 {
self.used_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(total_freed))
})
.ok();
}
total_freed
}
pub fn release_reservation(&self) -> usize {
let mut lock = self.reservation.lock().unwrap();
if let Some(res) = lock.take() {
let bytes = res.reserved_bytes;
drop(res);
bytes
} else {
0
}
}
pub fn has_reservation(&self) -> bool {
self.reservation.lock().unwrap().is_some()
}
#[cfg(feature = "cuda")]
pub fn safe_alloc<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
let alloc_bytes = count.saturating_mul(std::mem::size_of::<T>());
self.check_budget(alloc_bytes)?;
match self.try_alloc_zeros::<T>(count, alloc_bytes) {
Ok(buf) => Ok(buf),
Err(e) if self.is_oom(&e) => self.handle_oom(count, alloc_bytes, e),
Err(e) => Err(e),
}
}
#[cfg(feature = "cuda")]
pub fn safe_alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr,
{
let alloc_bytes = data.len().saturating_mul(std::mem::size_of::<T>());
self.check_budget(alloc_bytes)?;
match self.try_alloc_copy(data, alloc_bytes) {
Ok(buf) => Ok(buf),
Err(e) if self.is_oom(&e) => {
let policy = self.oom_policy.lock().unwrap().clone();
match policy {
OomPolicy::Fail => Err(e),
OomPolicy::RetryAfterFree => {
self.free_caches();
self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
self.try_alloc_copy(data, alloc_bytes)
}
OomPolicy::WaitAndRetry { timeout_secs } => {
self.wait_for_memory(alloc_bytes, timeout_secs)?;
self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
self.try_alloc_copy(data, alloc_bytes)
}
OomPolicy::CheckpointAndFail => {
self.trigger_emergency_checkpoint();
Err(e)
}
}
}
Err(e) => Err(e),
}
}
pub fn free<T>(&self, buffer: CudaBuffer<T>) {
let bytes = buffer
.len()
.checked_mul(std::mem::size_of::<T>())
.unwrap_or(0);
self.used_bytes
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(bytes))
})
.ok();
self.num_allocations
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(1))
})
.ok();
drop(buffer);
self.notify_pressure_change();
}
pub fn stats(&self) -> MemoryStats {
let (free_device, total_device) = self.query_device_memory();
MemoryStats {
used_bytes: self.used_bytes.load(Ordering::Relaxed),
budget_bytes: self.budget_bytes.load(Ordering::Relaxed),
peak_bytes: self.peak_bytes.load(Ordering::Relaxed),
free_device_bytes: free_device,
total_device_bytes: total_device,
num_allocations: self.num_allocations.load(Ordering::Relaxed),
num_oom_recoveries: self.num_oom_recoveries.load(Ordering::Relaxed),
}
}
pub fn reset_peak_stats(&self) {
let current = self.used_bytes.load(Ordering::Relaxed);
self.peak_bytes.store(current, Ordering::Relaxed);
}
#[inline]
pub fn device(&self) -> &GpuDevice {
&self.device
}
#[inline]
pub fn device_arc(&self) -> &Arc<GpuDevice> {
&self.device
}
#[allow(dead_code)]
fn check_budget(&self, alloc_bytes: usize) -> GpuResult<()> {
let budget = self.budget_bytes.load(Ordering::Relaxed);
if budget == 0 {
return Ok(()); }
let used = self.used_bytes.load(Ordering::Relaxed);
if used.saturating_add(alloc_bytes) > budget {
return Err(GpuError::BudgetExceeded {
requested_bytes: alloc_bytes,
budget_bytes: budget,
used_bytes: used,
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn try_alloc_zeros<T>(&self, count: usize, alloc_bytes: usize) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
let slice = self.device.stream().alloc_zeros::<T>(count)?;
let prev = self.used_bytes.fetch_add(alloc_bytes, Ordering::Relaxed);
self.peak_bytes
.fetch_max(prev + alloc_bytes, Ordering::Relaxed);
self.num_allocations.fetch_add(1, Ordering::Relaxed);
Ok(CudaBuffer {
data: Some(slice),
len: count,
alloc_len: count,
device_ordinal: self.device.ordinal(),
pool_fn: None,
})
}
#[cfg(feature = "cuda")]
fn try_alloc_copy<T>(&self, data: &[T], alloc_bytes: usize) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr,
{
let slice = self.device.stream().clone_htod(data)?;
let prev = self.used_bytes.fetch_add(alloc_bytes, Ordering::Relaxed);
self.peak_bytes
.fetch_max(prev + alloc_bytes, Ordering::Relaxed);
self.num_allocations.fetch_add(1, Ordering::Relaxed);
Ok(CudaBuffer {
data: Some(slice),
len: data.len(),
alloc_len: data.len(),
device_ordinal: self.device.ordinal(),
pool_fn: None,
})
}
#[allow(dead_code)]
fn is_oom(&self, err: &GpuError) -> bool {
match err {
GpuError::OutOfMemory { .. } => true,
#[cfg(feature = "cuda")]
GpuError::Driver(driver_err) => {
let msg = format!("{driver_err}");
msg.contains("OUT_OF_MEMORY")
|| msg.contains("out of memory")
|| msg.contains("CUDA_ERROR_OUT_OF_MEMORY")
}
_ => false,
}
}
#[cfg(feature = "cuda")]
fn handle_oom<T>(
&self,
count: usize,
alloc_bytes: usize,
original_err: GpuError,
) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
let policy = self.oom_policy.lock().unwrap().clone();
match policy {
OomPolicy::Fail => Err(original_err),
OomPolicy::RetryAfterFree => {
self.free_caches();
self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
self.try_alloc_zeros(count, alloc_bytes)
}
OomPolicy::WaitAndRetry { timeout_secs } => {
self.wait_for_memory(alloc_bytes, timeout_secs)?;
self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
self.try_alloc_zeros(count, alloc_bytes)
}
OomPolicy::CheckpointAndFail => {
self.trigger_emergency_checkpoint();
Err(original_err)
}
}
}
#[allow(dead_code)]
fn free_caches(&self) {
}
#[allow(dead_code)]
fn wait_for_memory(&self, needed_bytes: usize, timeout_secs: u64) -> GpuResult<()> {
let deadline = Instant::now() + Duration::from_secs(timeout_secs);
loop {
let (free, _) = self.query_device_memory();
if free >= needed_bytes {
return Ok(());
}
if Instant::now() >= deadline {
return Err(GpuError::OutOfMemory {
requested_bytes: needed_bytes,
free_bytes: free,
});
}
std::thread::sleep(Duration::from_millis(100));
}
}
#[allow(dead_code)]
fn trigger_emergency_checkpoint(&self) {
let lock = self.on_oom_callback.lock().unwrap();
if let Some(cb) = lock.as_ref() {
cb();
}
}
fn query_device_memory(&self) -> (usize, usize) {
#[cfg(feature = "cuda")]
{
cudarc::driver::result::mem_get_info().unwrap_or((0, 0))
}
#[cfg(not(feature = "cuda"))]
{
(0, 0)
}
}
}
impl std::fmt::Debug for MemoryGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryGuard")
.field("device_ordinal", &self.device.ordinal())
.field("budget_bytes", &self.budget_bytes.load(Ordering::Relaxed))
.field("used_bytes", &self.used_bytes.load(Ordering::Relaxed))
.field("peak_bytes", &self.peak_bytes.load(Ordering::Relaxed))
.field(
"has_reservation",
&self.reservation.lock().unwrap().is_some(),
)
.finish()
}
}
#[cfg(not(feature = "cuda"))]
impl MemoryGuard {
pub fn safe_alloc<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
Err(GpuError::NoCudaFeature)
}
pub fn safe_alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
Err(GpuError::NoCudaFeature)
}
pub fn safe_alloc_with_hooks<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
Err(GpuError::NoCudaFeature)
}
}
pub struct MemoryGuardBuilder {
device: Arc<GpuDevice>,
budget_bytes: usize,
reserve_bytes: usize,
oom_policy: OomPolicy,
}
impl std::fmt::Debug for MemoryGuardBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryGuardBuilder")
.field("budget_bytes", &self.budget_bytes)
.field("reserve_bytes", &self.reserve_bytes)
.field("oom_policy", &self.oom_policy)
.field("device_ordinal", &self.device.ordinal())
.finish()
}
}
impl MemoryGuardBuilder {
pub fn new(device: Arc<GpuDevice>) -> Self {
Self {
device,
budget_bytes: 0,
reserve_bytes: 0,
oom_policy: OomPolicy::default(),
}
}
pub fn budget_bytes(mut self, bytes: usize) -> Self {
self.budget_bytes = bytes;
self
}
pub fn reserve_bytes(mut self, bytes: usize) -> Self {
self.reserve_bytes = bytes;
self
}
pub fn oom_policy(mut self, policy: OomPolicy) -> Self {
self.oom_policy = policy;
self
}
#[cfg(feature = "cuda")]
pub fn build(self) -> GpuResult<MemoryGuard> {
let reservation = if self.reserve_bytes > 0 {
let slice = self.device.stream().alloc_zeros::<u8>(self.reserve_bytes)?;
Some(MemoryReservation {
_reservation: CudaBuffer {
data: Some(slice),
len: self.reserve_bytes,
alloc_len: self.reserve_bytes,
device_ordinal: self.device.ordinal(),
pool_fn: None,
},
reserved_bytes: self.reserve_bytes,
device_ordinal: self.device.ordinal(),
})
} else {
None
};
Ok(MemoryGuard {
device: self.device,
reservation: Mutex::new(reservation),
budget_bytes: AtomicUsize::new(self.budget_bytes),
used_bytes: AtomicUsize::new(0),
peak_bytes: AtomicUsize::new(0),
num_allocations: AtomicUsize::new(0),
num_oom_recoveries: AtomicUsize::new(0),
oom_policy: Mutex::new(self.oom_policy),
on_oom_callback: Mutex::new(None),
hooks: Mutex::new(Vec::new()),
pressure_listeners: Mutex::new(Vec::new()),
last_pressure_level: Mutex::new(PressureLevel::None),
})
}
#[cfg(not(feature = "cuda"))]
pub fn build(self) -> GpuResult<MemoryGuard> {
Ok(MemoryGuard {
device: self.device,
reservation: Mutex::new(None),
budget_bytes: AtomicUsize::new(self.budget_bytes),
used_bytes: AtomicUsize::new(0),
peak_bytes: AtomicUsize::new(0),
num_allocations: AtomicUsize::new(0),
num_oom_recoveries: AtomicUsize::new(0),
oom_policy: Mutex::new(self.oom_policy),
on_oom_callback: Mutex::new(None),
hooks: Mutex::new(Vec::new()),
pressure_listeners: Mutex::new(Vec::new()),
last_pressure_level: Mutex::new(PressureLevel::None),
})
}
}
pub struct MemoryWatchdog {
device: Arc<GpuDevice>,
pressure_threshold_bytes: usize,
check_interval: Duration,
paused: AtomicBool,
stop: AtomicBool,
has_checked: AtomicBool,
}
impl MemoryWatchdog {
pub fn new(
device: Arc<GpuDevice>,
pressure_threshold_bytes: usize,
check_interval: Duration,
) -> Self {
Self {
device,
pressure_threshold_bytes,
check_interval,
paused: AtomicBool::new(false),
stop: AtomicBool::new(false),
has_checked: AtomicBool::new(false),
}
}
pub fn start(self: Arc<Self>) -> JoinHandle<()> {
std::thread::Builder::new()
.name("ferrotorch-memory-watchdog".into())
.spawn(move || {
while !self.stop.load(Ordering::Relaxed) {
let free = self.query_free_memory();
if free < self.pressure_threshold_bytes {
self.paused.store(true, Ordering::SeqCst);
while self.query_free_memory() < self.pressure_threshold_bytes {
if self.stop.load(Ordering::Relaxed) {
return;
}
std::thread::sleep(Duration::from_millis(500));
}
self.paused.store(false, Ordering::SeqCst);
}
self.has_checked.store(true, Ordering::SeqCst);
std::thread::sleep(self.check_interval);
}
})
.expect("failed to spawn memory watchdog thread")
}
pub fn stop(&self) {
self.stop.store(true, Ordering::SeqCst);
}
#[inline]
pub fn check_pressure(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
pub fn wait_if_paused(&self) {
while self.paused.load(Ordering::SeqCst) {
std::thread::sleep(Duration::from_millis(100));
}
}
pub fn wait_for_first_check(&self, timeout: Duration) {
let start = std::time::Instant::now();
while !self.has_checked.load(Ordering::SeqCst) {
if start.elapsed() > timeout {
return;
}
std::thread::sleep(Duration::from_millis(5));
}
}
#[inline]
pub fn pressure_threshold_bytes(&self) -> usize {
self.pressure_threshold_bytes
}
fn query_free_memory(&self) -> usize {
#[cfg(feature = "cuda")]
{
let ctx = self.device.context();
let _ = ctx.bind_to_thread();
cudarc::driver::result::mem_get_info()
.map(|(free, _)| free)
.unwrap_or(0)
}
#[cfg(not(feature = "cuda"))]
{
let _ = &self.device;
0
}
}
}
impl std::fmt::Debug for MemoryWatchdog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryWatchdog")
.field("device_ordinal", &self.device.ordinal())
.field("pressure_threshold_bytes", &self.pressure_threshold_bytes)
.field("check_interval", &self.check_interval)
.field("paused", &self.paused.load(Ordering::Relaxed))
.finish()
}
}
impl GpuDevice {
#[cfg(feature = "cuda")]
pub fn memory_info(&self) -> GpuResult<(usize, usize)> {
let info = cudarc::driver::result::mem_get_info()?;
Ok(info)
}
#[cfg(not(feature = "cuda"))]
pub fn memory_info(&self) -> GpuResult<(usize, usize)> {
Err(GpuError::NoCudaFeature)
}
}
pub struct MemoryGuardedDevice {
pub guard: MemoryGuard,
}
impl MemoryGuardedDevice {
#[inline]
pub fn device(&self) -> &GpuDevice {
self.guard.device()
}
#[inline]
pub fn guard(&self) -> &MemoryGuard {
&self.guard
}
}
impl std::fmt::Debug for MemoryGuardedDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryGuardedDevice")
.field("guard", &self.guard)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn oom_policy_default_is_fail() {
assert_eq!(OomPolicy::default(), OomPolicy::Fail);
}
#[test]
fn oom_policy_debug() {
let p = OomPolicy::WaitAndRetry { timeout_secs: 30 };
let s = format!("{p:?}");
assert!(s.contains("WaitAndRetry"));
assert!(s.contains("30"));
}
#[test]
fn memory_stats_clone_eq() {
let s = MemoryStats {
used_bytes: 100,
budget_bytes: 1000,
peak_bytes: 200,
free_device_bytes: 800,
total_device_bytes: 2000,
num_allocations: 5,
num_oom_recoveries: 1,
};
let s2 = s.clone();
assert_eq!(s, s2);
}
#[test]
fn memory_stats_debug() {
let s = MemoryStats {
used_bytes: 0,
budget_bytes: 0,
peak_bytes: 0,
free_device_bytes: 0,
total_device_bytes: 0,
num_allocations: 0,
num_oom_recoveries: 0,
};
let d = format!("{s:?}");
assert!(d.contains("MemoryStats"));
assert!(d.contains("used_bytes"));
}
#[test]
fn gpu_error_out_of_memory_display() {
let e = GpuError::OutOfMemory {
requested_bytes: 1024,
free_bytes: 512,
};
let s = format!("{e}");
assert!(s.contains("1024"));
assert!(s.contains("512"));
assert!(s.contains("out of memory"));
}
#[test]
fn gpu_error_budget_exceeded_display() {
let e = GpuError::BudgetExceeded {
requested_bytes: 500,
budget_bytes: 1000,
used_bytes: 800,
};
let s = format!("{e}");
assert!(s.contains("500"));
assert!(s.contains("1000"));
assert!(s.contains("800"));
assert!(s.contains("budget exceeded"));
}
#[test]
fn pressure_level_ordering() {
assert!(PressureLevel::None < PressureLevel::Low);
assert!(PressureLevel::Low < PressureLevel::Medium);
assert!(PressureLevel::Medium < PressureLevel::High);
assert!(PressureLevel::High < PressureLevel::Critical);
}
#[test]
fn pressure_level_display() {
assert_eq!(format!("{}", PressureLevel::None), "none");
assert_eq!(format!("{}", PressureLevel::Critical), "critical");
}
#[test]
fn pressure_level_debug_clone_eq() {
let p = PressureLevel::Medium;
let p2 = p;
assert_eq!(p, p2);
let s = format!("{p:?}");
assert!(s.contains("Medium"));
}
#[test]
fn compute_pressure_unlimited_budget_is_none() {
assert_eq!(MemoryGuard::compute_pressure(0, 0), PressureLevel::None);
}
#[test]
fn compute_pressure_thresholds() {
let budget = 1000;
assert_eq!(
MemoryGuard::compute_pressure(budget, 0),
PressureLevel::None
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 600),
PressureLevel::None
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 699),
PressureLevel::None
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 750),
PressureLevel::Low
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 890),
PressureLevel::Low
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 910),
PressureLevel::Medium
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 949),
PressureLevel::Medium
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 960),
PressureLevel::High
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 999),
PressureLevel::High
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 1000),
PressureLevel::Critical
);
assert_eq!(
MemoryGuard::compute_pressure(budget, 2000),
PressureLevel::Critical
);
}
#[test]
fn memory_hook_debug() {
let hook = MemoryHook::new("test_hook", 1024, 64, 5, || 1024);
let s = format!("{hook:?}");
assert!(s.contains("test_hook"));
assert!(s.contains("1024"));
assert!(s.contains("64"));
assert!(s.contains("5"));
}
#[cfg(feature = "cuda")]
mod gpu_tests {
use super::*;
fn make_device() -> Arc<GpuDevice> {
Arc::new(GpuDevice::new(0).expect("CUDA device 0"))
}
#[test]
fn guard_construction_and_stats() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(1024 * 1024 * 1024) .oom_policy(OomPolicy::Fail)
.build()
.expect("build guard");
let stats = guard.stats();
assert_eq!(stats.used_bytes, 0);
assert_eq!(stats.budget_bytes, 1024 * 1024 * 1024);
assert_eq!(stats.peak_bytes, 0);
assert_eq!(stats.num_allocations, 0);
assert_eq!(stats.num_oom_recoveries, 0);
assert!(stats.total_device_bytes > 0);
assert!(stats.free_device_bytes > 0);
}
#[test]
fn budget_enforcement_rejects_over_budget() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(256) .build()
.expect("build guard");
let result = guard.safe_alloc::<f32>(1024); assert!(result.is_err());
match result.unwrap_err() {
GpuError::BudgetExceeded {
requested_bytes,
budget_bytes,
used_bytes,
} => {
assert_eq!(requested_bytes, 1024 * 4);
assert_eq!(budget_bytes, 256);
assert_eq!(used_bytes, 0);
}
other => panic!("expected BudgetExceeded, got {other:?}"),
}
}
#[test]
fn safe_alloc_tracks_usage() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(0) .build()
.expect("build guard");
let buf = guard.safe_alloc::<f32>(256).expect("alloc 256 f32");
let expected = 256 * std::mem::size_of::<f32>();
let stats = guard.stats();
assert_eq!(stats.used_bytes, expected);
assert_eq!(stats.peak_bytes, expected);
assert_eq!(stats.num_allocations, 1);
guard.free(buf);
let stats = guard.stats();
assert_eq!(stats.used_bytes, 0);
assert_eq!(stats.num_allocations, 0);
assert_eq!(stats.peak_bytes, expected);
}
#[test]
fn safe_alloc_copy_tracks_usage() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.build()
.expect("build guard");
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let buf = guard.safe_alloc_copy(&data).expect("alloc_copy");
let expected = 4 * std::mem::size_of::<f64>();
assert_eq!(guard.stats().used_bytes, expected);
guard.free(buf);
assert_eq!(guard.stats().used_bytes, 0);
}
#[test]
fn reset_peak_stats_works() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.build()
.expect("build guard");
let buf = guard.safe_alloc::<f32>(512).expect("alloc");
let peak = guard.stats().peak_bytes;
assert!(peak > 0);
guard.free(buf);
assert_eq!(guard.stats().peak_bytes, peak);
guard.reset_peak_stats();
assert_eq!(guard.stats().peak_bytes, 0); }
#[test]
fn emergency_checkpoint_callback_invoked() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.build()
.expect("build guard");
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
guard.on_oom(move || {
called_clone.store(true, Ordering::SeqCst);
});
guard.trigger_emergency_checkpoint();
assert!(called.load(Ordering::SeqCst));
}
#[test]
fn set_budget_at_runtime() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(0)
.build()
.expect("build guard");
assert_eq!(guard.budget(), 0);
guard.set_budget(1024);
assert_eq!(guard.budget(), 1024);
let result = guard.safe_alloc::<f32>(1024); assert!(result.is_err());
}
#[test]
fn memory_info_returns_nonzero() {
let device = GpuDevice::new(0).expect("CUDA device 0");
let (free, total) = device.memory_info().expect("memory_info");
assert!(total > 0, "total device memory should be > 0");
assert!(free > 0, "free device memory should be > 0");
assert!(free <= total, "free should not exceed total");
}
#[test]
fn reservation_holds_memory() {
let device = make_device();
let (free_before, _) = device.memory_info().expect("memory_info");
let reserve_bytes = 64 * 1024 * 1024;
let guard = MemoryGuardBuilder::new(device)
.reserve_bytes(reserve_bytes)
.build()
.expect("build guard with reservation");
assert!(guard.has_reservation());
let released = guard.release_reservation();
assert_eq!(released, reserve_bytes);
assert!(!guard.has_reservation());
assert_eq!(guard.release_reservation(), 0);
let _ = free_before; }
#[test]
fn guard_debug_impl() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(999)
.build()
.expect("build guard");
let s = format!("{guard:?}");
assert!(s.contains("MemoryGuard"));
assert!(s.contains("budget_bytes"));
assert!(s.contains("999"));
}
#[test]
fn watchdog_detects_no_pressure_when_plenty_free() {
let device = make_device();
let watchdog = Arc::new(MemoryWatchdog::new(device, 1, Duration::from_millis(50)));
assert!(!watchdog.check_pressure());
watchdog.wait_if_paused();
let wd = Arc::clone(&watchdog);
let handle = wd.start();
watchdog.wait_for_first_check(Duration::from_secs(5));
assert!(!watchdog.check_pressure());
watchdog.stop();
handle.join().expect("watchdog thread");
}
#[test]
fn watchdog_debug_impl() {
let device = make_device();
let watchdog = MemoryWatchdog::new(device, 1024, Duration::from_secs(1));
let s = format!("{watchdog:?}");
assert!(s.contains("MemoryWatchdog"));
assert!(s.contains("1024"));
}
#[test]
fn memory_guarded_device() {
let device = make_device();
let guard = MemoryGuardBuilder::new(Arc::clone(&device))
.budget_bytes(1024 * 1024)
.build()
.expect("build guard");
let guarded = MemoryGuardedDevice { guard };
assert_eq!(guarded.device().ordinal(), 0);
assert_eq!(guarded.guard().budget(), 1024 * 1024);
let s = format!("{guarded:?}");
assert!(s.contains("MemoryGuardedDevice"));
}
#[test]
fn oom_policy_retry_after_free() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.oom_policy(OomPolicy::RetryAfterFree)
.build()
.expect("build guard");
let buf = guard.safe_alloc::<f32>(64).expect("alloc");
assert_eq!(guard.stats().num_oom_recoveries, 0);
guard.free(buf);
}
#[test]
fn multiple_allocations_budget_accounting() {
let device = make_device();
let budget = 2048_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
let buf1 = guard.safe_alloc::<f32>(128).expect("alloc 1");
assert_eq!(guard.stats().used_bytes, 512);
assert_eq!(guard.stats().num_allocations, 1);
let buf2 = guard.safe_alloc::<f32>(128).expect("alloc 2");
assert_eq!(guard.stats().used_bytes, 1024);
assert_eq!(guard.stats().num_allocations, 2);
let result = guard.safe_alloc::<f32>(512);
assert!(result.is_err());
guard.free(buf1);
guard.free(buf2);
assert_eq!(guard.stats().used_bytes, 0);
let buf3 = guard.safe_alloc::<f32>(512).expect("alloc 3 after free");
assert_eq!(guard.stats().used_bytes, 2048);
guard.free(buf3);
}
#[test]
fn hook_called_on_budget_exceeded() {
let device = make_device();
let budget = 1024_usize; let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
guard.register_hook(MemoryHook::new("test_hook", 2048, 0, 10, move || {
called_clone.store(true, Ordering::SeqCst);
0 }));
let _result = guard.safe_alloc_with_hooks::<f32>(512);
assert!(called.load(Ordering::SeqCst), "hook was not called");
}
#[test]
fn hook_frees_enough_memory_allocation_succeeds() {
let device = make_device();
let budget = 2048_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
let prefill = guard.safe_alloc::<f32>(384).expect("prefill");
assert_eq!(guard.stats().used_bytes, 1536);
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
guard.register_hook(MemoryHook::new("free_1k", 1024, 0, 10, move || {
called_clone.store(true, Ordering::SeqCst);
1024
}));
let buf = guard
.safe_alloc_with_hooks::<f32>(256)
.expect("alloc after hook");
assert!(called.load(Ordering::SeqCst), "hook was not called");
assert_eq!(guard.stats().used_bytes, 1536);
guard.free(buf);
guard.free(prefill);
}
#[test]
fn hook_not_enough_falls_through_to_oom_policy() {
let device = make_device();
let budget = 512_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.oom_policy(OomPolicy::Fail)
.build()
.expect("build guard");
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
guard.register_hook(MemoryHook::new("weak_hook", 64, 0, 10, move || {
called_clone.store(true, Ordering::SeqCst);
64 }));
let result = guard.safe_alloc_with_hooks::<f32>(1024);
assert!(
called.load(Ordering::SeqCst),
"hook should have been called"
);
assert!(result.is_err(), "allocation should have failed");
}
#[test]
fn hooks_called_in_priority_order() {
let device = make_device();
let budget = 1024_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
let order = Arc::new(Mutex::new(Vec::new()));
let o1 = Arc::clone(&order);
guard.register_hook(MemoryHook::new("priority_20", 256, 0, 20, move || {
o1.lock().unwrap().push(20_u32);
256
}));
let o2 = Arc::clone(&order);
guard.register_hook(MemoryHook::new("priority_5", 256, 0, 5, move || {
o2.lock().unwrap().push(5_u32);
256
}));
let o3 = Arc::clone(&order);
guard.register_hook(MemoryHook::new("priority_10", 256, 0, 10, move || {
o3.lock().unwrap().push(10_u32);
256
}));
let _result = guard.safe_alloc_with_hooks::<f32>(512);
let call_order = order.lock().unwrap();
assert_eq!(
&*call_order,
&[5, 10, 20],
"hooks should fire in priority order"
);
}
#[test]
fn remove_hook_by_name() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(1024)
.build()
.expect("build guard");
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
guard.register_hook(MemoryHook::new("removable", 2048, 0, 10, move || {
called_clone.store(true, Ordering::SeqCst);
2048
}));
assert!(guard.remove_hook("removable"));
assert!(!guard.remove_hook("removable"));
let _result = guard.safe_alloc_with_hooks::<f32>(512);
assert!(
!called.load(Ordering::SeqCst),
"removed hook should not have been called"
);
}
#[test]
fn pressure_level_tracks_usage() {
let device = make_device();
let budget = 1000_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
assert_eq!(guard.pressure_level(), PressureLevel::None);
guard.used_bytes.store(750, Ordering::Relaxed);
assert_eq!(guard.pressure_level(), PressureLevel::Low);
guard.used_bytes.store(920, Ordering::Relaxed);
assert_eq!(guard.pressure_level(), PressureLevel::Medium);
guard.used_bytes.store(960, Ordering::Relaxed);
assert_eq!(guard.pressure_level(), PressureLevel::High);
guard.used_bytes.store(1000, Ordering::Relaxed);
assert_eq!(guard.pressure_level(), PressureLevel::Critical);
guard.set_budget(0);
assert_eq!(guard.pressure_level(), PressureLevel::None);
}
#[test]
fn multiple_hooks_called_until_enough_freed() {
let device = make_device();
let budget = 2048_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
let prefill = guard.safe_alloc::<f32>(256).expect("prefill");
assert_eq!(guard.stats().used_bytes, 1024);
let count = Arc::new(AtomicUsize::new(0));
let c1 = Arc::clone(&count);
guard.register_hook(MemoryHook::new("hook_a", 256, 0, 1, move || {
c1.fetch_add(1, Ordering::SeqCst);
256
}));
let c2 = Arc::clone(&count);
guard.register_hook(MemoryHook::new("hook_b", 512, 0, 2, move || {
c2.fetch_add(1, Ordering::SeqCst);
512
}));
let c3 = Arc::new(AtomicBool::new(false));
let c3_clone = Arc::clone(&c3);
guard.register_hook(MemoryHook::new("hook_c", 512, 0, 3, move || {
c3_clone.store(true, Ordering::SeqCst);
512
}));
let buf = guard
.safe_alloc_with_hooks::<f32>(384)
.expect("alloc with hooks");
assert_eq!(count.load(Ordering::SeqCst), 2, "hooks A and B should fire");
assert!(
!c3.load(Ordering::SeqCst),
"hook C should not have been called"
);
guard.free(buf);
guard.free(prefill);
}
#[test]
fn hook_with_excessive_overhead_is_skipped() {
let device = make_device();
let budget = 2048_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
let prefill = guard.safe_alloc::<f32>(480).expect("prefill");
assert_eq!(guard.stats().used_bytes, 1920);
let expensive_called = Arc::new(AtomicBool::new(false));
let expensive_clone = Arc::clone(&expensive_called);
guard.register_hook(MemoryHook::new("expensive_hook", 1024, 256, 1, move || {
expensive_clone.store(true, Ordering::SeqCst);
1024
}));
let cheap_called = Arc::new(AtomicBool::new(false));
let cheap_clone = Arc::clone(&cheap_called);
guard.register_hook(MemoryHook::new("cheap_hook", 512, 0, 2, move || {
cheap_clone.store(true, Ordering::SeqCst);
512
}));
let buf = guard
.safe_alloc_with_hooks::<f32>(64)
.expect("alloc with hooks");
assert!(
!expensive_called.load(Ordering::SeqCst),
"expensive hook should have been skipped due to overhead"
);
assert!(
cheap_called.load(Ordering::SeqCst),
"cheap hook should have been called"
);
guard.free(buf);
guard.free(prefill);
}
#[test]
fn pressure_listener_notified_on_change() {
let device = make_device();
let budget = 1000_usize;
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(budget)
.build()
.expect("build guard");
struct TestListener {
changes: Mutex<Vec<(PressureLevel, PressureLevel)>>,
}
impl MemoryPressureListener for TestListener {
fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel) {
self.changes.lock().unwrap().push((old, new));
}
}
let listener = Arc::new(TestListener {
changes: Mutex::new(Vec::new()),
});
let listener_ref = Arc::clone(&listener);
struct ListenerWrapper(Arc<TestListener>);
impl MemoryPressureListener for ListenerWrapper {
fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel) {
self.0.on_pressure_change(old, new);
}
}
guard.add_pressure_listener(Box::new(ListenerWrapper(listener_ref)));
let buf1 = guard.safe_alloc::<f32>(1).expect("small alloc");
guard.free(buf1);
guard.used_bytes.store(960, Ordering::Relaxed);
guard.notify_pressure_change(); guard.used_bytes.store(0, Ordering::Relaxed);
guard.notify_pressure_change();
let changes = listener.changes.lock().unwrap();
assert!(
changes.len() >= 2,
"should have at least 2 pressure changes, got {}",
changes.len()
);
assert_eq!(changes[0], (PressureLevel::None, PressureLevel::High));
assert_eq!(changes[1], (PressureLevel::High, PressureLevel::None));
}
#[test]
fn safe_alloc_with_hooks_fast_path_no_hooks() {
let device = make_device();
let guard = MemoryGuardBuilder::new(device)
.budget_bytes(1024 * 1024)
.build()
.expect("build guard");
let buf = guard
.safe_alloc_with_hooks::<f32>(64)
.expect("fast-path alloc");
assert_eq!(guard.stats().used_bytes, 64 * 4);
guard.free(buf);
}
}
}