Skip to main content

ferrotorch_gpu/
memory_guard.rs

1//! GPU memory safety system — reservation, OOM recovery, pressure monitoring, and
2//! pre-OOM hooks.
3//!
4//! This module implements four layers of protection against GPU memory issues:
5//!
6//! 1. **Memory Reservation** ([`MemoryReservation`]) — Pre-allocate a large block at
7//!    startup so other processes cannot steal VRAM out from under a training run.
8//!
9//! 2. **OOM Recovery** ([`OomPolicy`], [`MemoryGuard::safe_alloc`]) — Configurable
10//!    behaviour when an allocation fails: retry after freeing cache, wait for memory
11//!    to become available, or save a checkpoint before crashing.
12//!
13//! 3. **Memory Pressure Monitoring** ([`MemoryWatchdog`]) — Background thread that
14//!    pauses training when free VRAM drops below a threshold, resuming automatically
15//!    once the pressure lifts.
16//!
17//! 4. **Pre-OOM Hooks** ([`MemoryHook`], [`MemoryGuard::safe_alloc_with_hooks`]) —
18//!    User-registered callbacks that fire *before* an allocation fails. Hooks declare
19//!    upfront how much memory they expect to free (and any execution overhead), so the
20//!    guard can call them in priority order until enough headroom is recovered.
21//!
22//! # Quick start
23//!
24//! ```rust,no_run
25//! use std::sync::Arc;
26//! use ferrotorch_gpu::memory_guard::{MemoryGuard, MemoryGuardBuilder, OomPolicy};
27//! use ferrotorch_gpu::GpuDevice;
28//!
29//! let device = Arc::new(GpuDevice::new(0).unwrap());
30//! let guard = MemoryGuardBuilder::new(Arc::clone(&device))
31//!     .budget_bytes(20 * 1024 * 1024 * 1024) // 20 GiB
32//!     .oom_policy(OomPolicy::RetryAfterFree)
33//!     .build()
34//!     .unwrap();
35//!
36//! let stats = guard.stats();
37//! println!("free: {} / total: {}", stats.free_device_bytes, stats.total_device_bytes);
38//! ```
39
40use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
41use std::sync::{Arc, Mutex};
42use std::thread::JoinHandle;
43use std::time::{Duration, Instant};
44
45use crate::buffer::CudaBuffer;
46use crate::device::GpuDevice;
47use crate::error::{GpuError, GpuResult};
48
49// ---------------------------------------------------------------------------
50// OomPolicy
51// ---------------------------------------------------------------------------
52
53/// What to do when a GPU allocation fails with an out-of-memory error.
54#[derive(Debug, Clone, PartialEq, Eq, Default)]
55pub enum OomPolicy {
56    /// Crash immediately (PyTorch default behaviour).
57    #[default]
58    Fail,
59    /// Free the allocator cache and retry once.
60    RetryAfterFree,
61    /// Wait up to `timeout_secs` seconds for memory to become available,
62    /// then retry once.
63    WaitAndRetry {
64        /// Maximum seconds to wait.
65        timeout_secs: u64,
66    },
67    /// Invoke the registered emergency-checkpoint callback, then fail.
68    CheckpointAndFail,
69}
70
71// ---------------------------------------------------------------------------
72// MemoryHook — pre-OOM callback
73// ---------------------------------------------------------------------------
74
75/// A hook that can free memory or reduce memory demand before an OOM.
76///
77/// Hooks declare upfront how much memory they expect to free, so the
78/// memory guard can decide which hooks to call and in what order.
79///
80/// # Example: halving a batch when memory is tight
81///
82/// ```rust
83/// use ferrotorch_gpu::memory_guard::MemoryHook;
84///
85/// let hook = MemoryHook {
86///     name: "halve_batch_size".into(),
87///     estimated_free_bytes: 512 * 1024 * 1024, // expects to free ~512 MiB
88///     execution_overhead_bytes: 4096,           // metadata setup cost
89///     priority: 10,
90///     callback: Box::new(|| {
91///         // ... split the batch, free old tensors ...
92///         512 * 1024 * 1024 // actual bytes freed
93///     }),
94/// };
95/// ```
96pub struct MemoryHook {
97    /// Human-readable name (e.g., `"halve_batch_size"`, `"free_kv_cache"`).
98    pub name: String,
99    /// Estimated bytes this hook will free when called.
100    pub estimated_free_bytes: usize,
101    /// Extra bytes this hook needs temporarily to execute (e.g., metadata
102    /// setup for a batch split). If the available headroom is less than this
103    /// overhead the hook is skipped.
104    pub execution_overhead_bytes: usize,
105    /// Priority: lower values fire first. Hooks at the same priority are
106    /// called in registration order.
107    pub priority: u32,
108    /// The callback. Returns the *actual* bytes freed (may differ from the
109    /// estimate).
110    pub callback: Box<dyn Fn() -> usize + Send + Sync>,
111}
112
113impl std::fmt::Debug for MemoryHook {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("MemoryHook")
116            .field("name", &self.name)
117            .field("estimated_free_bytes", &self.estimated_free_bytes)
118            .field("execution_overhead_bytes", &self.execution_overhead_bytes)
119            .field("priority", &self.priority)
120            .finish()
121    }
122}
123
124// ---------------------------------------------------------------------------
125// PressureLevel
126// ---------------------------------------------------------------------------
127
128/// Level of memory pressure relative to the configured budget.
129#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
130pub enum PressureLevel {
131    /// Plenty of headroom (>30% of budget free).
132    None,
133    /// Getting tight (10--30% free). Informational only.
134    Low,
135    /// Approaching the limit (5--10% free). Non-critical hooks may fire.
136    Medium,
137    /// Near OOM (<5% free). All hooks fire.
138    High,
139    /// An allocation would fail without intervention.
140    Critical,
141}
142
143impl std::fmt::Display for PressureLevel {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        let label = match self {
146            Self::None => "none",
147            Self::Low => "low",
148            Self::Medium => "medium",
149            Self::High => "high",
150            Self::Critical => "critical",
151        };
152        f.write_str(label)
153    }
154}
155
156// ---------------------------------------------------------------------------
157// MemoryPressureListener
158// ---------------------------------------------------------------------------
159
160/// Trait for continuous pressure-level monitoring.
161///
162/// Types that conform to this trait can be registered via
163/// [`MemoryGuard::add_pressure_listener`] to receive callbacks whenever the
164/// pressure level changes (e.g., after every allocation or free).
165pub trait MemoryPressureListener: Send + Sync {
166    /// Called when the memory-pressure level transitions between two values.
167    fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel);
168}
169
170// ---------------------------------------------------------------------------
171// MemoryReservation
172// ---------------------------------------------------------------------------
173
174/// A sentinel CUDA allocation that reserves physical VRAM.
175///
176/// As long as this struct is alive the driver cannot give the reserved bytes to
177/// another process. Drop the reservation (or call
178/// [`MemoryGuard::release_reservation`]) to free the memory for reuse.
179pub struct MemoryReservation {
180    /// The reserved CUDA allocation that holds our budget.
181    /// Other processes cannot use this memory while this buffer exists.
182    _reservation: CudaBuffer<u8>,
183    /// Number of bytes reserved.
184    reserved_bytes: usize,
185    /// Which device the reservation lives on.
186    device_ordinal: usize,
187}
188
189impl MemoryReservation {
190    /// How many bytes are held by this reservation.
191    #[inline]
192    pub fn reserved_bytes(&self) -> usize {
193        self.reserved_bytes
194    }
195
196    /// The device ordinal the reservation lives on.
197    #[inline]
198    pub fn device_ordinal(&self) -> usize {
199        self.device_ordinal
200    }
201}
202
203impl std::fmt::Debug for MemoryReservation {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        f.debug_struct("MemoryReservation")
206            .field("reserved_bytes", &self.reserved_bytes)
207            .field("device_ordinal", &self.device_ordinal)
208            .finish()
209    }
210}
211
212// ---------------------------------------------------------------------------
213// MemoryStats
214// ---------------------------------------------------------------------------
215
216/// Snapshot of memory-guard statistics.
217#[derive(Debug, Clone, PartialEq, Eq)]
218pub struct MemoryStats {
219    /// Bytes currently tracked as live by the guard.
220    pub used_bytes: usize,
221    /// Hard budget ceiling (0 = unlimited).
222    pub budget_bytes: usize,
223    /// Peak tracked usage since creation or last reset.
224    pub peak_bytes: usize,
225    /// Free device memory as reported by the driver.
226    pub free_device_bytes: usize,
227    /// Total device memory as reported by the driver.
228    pub total_device_bytes: usize,
229    /// Number of live allocations tracked by the guard.
230    pub num_allocations: usize,
231    /// Number of OOM events that were successfully recovered.
232    pub num_oom_recoveries: usize,
233}
234
235// ---------------------------------------------------------------------------
236// MemoryGuard
237// ---------------------------------------------------------------------------
238
239/// Central memory-safety controller for a single GPU.
240///
241/// Wraps a [`GpuDevice`] and provides:
242/// - Optional upfront VRAM reservation (sentinel allocation).
243/// - Budget enforcement — allocations that would exceed the budget are
244///   rejected *before* touching the driver.
245/// - Configurable OOM recovery via [`OomPolicy`].
246/// - An emergency-checkpoint callback.
247///
248/// Construct via [`MemoryGuardBuilder`].
249pub struct MemoryGuard {
250    device: Arc<GpuDevice>,
251    /// Pre-allocated reservation block.
252    reservation: Mutex<Option<MemoryReservation>>,
253    /// Maximum bytes we are allowed to use (0 = unlimited).
254    budget_bytes: AtomicUsize,
255    /// Current live allocation bytes tracked by the guard.
256    used_bytes: AtomicUsize,
257    /// Peak tracked usage.
258    peak_bytes: AtomicUsize,
259    /// Number of live allocations.
260    num_allocations: AtomicUsize,
261    /// Number of successful OOM recoveries.
262    num_oom_recoveries: AtomicUsize,
263    /// Policy when OOM occurs.
264    oom_policy: Mutex<OomPolicy>,
265    /// Callback for emergency checkpoint.
266    on_oom_callback: Mutex<Option<Box<dyn Fn() + Send + Sync>>>,
267    /// Pre-OOM hooks, called before an allocation failure is propagated.
268    hooks: Mutex<Vec<MemoryHook>>,
269    /// Continuous pressure-level listeners.
270    pressure_listeners: Mutex<Vec<Box<dyn MemoryPressureListener>>>,
271    /// Cached pressure level for change detection.
272    last_pressure_level: Mutex<PressureLevel>,
273}
274
275// SAFETY: All interior mutability is via atomics or `Mutex`.
276unsafe impl Send for MemoryGuard {}
277unsafe impl Sync for MemoryGuard {}
278
279impl MemoryGuard {
280    // ------------------------------------------------------------------
281    // Budget
282    // ------------------------------------------------------------------
283
284    /// Set a hard budget in bytes. Allocations that would push `used_bytes`
285    /// past this limit return [`GpuError::BudgetExceeded`] without touching
286    /// the driver.
287    ///
288    /// Pass `0` to remove the budget (unlimited).
289    pub fn set_budget(&self, bytes: usize) {
290        self.budget_bytes.store(bytes, Ordering::SeqCst);
291    }
292
293    /// Current budget (0 = unlimited).
294    #[inline]
295    pub fn budget(&self) -> usize {
296        self.budget_bytes.load(Ordering::Relaxed)
297    }
298
299    // ------------------------------------------------------------------
300    // OOM callback
301    // ------------------------------------------------------------------
302
303    /// Register a callback that will be invoked on OOM when the policy is
304    /// [`OomPolicy::CheckpointAndFail`]. Typically used to save a training
305    /// checkpoint so work is not lost.
306    pub fn on_oom<F: Fn() + Send + Sync + 'static>(&self, f: F) {
307        *self.on_oom_callback.lock().unwrap() = Some(Box::new(f));
308    }
309
310    /// Change the OOM policy at runtime.
311    pub fn set_oom_policy(&self, policy: OomPolicy) {
312        *self.oom_policy.lock().unwrap() = policy;
313    }
314
315    // ------------------------------------------------------------------
316    // Pre-OOM hooks
317    // ------------------------------------------------------------------
318
319    /// Register a pre-OOM hook.
320    ///
321    /// Hooks are called (in priority order, lowest first) when an allocation
322    /// would exceed the budget. Each hook gets a chance to free memory before
323    /// the guard falls through to the [`OomPolicy`].
324    pub fn register_hook(&self, hook: MemoryHook) {
325        self.hooks.lock().unwrap().push(hook);
326    }
327
328    /// Remove a previously registered hook by name.
329    ///
330    /// Returns `true` if a hook with that name was found and removed.
331    pub fn remove_hook(&self, name: &str) -> bool {
332        let mut hooks = self.hooks.lock().unwrap();
333        let before = hooks.len();
334        hooks.retain(|h| h.name != name);
335        hooks.len() < before
336    }
337
338    /// Current pressure level based on budget usage.
339    ///
340    /// If no budget is set (budget = 0 / unlimited), always returns
341    /// [`PressureLevel::None`].
342    pub fn pressure_level(&self) -> PressureLevel {
343        let budget = self.budget_bytes.load(Ordering::Relaxed);
344        if budget == 0 {
345            return PressureLevel::None;
346        }
347        let used = self.used_bytes.load(Ordering::Relaxed);
348        Self::compute_pressure(budget, used)
349    }
350
351    /// Compute the pressure level from a budget and usage pair.
352    ///
353    /// A budget of `0` means unlimited and always returns [`PressureLevel::None`].
354    fn compute_pressure(budget: usize, used: usize) -> PressureLevel {
355        if budget == 0 {
356            return PressureLevel::None;
357        }
358        if used >= budget {
359            return PressureLevel::Critical;
360        }
361        let free_frac = ((budget - used) as f64) / (budget as f64);
362        if free_frac > 0.30 {
363            PressureLevel::None
364        } else if free_frac > 0.10 {
365            PressureLevel::Low
366        } else if free_frac > 0.05 {
367            PressureLevel::Medium
368        } else {
369            PressureLevel::High
370        }
371    }
372
373    /// Register a listener that is notified whenever the pressure level
374    /// changes (checked after every allocation and free through the guard).
375    pub fn add_pressure_listener(&self, listener: Box<dyn MemoryPressureListener>) {
376        self.pressure_listeners.lock().unwrap().push(listener);
377    }
378
379    /// Check whether the pressure level has changed and notify listeners.
380    fn notify_pressure_change(&self) {
381        let new_level = self.pressure_level();
382        let mut last = self.last_pressure_level.lock().unwrap();
383        if *last != new_level {
384            let old = *last;
385            *last = new_level;
386            // Release the last-level lock before calling listeners to avoid
387            // deadlocks if a listener queries the guard.
388            drop(last);
389            let listeners = self.pressure_listeners.lock().unwrap();
390            for listener in listeners.iter() {
391                listener.on_pressure_change(old, new_level);
392            }
393        }
394    }
395
396    /// Allocate `count` zero-initialized elements, trying pre-OOM hooks
397    /// before falling through to the [`OomPolicy`].
398    ///
399    /// The algorithm:
400    ///
401    /// 1. Check if the allocation fits within the budget -- if so, allocate
402    ///    directly.
403    /// 2. If not, compute the shortfall.
404    /// 3. Sort hooks by `(priority, estimated_free_bytes descending)`.
405    /// 4. Call hooks one at a time, skipping any whose
406    ///    `execution_overhead_bytes` exceeds current headroom, until enough
407    ///    cumulative memory has been freed.
408    /// 5. Retry the allocation.
409    /// 6. If still insufficient after all hooks, fall through to the regular
410    ///    [`OomPolicy`] path.
411    #[cfg(feature = "cuda")]
412    pub fn safe_alloc_with_hooks<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
413    where
414        T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
415    {
416        let alloc_bytes = count.saturating_mul(std::mem::size_of::<T>());
417
418        // Fast path: fits within budget.
419        if self.check_budget(alloc_bytes).is_ok() {
420            let result = self.try_alloc_zeros::<T>(count, alloc_bytes);
421            if result.is_ok() {
422                self.notify_pressure_change();
423            }
424            return result;
425        }
426
427        // Compute the shortfall.
428        let budget = self.budget_bytes.load(Ordering::Relaxed);
429        let used = self.used_bytes.load(Ordering::Relaxed);
430        let shortfall = (used + alloc_bytes).saturating_sub(budget);
431
432        // Try hooks.
433        let freed = self.run_hooks(shortfall, budget, used);
434
435        if freed > 0 {
436            // Re-check budget after hooks freed memory.
437            if self.check_budget(alloc_bytes).is_ok() {
438                let result = self.try_alloc_zeros::<T>(count, alloc_bytes);
439                if result.is_ok() {
440                    self.notify_pressure_change();
441                    return result;
442                }
443                // Driver-level OOM despite budget check passing -- fall through
444                // to OomPolicy.
445                if let Err(e) = result {
446                    if self.is_oom(&e) {
447                        return self.handle_oom(count, alloc_bytes, e);
448                    }
449                    return Err(e);
450                }
451            }
452        }
453
454        // Hooks were not enough. Re-check budget — if still over, enforce
455        // the budget rather than letting the driver allocate beyond it.
456        if self.check_budget(alloc_bytes).is_err() {
457            let budget = self.budget_bytes.load(Ordering::Relaxed);
458            let used = self.used_bytes.load(Ordering::Relaxed);
459            return Err(crate::error::GpuError::BudgetExceeded {
460                requested_bytes: alloc_bytes,
461                budget_bytes: budget,
462                used_bytes: used,
463            });
464        }
465
466        // Budget check passed (hooks freed enough). Try the driver.
467        match self.try_alloc_zeros::<T>(count, alloc_bytes) {
468            Ok(buf) => {
469                self.notify_pressure_change();
470                Ok(buf)
471            }
472            Err(e) if self.is_oom(&e) => self.handle_oom(count, alloc_bytes, e),
473            Err(e) => Err(e),
474        }
475    }
476
477    /// Run pre-OOM hooks in priority order until `shortfall` bytes have been
478    /// freed (or all hooks have been tried).
479    ///
480    /// Returns the total actual bytes freed across all invoked hooks.
481    #[allow(dead_code)]
482    fn run_hooks(&self, shortfall: usize, budget: usize, used: usize) -> usize {
483        // Build a sorted index of hooks. We sort by (priority ASC,
484        // estimated_free_bytes DESC) so high-impact hooks at a given
485        // priority run first.
486        let hooks = self.hooks.lock().unwrap();
487        if hooks.is_empty() {
488            return 0;
489        }
490
491        let mut indices: Vec<usize> = (0..hooks.len()).collect();
492        indices.sort_by(|&a, &b| {
493            hooks[a].priority.cmp(&hooks[b].priority).then_with(|| {
494                hooks[b]
495                    .estimated_free_bytes
496                    .cmp(&hooks[a].estimated_free_bytes)
497            })
498        });
499
500        let mut total_freed: usize = 0;
501        let mut current_used = used;
502
503        for &idx in &indices {
504            if total_freed >= shortfall {
505                break;
506            }
507
508            let hook = &hooks[idx];
509
510            // Skip if overhead exceeds available headroom. "Available
511            // headroom" is whatever room we have *right now* within the
512            // budget, after accounting for memory freed so far.
513            let headroom = budget.saturating_sub(current_used);
514            if hook.execution_overhead_bytes > headroom {
515                continue;
516            }
517
518            let freed = (hook.callback)();
519            total_freed = total_freed.saturating_add(freed);
520            current_used = current_used.saturating_sub(freed);
521        }
522
523        // Reflect freed memory in the atomic counter. Hooks freed memory
524        // outside the guard's tracking, so we adjust used_bytes downward.
525        if total_freed > 0 {
526            self.used_bytes
527                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
528                    Some(current.saturating_sub(total_freed))
529                })
530                .ok();
531        }
532
533        total_freed
534    }
535
536    // ------------------------------------------------------------------
537    // Reservation management
538    // ------------------------------------------------------------------
539
540    /// Release the upfront reservation, making its memory available for
541    /// normal allocations. Returns the number of bytes released, or `0` if
542    /// there was no active reservation.
543    pub fn release_reservation(&self) -> usize {
544        let mut lock = self.reservation.lock().unwrap();
545        if let Some(res) = lock.take() {
546            let bytes = res.reserved_bytes;
547            drop(res);
548            bytes
549        } else {
550            0
551        }
552    }
553
554    /// Whether an upfront reservation is currently held.
555    pub fn has_reservation(&self) -> bool {
556        self.reservation.lock().unwrap().is_some()
557    }
558
559    // ------------------------------------------------------------------
560    // Allocation with safety layers
561    // ------------------------------------------------------------------
562
563    /// Allocate `count` zero-initialized elements on the device, enforcing
564    /// the budget and OOM policy.
565    ///
566    /// This is the primary allocation entry point when using the memory
567    /// guard. Prefer this over raw `CudaAllocator::alloc_zeros`.
568    #[cfg(feature = "cuda")]
569    pub fn safe_alloc<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
570    where
571        T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
572    {
573        let alloc_bytes = count.saturating_mul(std::mem::size_of::<T>());
574
575        // --- Layer 1: budget check ---
576        self.check_budget(alloc_bytes)?;
577
578        // --- Layer 2: attempt allocation ---
579        match self.try_alloc_zeros::<T>(count, alloc_bytes) {
580            Ok(buf) => Ok(buf),
581            Err(e) if self.is_oom(&e) => self.handle_oom(count, alloc_bytes, e),
582            Err(e) => Err(e),
583        }
584    }
585
586    /// Allocate by copying host data to the device, enforcing budget and
587    /// OOM policy.
588    #[cfg(feature = "cuda")]
589    pub fn safe_alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
590    where
591        T: cudarc::driver::DeviceRepr,
592    {
593        let alloc_bytes = data.len().saturating_mul(std::mem::size_of::<T>());
594
595        self.check_budget(alloc_bytes)?;
596
597        match self.try_alloc_copy(data, alloc_bytes) {
598            Ok(buf) => Ok(buf),
599            Err(e) if self.is_oom(&e) => {
600                // For copy allocs, retry with the same data.
601                let policy = self.oom_policy.lock().unwrap().clone();
602                match policy {
603                    OomPolicy::Fail => Err(e),
604                    OomPolicy::RetryAfterFree => {
605                        self.free_caches();
606                        self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
607                        self.try_alloc_copy(data, alloc_bytes)
608                    }
609                    OomPolicy::WaitAndRetry { timeout_secs } => {
610                        self.wait_for_memory(alloc_bytes, timeout_secs)?;
611                        self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
612                        self.try_alloc_copy(data, alloc_bytes)
613                    }
614                    OomPolicy::CheckpointAndFail => {
615                        self.trigger_emergency_checkpoint();
616                        Err(e)
617                    }
618                }
619            }
620            Err(e) => Err(e),
621        }
622    }
623
624    /// Return a buffer to the guard, freeing GPU memory and updating
625    /// statistics.
626    pub fn free<T>(&self, buffer: CudaBuffer<T>) {
627        let bytes = buffer
628            .len()
629            .checked_mul(std::mem::size_of::<T>())
630            .unwrap_or(0);
631        self.used_bytes
632            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
633                Some(current.saturating_sub(bytes))
634            })
635            .ok();
636        self.num_allocations
637            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
638                Some(current.saturating_sub(1))
639            })
640            .ok();
641        drop(buffer);
642        self.notify_pressure_change();
643    }
644
645    // ------------------------------------------------------------------
646    // Statistics
647    // ------------------------------------------------------------------
648
649    /// Snapshot the current memory statistics.
650    pub fn stats(&self) -> MemoryStats {
651        let (free_device, total_device) = self.query_device_memory();
652        MemoryStats {
653            used_bytes: self.used_bytes.load(Ordering::Relaxed),
654            budget_bytes: self.budget_bytes.load(Ordering::Relaxed),
655            peak_bytes: self.peak_bytes.load(Ordering::Relaxed),
656            free_device_bytes: free_device,
657            total_device_bytes: total_device,
658            num_allocations: self.num_allocations.load(Ordering::Relaxed),
659            num_oom_recoveries: self.num_oom_recoveries.load(Ordering::Relaxed),
660        }
661    }
662
663    /// Reset the peak-usage counter to the current usage level.
664    pub fn reset_peak_stats(&self) {
665        let current = self.used_bytes.load(Ordering::Relaxed);
666        self.peak_bytes.store(current, Ordering::Relaxed);
667    }
668
669    /// The underlying device.
670    #[inline]
671    pub fn device(&self) -> &GpuDevice {
672        &self.device
673    }
674
675    /// The underlying device as an `Arc`.
676    #[inline]
677    pub fn device_arc(&self) -> &Arc<GpuDevice> {
678        &self.device
679    }
680
681    // ------------------------------------------------------------------
682    // Internal helpers
683    //
684    // These methods are used by the `#[cfg(feature = "cuda")]` allocation
685    // paths. In the no-cuda build the callers do not exist, so we suppress
686    // the dead-code lint.
687    // ------------------------------------------------------------------
688
689    /// Check whether `alloc_bytes` would exceed the budget.
690    #[allow(dead_code)]
691    fn check_budget(&self, alloc_bytes: usize) -> GpuResult<()> {
692        let budget = self.budget_bytes.load(Ordering::Relaxed);
693        if budget == 0 {
694            return Ok(()); // unlimited
695        }
696        let used = self.used_bytes.load(Ordering::Relaxed);
697        if used.saturating_add(alloc_bytes) > budget {
698            return Err(GpuError::BudgetExceeded {
699                requested_bytes: alloc_bytes,
700                budget_bytes: budget,
701                used_bytes: used,
702            });
703        }
704        Ok(())
705    }
706
707    /// Low-level zero-init allocation with tracking.
708    #[cfg(feature = "cuda")]
709    fn try_alloc_zeros<T>(&self, count: usize, alloc_bytes: usize) -> GpuResult<CudaBuffer<T>>
710    where
711        T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
712    {
713        let slice = self.device.stream().alloc_zeros::<T>(count)?;
714
715        let prev = self.used_bytes.fetch_add(alloc_bytes, Ordering::Relaxed);
716        self.peak_bytes
717            .fetch_max(prev + alloc_bytes, Ordering::Relaxed);
718        self.num_allocations.fetch_add(1, Ordering::Relaxed);
719
720        Ok(CudaBuffer {
721            data: Some(slice),
722            len: count,
723            alloc_len: count,
724            device_ordinal: self.device.ordinal(),
725            pool_fn: None,
726        })
727    }
728
729    /// Low-level host-to-device copy allocation with tracking.
730    #[cfg(feature = "cuda")]
731    fn try_alloc_copy<T>(&self, data: &[T], alloc_bytes: usize) -> GpuResult<CudaBuffer<T>>
732    where
733        T: cudarc::driver::DeviceRepr,
734    {
735        let slice = self.device.stream().clone_htod(data)?;
736
737        let prev = self.used_bytes.fetch_add(alloc_bytes, Ordering::Relaxed);
738        self.peak_bytes
739            .fetch_max(prev + alloc_bytes, Ordering::Relaxed);
740        self.num_allocations.fetch_add(1, Ordering::Relaxed);
741
742        Ok(CudaBuffer {
743            data: Some(slice),
744            len: data.len(),
745            alloc_len: data.len(),
746            device_ordinal: self.device.ordinal(),
747            pool_fn: None,
748        })
749    }
750
751    /// Determine whether an error is an out-of-memory condition.
752    #[allow(dead_code)]
753    fn is_oom(&self, err: &GpuError) -> bool {
754        match err {
755            GpuError::OutOfMemory { .. } => true,
756            #[cfg(feature = "cuda")]
757            GpuError::Driver(driver_err) => {
758                let msg = format!("{driver_err}");
759                msg.contains("OUT_OF_MEMORY")
760                    || msg.contains("out of memory")
761                    || msg.contains("CUDA_ERROR_OUT_OF_MEMORY")
762            }
763            _ => false,
764        }
765    }
766
767    /// Handle an OOM according to the current policy.
768    #[cfg(feature = "cuda")]
769    fn handle_oom<T>(
770        &self,
771        count: usize,
772        alloc_bytes: usize,
773        original_err: GpuError,
774    ) -> GpuResult<CudaBuffer<T>>
775    where
776        T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
777    {
778        let policy = self.oom_policy.lock().unwrap().clone();
779        match policy {
780            OomPolicy::Fail => Err(original_err),
781            OomPolicy::RetryAfterFree => {
782                self.free_caches();
783                self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
784                self.try_alloc_zeros(count, alloc_bytes)
785            }
786            OomPolicy::WaitAndRetry { timeout_secs } => {
787                self.wait_for_memory(alloc_bytes, timeout_secs)?;
788                self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
789                self.try_alloc_zeros(count, alloc_bytes)
790            }
791            OomPolicy::CheckpointAndFail => {
792                self.trigger_emergency_checkpoint();
793                Err(original_err)
794            }
795        }
796    }
797
798    /// Best-effort cache eviction. Currently a no-op placeholder — the
799    /// caching allocator is not yet implemented. When it is, this will
800    /// release all cached-but-free blocks.
801    #[allow(dead_code)]
802    fn free_caches(&self) {
803        // Future: delegate to CudaAllocator::empty_cache() once block
804        // caching is implemented.
805    }
806
807    /// Block until at least `needed_bytes` are free, or until `timeout_secs`
808    /// elapses.
809    #[allow(dead_code)]
810    fn wait_for_memory(&self, needed_bytes: usize, timeout_secs: u64) -> GpuResult<()> {
811        let deadline = Instant::now() + Duration::from_secs(timeout_secs);
812        loop {
813            let (free, _) = self.query_device_memory();
814            if free >= needed_bytes {
815                return Ok(());
816            }
817            if Instant::now() >= deadline {
818                return Err(GpuError::OutOfMemory {
819                    requested_bytes: needed_bytes,
820                    free_bytes: free,
821                });
822            }
823            std::thread::sleep(Duration::from_millis(100));
824        }
825    }
826
827    /// Invoke the user-registered emergency checkpoint callback.
828    #[allow(dead_code)]
829    fn trigger_emergency_checkpoint(&self) {
830        let lock = self.on_oom_callback.lock().unwrap();
831        if let Some(cb) = lock.as_ref() {
832            cb();
833        }
834    }
835
836    /// Query free and total device memory from the driver.
837    ///
838    /// Returns `(free_bytes, total_bytes)`. On error (or when the `cuda`
839    /// feature is disabled), returns `(0, 0)`.
840    fn query_device_memory(&self) -> (usize, usize) {
841        #[cfg(feature = "cuda")]
842        {
843            cudarc::driver::result::mem_get_info().unwrap_or((0, 0))
844        }
845        #[cfg(not(feature = "cuda"))]
846        {
847            (0, 0)
848        }
849    }
850}
851
852impl std::fmt::Debug for MemoryGuard {
853    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
854        f.debug_struct("MemoryGuard")
855            .field("device_ordinal", &self.device.ordinal())
856            .field("budget_bytes", &self.budget_bytes.load(Ordering::Relaxed))
857            .field("used_bytes", &self.used_bytes.load(Ordering::Relaxed))
858            .field("peak_bytes", &self.peak_bytes.load(Ordering::Relaxed))
859            .field(
860                "has_reservation",
861                &self.reservation.lock().unwrap().is_some(),
862            )
863            .finish()
864    }
865}
866
867// ---------------------------------------------------------------------------
868// Stub when `cuda` feature is disabled
869// ---------------------------------------------------------------------------
870
871#[cfg(not(feature = "cuda"))]
872impl MemoryGuard {
873    /// Stub — returns [`GpuError::NoCudaFeature`].
874    pub fn safe_alloc<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
875        Err(GpuError::NoCudaFeature)
876    }
877
878    /// Stub — returns [`GpuError::NoCudaFeature`].
879    pub fn safe_alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
880        Err(GpuError::NoCudaFeature)
881    }
882
883    /// Stub — returns [`GpuError::NoCudaFeature`].
884    pub fn safe_alloc_with_hooks<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
885        Err(GpuError::NoCudaFeature)
886    }
887}
888
889// ---------------------------------------------------------------------------
890// MemoryGuardBuilder
891// ---------------------------------------------------------------------------
892
893/// Builder for [`MemoryGuard`].
894///
895/// ```rust,no_run
896/// # use std::sync::Arc;
897/// # use ferrotorch_gpu::memory_guard::{MemoryGuardBuilder, OomPolicy};
898/// # use ferrotorch_gpu::GpuDevice;
899/// let device = Arc::new(GpuDevice::new(0).unwrap());
900/// let guard = MemoryGuardBuilder::new(device)
901///     .budget_bytes(16 * 1024 * 1024 * 1024)
902///     .reserve_bytes(16 * 1024 * 1024 * 1024)
903///     .oom_policy(OomPolicy::RetryAfterFree)
904///     .build()
905///     .unwrap();
906/// ```
907pub struct MemoryGuardBuilder {
908    device: Arc<GpuDevice>,
909    budget_bytes: usize,
910    reserve_bytes: usize,
911    oom_policy: OomPolicy,
912}
913
914impl MemoryGuardBuilder {
915    /// Create a new builder for the given device.
916    pub fn new(device: Arc<GpuDevice>) -> Self {
917        Self {
918            device,
919            budget_bytes: 0,
920            reserve_bytes: 0,
921            oom_policy: OomPolicy::default(),
922        }
923    }
924
925    /// Set the hard memory budget in bytes. `0` means unlimited.
926    pub fn budget_bytes(mut self, bytes: usize) -> Self {
927        self.budget_bytes = bytes;
928        self
929    }
930
931    /// Pre-allocate `bytes` of VRAM as a reservation sentinel.
932    /// Other processes cannot use this memory while the guard is alive.
933    pub fn reserve_bytes(mut self, bytes: usize) -> Self {
934        self.reserve_bytes = bytes;
935        self
936    }
937
938    /// Set the OOM recovery policy.
939    pub fn oom_policy(mut self, policy: OomPolicy) -> Self {
940        self.oom_policy = policy;
941        self
942    }
943
944    /// Build the [`MemoryGuard`].
945    ///
946    /// If `reserve_bytes` was set, this will attempt to allocate the
947    /// sentinel buffer. Failure to allocate is returned as an error.
948    #[cfg(feature = "cuda")]
949    pub fn build(self) -> GpuResult<MemoryGuard> {
950        let reservation = if self.reserve_bytes > 0 {
951            let slice = self.device.stream().alloc_zeros::<u8>(self.reserve_bytes)?;
952            Some(MemoryReservation {
953                _reservation: CudaBuffer {
954                    data: Some(slice),
955                    len: self.reserve_bytes,
956                    alloc_len: self.reserve_bytes,
957                    device_ordinal: self.device.ordinal(),
958                    pool_fn: None,
959                },
960                reserved_bytes: self.reserve_bytes,
961                device_ordinal: self.device.ordinal(),
962            })
963        } else {
964            None
965        };
966
967        Ok(MemoryGuard {
968            device: self.device,
969            reservation: Mutex::new(reservation),
970            budget_bytes: AtomicUsize::new(self.budget_bytes),
971            used_bytes: AtomicUsize::new(0),
972            peak_bytes: AtomicUsize::new(0),
973            num_allocations: AtomicUsize::new(0),
974            num_oom_recoveries: AtomicUsize::new(0),
975            oom_policy: Mutex::new(self.oom_policy),
976            on_oom_callback: Mutex::new(None),
977            hooks: Mutex::new(Vec::new()),
978            pressure_listeners: Mutex::new(Vec::new()),
979            last_pressure_level: Mutex::new(PressureLevel::None),
980        })
981    }
982
983    /// Stub build when `cuda` feature is disabled.
984    #[cfg(not(feature = "cuda"))]
985    pub fn build(self) -> GpuResult<MemoryGuard> {
986        Ok(MemoryGuard {
987            device: self.device,
988            reservation: Mutex::new(None),
989            budget_bytes: AtomicUsize::new(self.budget_bytes),
990            used_bytes: AtomicUsize::new(0),
991            peak_bytes: AtomicUsize::new(0),
992            num_allocations: AtomicUsize::new(0),
993            num_oom_recoveries: AtomicUsize::new(0),
994            oom_policy: Mutex::new(self.oom_policy),
995            on_oom_callback: Mutex::new(None),
996            hooks: Mutex::new(Vec::new()),
997            pressure_listeners: Mutex::new(Vec::new()),
998            last_pressure_level: Mutex::new(PressureLevel::None),
999        })
1000    }
1001}
1002
1003// ---------------------------------------------------------------------------
1004// MemoryWatchdog
1005// ---------------------------------------------------------------------------
1006
1007/// Background monitor that pauses training when free VRAM drops below a
1008/// threshold.
1009///
1010/// Create a watchdog, wrap it in an `Arc`, and call [`start`](Self::start) to
1011/// spawn the monitoring thread. Between training batches, call
1012/// [`wait_if_paused`](Self::wait_if_paused) to block until memory pressure
1013/// is resolved.
1014///
1015/// ```rust,no_run
1016/// use std::sync::Arc;
1017/// use std::time::Duration;
1018/// use ferrotorch_gpu::memory_guard::MemoryWatchdog;
1019/// use ferrotorch_gpu::GpuDevice;
1020///
1021/// let device = Arc::new(GpuDevice::new(0).unwrap());
1022/// let watchdog = Arc::new(MemoryWatchdog::new(
1023///     device,
1024///     512 * 1024 * 1024, // pause when <512 MiB free
1025///     Duration::from_secs(1),
1026/// ));
1027/// let handle = Arc::clone(&watchdog).start();
1028///
1029/// // In training loop:
1030/// watchdog.wait_if_paused();
1031/// ```
1032pub struct MemoryWatchdog {
1033    device: Arc<GpuDevice>,
1034    /// Minimum free bytes before we pause.
1035    pressure_threshold_bytes: usize,
1036    /// How often to poll the driver.
1037    check_interval: Duration,
1038    /// Whether training is currently paused due to memory pressure.
1039    paused: AtomicBool,
1040    /// Signal to stop the background thread.
1041    stop: AtomicBool,
1042    /// Set to `true` after the first check cycle completes.
1043    has_checked: AtomicBool,
1044}
1045
1046impl MemoryWatchdog {
1047    /// Create a new watchdog. Does not start monitoring until [`start`](Self::start)
1048    /// is called.
1049    pub fn new(
1050        device: Arc<GpuDevice>,
1051        pressure_threshold_bytes: usize,
1052        check_interval: Duration,
1053    ) -> Self {
1054        Self {
1055            device,
1056            pressure_threshold_bytes,
1057            check_interval,
1058            paused: AtomicBool::new(false),
1059            stop: AtomicBool::new(false),
1060            has_checked: AtomicBool::new(false),
1061        }
1062    }
1063
1064    /// Start the monitoring thread. Returns a `JoinHandle` that can be used
1065    /// to wait for shutdown (after calling [`stop`](Self::stop)).
1066    pub fn start(self: Arc<Self>) -> JoinHandle<()> {
1067        std::thread::Builder::new()
1068            .name("ferrotorch-memory-watchdog".into())
1069            .spawn(move || {
1070                while !self.stop.load(Ordering::Relaxed) {
1071                    let free = self.query_free_memory();
1072                    if free < self.pressure_threshold_bytes {
1073                        self.paused.store(true, Ordering::SeqCst);
1074                        // Spin until memory is available or we are told to stop.
1075                        while self.query_free_memory() < self.pressure_threshold_bytes {
1076                            if self.stop.load(Ordering::Relaxed) {
1077                                return;
1078                            }
1079                            std::thread::sleep(Duration::from_millis(500));
1080                        }
1081                        self.paused.store(false, Ordering::SeqCst);
1082                    }
1083                    self.has_checked.store(true, Ordering::SeqCst);
1084                    std::thread::sleep(self.check_interval);
1085                }
1086            })
1087            .expect("failed to spawn memory watchdog thread")
1088    }
1089
1090    /// Signal the background thread to exit.
1091    pub fn stop(&self) {
1092        self.stop.store(true, Ordering::SeqCst);
1093    }
1094
1095    /// Returns `true` if the watchdog currently has training paused due to
1096    /// memory pressure.
1097    #[inline]
1098    pub fn check_pressure(&self) -> bool {
1099        self.paused.load(Ordering::SeqCst)
1100    }
1101
1102    /// Block the calling thread until memory pressure is resolved.
1103    /// Call this between training batches.
1104    pub fn wait_if_paused(&self) {
1105        while self.paused.load(Ordering::SeqCst) {
1106            std::thread::sleep(Duration::from_millis(100));
1107        }
1108    }
1109
1110    /// Block until the watchdog has completed at least one check cycle.
1111    /// Useful in tests to avoid timing races.
1112    pub fn wait_for_first_check(&self, timeout: Duration) {
1113        let start = std::time::Instant::now();
1114        while !self.has_checked.load(Ordering::SeqCst) {
1115            if start.elapsed() > timeout {
1116                return;
1117            }
1118            std::thread::sleep(Duration::from_millis(5));
1119        }
1120    }
1121
1122    /// The pressure threshold in bytes.
1123    #[inline]
1124    pub fn pressure_threshold_bytes(&self) -> usize {
1125        self.pressure_threshold_bytes
1126    }
1127
1128    /// Query the amount of free device memory.
1129    ///
1130    /// Binds the device's CUDA context on the current thread before querying,
1131    /// so this is safe to call from the watchdog's background thread.
1132    fn query_free_memory(&self) -> usize {
1133        #[cfg(feature = "cuda")]
1134        {
1135            // Bind the CUDA context on this thread so mem_get_info works.
1136            let ctx = self.device.context();
1137            let _ = ctx.bind_to_thread();
1138            cudarc::driver::result::mem_get_info()
1139                .map(|(free, _)| free)
1140                .unwrap_or(0)
1141        }
1142        #[cfg(not(feature = "cuda"))]
1143        {
1144            let _ = &self.device;
1145            0
1146        }
1147    }
1148}
1149
1150impl std::fmt::Debug for MemoryWatchdog {
1151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1152        f.debug_struct("MemoryWatchdog")
1153            .field("device_ordinal", &self.device.ordinal())
1154            .field("pressure_threshold_bytes", &self.pressure_threshold_bytes)
1155            .field("check_interval", &self.check_interval)
1156            .field("paused", &self.paused.load(Ordering::Relaxed))
1157            .finish()
1158    }
1159}
1160
1161// ---------------------------------------------------------------------------
1162// GpuDevice extensions
1163// ---------------------------------------------------------------------------
1164
1165impl GpuDevice {
1166    /// Query free and total GPU memory for this device.
1167    ///
1168    /// Returns `(free_bytes, total_bytes)`.
1169    ///
1170    /// # Errors
1171    ///
1172    /// Returns [`GpuError::Driver`] if the CUDA driver call fails.
1173    #[cfg(feature = "cuda")]
1174    pub fn memory_info(&self) -> GpuResult<(usize, usize)> {
1175        // cuMemGetInfo operates on the current context, so we need to ensure
1176        // this device's context is bound. The cudarc CudaContext does this
1177        // internally for allocations, but mem_get_info is a free function.
1178        // Binding is handled by the caller having created the device.
1179        let info = cudarc::driver::result::mem_get_info()?;
1180        Ok(info)
1181    }
1182
1183    /// Query free and total GPU memory — stub when `cuda` is disabled.
1184    #[cfg(not(feature = "cuda"))]
1185    pub fn memory_info(&self) -> GpuResult<(usize, usize)> {
1186        Err(GpuError::NoCudaFeature)
1187    }
1188}
1189
1190/// A [`GpuDevice`] paired with a [`MemoryGuard`] for convenient use.
1191///
1192/// Created by [`GpuDevice::with_memory_guard`].
1193pub struct MemoryGuardedDevice {
1194    /// The memory guard managing allocations.
1195    pub guard: MemoryGuard,
1196}
1197
1198impl MemoryGuardedDevice {
1199    /// Access the underlying device.
1200    #[inline]
1201    pub fn device(&self) -> &GpuDevice {
1202        self.guard.device()
1203    }
1204
1205    /// Access the memory guard.
1206    #[inline]
1207    pub fn guard(&self) -> &MemoryGuard {
1208        &self.guard
1209    }
1210}
1211
1212impl std::fmt::Debug for MemoryGuardedDevice {
1213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1214        f.debug_struct("MemoryGuardedDevice")
1215            .field("guard", &self.guard)
1216            .finish()
1217    }
1218}
1219
1220// ---------------------------------------------------------------------------
1221// Tests
1222// ---------------------------------------------------------------------------
1223
1224#[cfg(test)]
1225mod tests {
1226    use super::*;
1227
1228    // ---------------------------------------------------------------
1229    // Unit tests (no GPU required)
1230    // ---------------------------------------------------------------
1231
1232    #[test]
1233    fn oom_policy_default_is_fail() {
1234        assert_eq!(OomPolicy::default(), OomPolicy::Fail);
1235    }
1236
1237    #[test]
1238    fn oom_policy_debug() {
1239        let p = OomPolicy::WaitAndRetry { timeout_secs: 30 };
1240        let s = format!("{p:?}");
1241        assert!(s.contains("WaitAndRetry"));
1242        assert!(s.contains("30"));
1243    }
1244
1245    #[test]
1246    fn memory_stats_clone_eq() {
1247        let s = MemoryStats {
1248            used_bytes: 100,
1249            budget_bytes: 1000,
1250            peak_bytes: 200,
1251            free_device_bytes: 800,
1252            total_device_bytes: 2000,
1253            num_allocations: 5,
1254            num_oom_recoveries: 1,
1255        };
1256        let s2 = s.clone();
1257        assert_eq!(s, s2);
1258    }
1259
1260    #[test]
1261    fn memory_stats_debug() {
1262        let s = MemoryStats {
1263            used_bytes: 0,
1264            budget_bytes: 0,
1265            peak_bytes: 0,
1266            free_device_bytes: 0,
1267            total_device_bytes: 0,
1268            num_allocations: 0,
1269            num_oom_recoveries: 0,
1270        };
1271        let d = format!("{s:?}");
1272        assert!(d.contains("MemoryStats"));
1273        assert!(d.contains("used_bytes"));
1274    }
1275
1276    #[test]
1277    fn gpu_error_out_of_memory_display() {
1278        let e = GpuError::OutOfMemory {
1279            requested_bytes: 1024,
1280            free_bytes: 512,
1281        };
1282        let s = format!("{e}");
1283        assert!(s.contains("1024"));
1284        assert!(s.contains("512"));
1285        assert!(s.contains("out of memory"));
1286    }
1287
1288    #[test]
1289    fn gpu_error_budget_exceeded_display() {
1290        let e = GpuError::BudgetExceeded {
1291            requested_bytes: 500,
1292            budget_bytes: 1000,
1293            used_bytes: 800,
1294        };
1295        let s = format!("{e}");
1296        assert!(s.contains("500"));
1297        assert!(s.contains("1000"));
1298        assert!(s.contains("800"));
1299        assert!(s.contains("budget exceeded"));
1300    }
1301
1302    // ---------------------------------------------------------------
1303    // Pre-OOM hooks & pressure unit tests (no GPU required)
1304    // ---------------------------------------------------------------
1305
1306    #[test]
1307    fn pressure_level_ordering() {
1308        assert!(PressureLevel::None < PressureLevel::Low);
1309        assert!(PressureLevel::Low < PressureLevel::Medium);
1310        assert!(PressureLevel::Medium < PressureLevel::High);
1311        assert!(PressureLevel::High < PressureLevel::Critical);
1312    }
1313
1314    #[test]
1315    fn pressure_level_display() {
1316        assert_eq!(format!("{}", PressureLevel::None), "none");
1317        assert_eq!(format!("{}", PressureLevel::Critical), "critical");
1318    }
1319
1320    #[test]
1321    fn pressure_level_debug_clone_eq() {
1322        let p = PressureLevel::Medium;
1323        let p2 = p;
1324        assert_eq!(p, p2);
1325        let s = format!("{p:?}");
1326        assert!(s.contains("Medium"));
1327    }
1328
1329    #[test]
1330    fn compute_pressure_unlimited_budget_is_none() {
1331        // budget=0 means unlimited, should always be None.
1332        assert_eq!(MemoryGuard::compute_pressure(0, 0), PressureLevel::None);
1333    }
1334
1335    #[test]
1336    fn compute_pressure_thresholds() {
1337        let budget = 1000;
1338        // >30% free => None
1339        assert_eq!(
1340            MemoryGuard::compute_pressure(budget, 0),
1341            PressureLevel::None
1342        );
1343        assert_eq!(
1344            MemoryGuard::compute_pressure(budget, 600),
1345            PressureLevel::None
1346        );
1347        assert_eq!(
1348            MemoryGuard::compute_pressure(budget, 699),
1349            PressureLevel::None
1350        );
1351        // 10-30% free => Low
1352        assert_eq!(
1353            MemoryGuard::compute_pressure(budget, 750),
1354            PressureLevel::Low
1355        );
1356        assert_eq!(
1357            MemoryGuard::compute_pressure(budget, 890),
1358            PressureLevel::Low
1359        );
1360        // 5-10% free => Medium
1361        assert_eq!(
1362            MemoryGuard::compute_pressure(budget, 910),
1363            PressureLevel::Medium
1364        );
1365        assert_eq!(
1366            MemoryGuard::compute_pressure(budget, 949),
1367            PressureLevel::Medium
1368        );
1369        // <5% free => High
1370        assert_eq!(
1371            MemoryGuard::compute_pressure(budget, 960),
1372            PressureLevel::High
1373        );
1374        assert_eq!(
1375            MemoryGuard::compute_pressure(budget, 999),
1376            PressureLevel::High
1377        );
1378        // At or over budget => Critical
1379        assert_eq!(
1380            MemoryGuard::compute_pressure(budget, 1000),
1381            PressureLevel::Critical
1382        );
1383        assert_eq!(
1384            MemoryGuard::compute_pressure(budget, 2000),
1385            PressureLevel::Critical
1386        );
1387    }
1388
1389    #[test]
1390    fn memory_hook_debug() {
1391        let hook = MemoryHook {
1392            name: "test_hook".into(),
1393            estimated_free_bytes: 1024,
1394            execution_overhead_bytes: 64,
1395            priority: 5,
1396            callback: Box::new(|| 1024),
1397        };
1398        let s = format!("{hook:?}");
1399        assert!(s.contains("test_hook"));
1400        assert!(s.contains("1024"));
1401        assert!(s.contains("64"));
1402        assert!(s.contains("5"));
1403    }
1404
1405    // ---------------------------------------------------------------
1406    // GPU tests (require `cuda` feature and a real device)
1407    // ---------------------------------------------------------------
1408
1409    #[cfg(feature = "cuda")]
1410    mod gpu_tests {
1411        use super::*;
1412
1413        fn make_device() -> Arc<GpuDevice> {
1414            Arc::new(GpuDevice::new(0).expect("CUDA device 0"))
1415        }
1416
1417        #[test]
1418        fn guard_construction_and_stats() {
1419            let device = make_device();
1420            let guard = MemoryGuardBuilder::new(device)
1421                .budget_bytes(1024 * 1024 * 1024) // 1 GiB
1422                .oom_policy(OomPolicy::Fail)
1423                .build()
1424                .expect("build guard");
1425
1426            let stats = guard.stats();
1427            assert_eq!(stats.used_bytes, 0);
1428            assert_eq!(stats.budget_bytes, 1024 * 1024 * 1024);
1429            assert_eq!(stats.peak_bytes, 0);
1430            assert_eq!(stats.num_allocations, 0);
1431            assert_eq!(stats.num_oom_recoveries, 0);
1432            assert!(stats.total_device_bytes > 0);
1433            assert!(stats.free_device_bytes > 0);
1434        }
1435
1436        #[test]
1437        fn budget_enforcement_rejects_over_budget() {
1438            let device = make_device();
1439            let guard = MemoryGuardBuilder::new(device)
1440                .budget_bytes(256) // tiny budget: 256 bytes
1441                .build()
1442                .expect("build guard");
1443
1444            // Try to allocate way more than the budget.
1445            let result = guard.safe_alloc::<f32>(1024); // 4096 bytes
1446            assert!(result.is_err());
1447            match result.unwrap_err() {
1448                GpuError::BudgetExceeded {
1449                    requested_bytes,
1450                    budget_bytes,
1451                    used_bytes,
1452                } => {
1453                    assert_eq!(requested_bytes, 1024 * 4);
1454                    assert_eq!(budget_bytes, 256);
1455                    assert_eq!(used_bytes, 0);
1456                }
1457                other => panic!("expected BudgetExceeded, got {other:?}"),
1458            }
1459        }
1460
1461        #[test]
1462        fn safe_alloc_tracks_usage() {
1463            let device = make_device();
1464            let guard = MemoryGuardBuilder::new(device)
1465                .budget_bytes(0) // unlimited
1466                .build()
1467                .expect("build guard");
1468
1469            let buf = guard.safe_alloc::<f32>(256).expect("alloc 256 f32");
1470            let expected = 256 * std::mem::size_of::<f32>();
1471
1472            let stats = guard.stats();
1473            assert_eq!(stats.used_bytes, expected);
1474            assert_eq!(stats.peak_bytes, expected);
1475            assert_eq!(stats.num_allocations, 1);
1476
1477            guard.free(buf);
1478
1479            let stats = guard.stats();
1480            assert_eq!(stats.used_bytes, 0);
1481            assert_eq!(stats.num_allocations, 0);
1482            // Peak should still be high.
1483            assert_eq!(stats.peak_bytes, expected);
1484        }
1485
1486        #[test]
1487        fn safe_alloc_copy_tracks_usage() {
1488            let device = make_device();
1489            let guard = MemoryGuardBuilder::new(device)
1490                .build()
1491                .expect("build guard");
1492
1493            let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
1494            let buf = guard.safe_alloc_copy(&data).expect("alloc_copy");
1495            let expected = 4 * std::mem::size_of::<f64>();
1496
1497            assert_eq!(guard.stats().used_bytes, expected);
1498            guard.free(buf);
1499            assert_eq!(guard.stats().used_bytes, 0);
1500        }
1501
1502        #[test]
1503        fn reset_peak_stats_works() {
1504            let device = make_device();
1505            let guard = MemoryGuardBuilder::new(device)
1506                .build()
1507                .expect("build guard");
1508
1509            let buf = guard.safe_alloc::<f32>(512).expect("alloc");
1510            let peak = guard.stats().peak_bytes;
1511            assert!(peak > 0);
1512
1513            guard.free(buf);
1514            assert_eq!(guard.stats().peak_bytes, peak); // still high
1515
1516            guard.reset_peak_stats();
1517            assert_eq!(guard.stats().peak_bytes, 0); // reset to current (0)
1518        }
1519
1520        #[test]
1521        fn emergency_checkpoint_callback_invoked() {
1522            let device = make_device();
1523            let guard = MemoryGuardBuilder::new(device)
1524                .build()
1525                .expect("build guard");
1526
1527            let called = Arc::new(AtomicBool::new(false));
1528            let called_clone = Arc::clone(&called);
1529            guard.on_oom(move || {
1530                called_clone.store(true, Ordering::SeqCst);
1531            });
1532
1533            // Directly invoke the internal method to test the callback.
1534            guard.trigger_emergency_checkpoint();
1535            assert!(called.load(Ordering::SeqCst));
1536        }
1537
1538        #[test]
1539        fn set_budget_at_runtime() {
1540            let device = make_device();
1541            let guard = MemoryGuardBuilder::new(device)
1542                .budget_bytes(0)
1543                .build()
1544                .expect("build guard");
1545
1546            assert_eq!(guard.budget(), 0);
1547
1548            guard.set_budget(1024);
1549            assert_eq!(guard.budget(), 1024);
1550
1551            // Now an allocation over 1024 bytes should fail.
1552            let result = guard.safe_alloc::<f32>(1024); // 4096 bytes > 1024 budget
1553            assert!(result.is_err());
1554        }
1555
1556        #[test]
1557        fn memory_info_returns_nonzero() {
1558            let device = GpuDevice::new(0).expect("CUDA device 0");
1559            let (free, total) = device.memory_info().expect("memory_info");
1560            assert!(total > 0, "total device memory should be > 0");
1561            assert!(free > 0, "free device memory should be > 0");
1562            assert!(free <= total, "free should not exceed total");
1563        }
1564
1565        #[test]
1566        fn reservation_holds_memory() {
1567            let device = make_device();
1568            let (free_before, _) = device.memory_info().expect("memory_info");
1569
1570            // Reserve 64 MiB.
1571            let reserve_bytes = 64 * 1024 * 1024;
1572            let guard = MemoryGuardBuilder::new(device)
1573                .reserve_bytes(reserve_bytes)
1574                .build()
1575                .expect("build guard with reservation");
1576
1577            assert!(guard.has_reservation());
1578
1579            // Release the reservation.
1580            let released = guard.release_reservation();
1581            assert_eq!(released, reserve_bytes);
1582            assert!(!guard.has_reservation());
1583
1584            // Releasing again returns 0.
1585            assert_eq!(guard.release_reservation(), 0);
1586
1587            let _ = free_before; // suppress unused warning
1588        }
1589
1590        #[test]
1591        fn guard_debug_impl() {
1592            let device = make_device();
1593            let guard = MemoryGuardBuilder::new(device)
1594                .budget_bytes(999)
1595                .build()
1596                .expect("build guard");
1597
1598            let s = format!("{guard:?}");
1599            assert!(s.contains("MemoryGuard"));
1600            assert!(s.contains("budget_bytes"));
1601            assert!(s.contains("999"));
1602        }
1603
1604        #[test]
1605        fn watchdog_detects_no_pressure_when_plenty_free() {
1606            let device = make_device();
1607            // Threshold of 1 byte — should never trigger pressure.
1608            let watchdog = Arc::new(MemoryWatchdog::new(device, 1, Duration::from_millis(50)));
1609
1610            assert!(!watchdog.check_pressure());
1611            watchdog.wait_if_paused(); // should return immediately
1612
1613            // Start watchdog and wait for it to complete at least one cycle.
1614            let wd = Arc::clone(&watchdog);
1615            let handle = wd.start();
1616            watchdog.wait_for_first_check(Duration::from_secs(5));
1617            assert!(!watchdog.check_pressure());
1618            watchdog.stop();
1619            handle.join().expect("watchdog thread");
1620        }
1621
1622        #[test]
1623        fn watchdog_debug_impl() {
1624            let device = make_device();
1625            let watchdog = MemoryWatchdog::new(device, 1024, Duration::from_secs(1));
1626            let s = format!("{watchdog:?}");
1627            assert!(s.contains("MemoryWatchdog"));
1628            assert!(s.contains("1024"));
1629        }
1630
1631        #[test]
1632        fn memory_guarded_device() {
1633            let device = make_device();
1634            let guard = MemoryGuardBuilder::new(Arc::clone(&device))
1635                .budget_bytes(1024 * 1024)
1636                .build()
1637                .expect("build guard");
1638
1639            let guarded = MemoryGuardedDevice { guard };
1640            assert_eq!(guarded.device().ordinal(), 0);
1641            assert_eq!(guarded.guard().budget(), 1024 * 1024);
1642
1643            let s = format!("{guarded:?}");
1644            assert!(s.contains("MemoryGuardedDevice"));
1645        }
1646
1647        #[test]
1648        fn oom_policy_retry_after_free() {
1649            // This test verifies the RetryAfterFree policy increments the
1650            // recovery counter. We cannot easily force a real OOM in a unit
1651            // test, so we verify the policy is stored correctly and the
1652            // counter machinery works.
1653            let device = make_device();
1654            let guard = MemoryGuardBuilder::new(device)
1655                .oom_policy(OomPolicy::RetryAfterFree)
1656                .build()
1657                .expect("build guard");
1658
1659            // With RetryAfterFree and no actual OOM, allocation succeeds
1660            // on the first attempt — recovery counter stays at 0.
1661            let buf = guard.safe_alloc::<f32>(64).expect("alloc");
1662            assert_eq!(guard.stats().num_oom_recoveries, 0);
1663            guard.free(buf);
1664        }
1665
1666        #[test]
1667        fn multiple_allocations_budget_accounting() {
1668            let device = make_device();
1669            let budget = 2048_usize;
1670            let guard = MemoryGuardBuilder::new(device)
1671                .budget_bytes(budget)
1672                .build()
1673                .expect("build guard");
1674
1675            // First alloc: 128 f32 = 512 bytes. Should succeed.
1676            let buf1 = guard.safe_alloc::<f32>(128).expect("alloc 1");
1677            assert_eq!(guard.stats().used_bytes, 512);
1678            assert_eq!(guard.stats().num_allocations, 1);
1679
1680            // Second alloc: 128 f32 = 512 bytes. Total = 1024. Should succeed.
1681            let buf2 = guard.safe_alloc::<f32>(128).expect("alloc 2");
1682            assert_eq!(guard.stats().used_bytes, 1024);
1683            assert_eq!(guard.stats().num_allocations, 2);
1684
1685            // Third alloc: 512 f32 = 2048 bytes. Total would be 3072 > 2048. Should fail.
1686            let result = guard.safe_alloc::<f32>(512);
1687            assert!(result.is_err());
1688
1689            // Free buf1, then the third alloc should succeed (1024 + 512 < 2048? no, 512 + 2048 = 2560)
1690            // Actually 1024-512 = 512 used, then 512 + 2048 = 2560 > 2048. Still too big.
1691            // Let's free both and try a fitting alloc.
1692            guard.free(buf1);
1693            guard.free(buf2);
1694            assert_eq!(guard.stats().used_bytes, 0);
1695
1696            let buf3 = guard.safe_alloc::<f32>(512).expect("alloc 3 after free");
1697            assert_eq!(guard.stats().used_bytes, 2048);
1698            guard.free(buf3);
1699        }
1700
1701        // ---------------------------------------------------------------
1702        // Pre-OOM hooks GPU tests
1703        // ---------------------------------------------------------------
1704
1705        #[test]
1706        fn hook_called_on_budget_exceeded() {
1707            let device = make_device();
1708            let budget = 1024_usize; // 1024 bytes
1709            let guard = MemoryGuardBuilder::new(device)
1710                .budget_bytes(budget)
1711                .build()
1712                .expect("build guard");
1713
1714            let called = Arc::new(AtomicBool::new(false));
1715            let called_clone = Arc::clone(&called);
1716
1717            guard.register_hook(MemoryHook {
1718                name: "test_hook".into(),
1719                estimated_free_bytes: 2048,
1720                execution_overhead_bytes: 0,
1721                priority: 10,
1722                callback: Box::new(move || {
1723                    called_clone.store(true, Ordering::SeqCst);
1724                    0 // does not actually free tracked memory
1725                }),
1726            });
1727
1728            // Allocation of 512 f32 = 2048 bytes > budget of 1024.
1729            // Hook will be called but won't free enough, so alloc falls
1730            // through. The hook should still have been invoked.
1731            let _result = guard.safe_alloc_with_hooks::<f32>(512);
1732            assert!(called.load(Ordering::SeqCst), "hook was not called");
1733        }
1734
1735        #[test]
1736        fn hook_frees_enough_memory_allocation_succeeds() {
1737            let device = make_device();
1738            // Budget: 2048. Pre-fill 1536 bytes, then request 1024 bytes
1739            // (total would be 2560 > 2048). Hook frees 1024 bytes.
1740            let budget = 2048_usize;
1741            let guard = MemoryGuardBuilder::new(device)
1742                .budget_bytes(budget)
1743                .build()
1744                .expect("build guard");
1745
1746            // Pre-fill: 384 f32 = 1536 bytes.
1747            let prefill = guard.safe_alloc::<f32>(384).expect("prefill");
1748            assert_eq!(guard.stats().used_bytes, 1536);
1749
1750            let called = Arc::new(AtomicBool::new(false));
1751            let called_clone = Arc::clone(&called);
1752
1753            // Hook "frees" 1024 bytes by adjusting the guard's tracked usage.
1754            // In a real scenario the hook would drop GPU buffers. Here we
1755            // simulate by having the hook report 1024 freed; run_hooks
1756            // subtracts that from used_bytes.
1757            guard.register_hook(MemoryHook {
1758                name: "free_1k".into(),
1759                estimated_free_bytes: 1024,
1760                execution_overhead_bytes: 0,
1761                priority: 10,
1762                callback: Box::new(move || {
1763                    called_clone.store(true, Ordering::SeqCst);
1764                    1024
1765                }),
1766            });
1767
1768            // Request 256 f32 = 1024 bytes. 1536 + 1024 = 2560 > 2048.
1769            // Shortfall = 512. Hook frees 1024 (enough).
1770            let buf = guard
1771                .safe_alloc_with_hooks::<f32>(256)
1772                .expect("alloc after hook");
1773            assert!(called.load(Ordering::SeqCst), "hook was not called");
1774
1775            // used_bytes: was 1536, hook freed 1024 => 512, then alloc adds 1024 => 1536.
1776            assert_eq!(guard.stats().used_bytes, 1536);
1777
1778            guard.free(buf);
1779            guard.free(prefill);
1780        }
1781
1782        #[test]
1783        fn hook_not_enough_falls_through_to_oom_policy() {
1784            let device = make_device();
1785            let budget = 512_usize;
1786            let guard = MemoryGuardBuilder::new(device)
1787                .budget_bytes(budget)
1788                .oom_policy(OomPolicy::Fail)
1789                .build()
1790                .expect("build guard");
1791
1792            let called = Arc::new(AtomicBool::new(false));
1793            let called_clone = Arc::clone(&called);
1794
1795            guard.register_hook(MemoryHook {
1796                name: "weak_hook".into(),
1797                estimated_free_bytes: 64,
1798                execution_overhead_bytes: 0,
1799                priority: 10,
1800                callback: Box::new(move || {
1801                    called_clone.store(true, Ordering::SeqCst);
1802                    64 // only frees 64 bytes
1803                }),
1804            });
1805
1806            // Request 1024 f32 = 4096 bytes >> 512 budget.
1807            // Hook frees 64, still not enough, falls through to OomPolicy::Fail.
1808            let result = guard.safe_alloc_with_hooks::<f32>(1024);
1809            assert!(
1810                called.load(Ordering::SeqCst),
1811                "hook should have been called"
1812            );
1813            assert!(result.is_err(), "allocation should have failed");
1814        }
1815
1816        #[test]
1817        fn hooks_called_in_priority_order() {
1818            let device = make_device();
1819            let budget = 1024_usize;
1820            let guard = MemoryGuardBuilder::new(device)
1821                .budget_bytes(budget)
1822                .build()
1823                .expect("build guard");
1824
1825            let order = Arc::new(Mutex::new(Vec::new()));
1826
1827            let o1 = Arc::clone(&order);
1828            guard.register_hook(MemoryHook {
1829                name: "priority_20".into(),
1830                estimated_free_bytes: 256,
1831                execution_overhead_bytes: 0,
1832                priority: 20,
1833                callback: Box::new(move || {
1834                    o1.lock().unwrap().push(20_u32);
1835                    256
1836                }),
1837            });
1838
1839            let o2 = Arc::clone(&order);
1840            guard.register_hook(MemoryHook {
1841                name: "priority_5".into(),
1842                estimated_free_bytes: 256,
1843                execution_overhead_bytes: 0,
1844                priority: 5,
1845                callback: Box::new(move || {
1846                    o2.lock().unwrap().push(5_u32);
1847                    256
1848                }),
1849            });
1850
1851            let o3 = Arc::clone(&order);
1852            guard.register_hook(MemoryHook {
1853                name: "priority_10".into(),
1854                estimated_free_bytes: 256,
1855                execution_overhead_bytes: 0,
1856                priority: 10,
1857                callback: Box::new(move || {
1858                    o3.lock().unwrap().push(10_u32);
1859                    256
1860                }),
1861            });
1862
1863            // Request 512 f32 = 2048 bytes > 1024 budget.
1864            // All three hooks needed. Should fire: 5, 10, 20.
1865            let _result = guard.safe_alloc_with_hooks::<f32>(512);
1866            let call_order = order.lock().unwrap();
1867            assert_eq!(
1868                &*call_order,
1869                &[5, 10, 20],
1870                "hooks should fire in priority order"
1871            );
1872        }
1873
1874        #[test]
1875        fn remove_hook_by_name() {
1876            let device = make_device();
1877            let guard = MemoryGuardBuilder::new(device)
1878                .budget_bytes(1024)
1879                .build()
1880                .expect("build guard");
1881
1882            let called = Arc::new(AtomicBool::new(false));
1883            let called_clone = Arc::clone(&called);
1884
1885            guard.register_hook(MemoryHook {
1886                name: "removable".into(),
1887                estimated_free_bytes: 2048,
1888                execution_overhead_bytes: 0,
1889                priority: 10,
1890                callback: Box::new(move || {
1891                    called_clone.store(true, Ordering::SeqCst);
1892                    2048
1893                }),
1894            });
1895
1896            // Remove the hook.
1897            assert!(guard.remove_hook("removable"));
1898            // Removing again returns false.
1899            assert!(!guard.remove_hook("removable"));
1900
1901            // Trigger an over-budget allocation; removed hook should NOT fire.
1902            let _result = guard.safe_alloc_with_hooks::<f32>(512);
1903            assert!(
1904                !called.load(Ordering::SeqCst),
1905                "removed hook should not have been called"
1906            );
1907        }
1908
1909        #[test]
1910        fn pressure_level_tracks_usage() {
1911            let device = make_device();
1912            let budget = 1000_usize;
1913            let guard = MemoryGuardBuilder::new(device)
1914                .budget_bytes(budget)
1915                .build()
1916                .expect("build guard");
1917
1918            // No usage => None.
1919            assert_eq!(guard.pressure_level(), PressureLevel::None);
1920
1921            // Manually bump used_bytes to 750 => 25% free => Low.
1922            guard.used_bytes.store(750, Ordering::Relaxed);
1923            assert_eq!(guard.pressure_level(), PressureLevel::Low);
1924
1925            // 920 used => 8% free => Medium.
1926            guard.used_bytes.store(920, Ordering::Relaxed);
1927            assert_eq!(guard.pressure_level(), PressureLevel::Medium);
1928
1929            // 960 used => 4% free => High.
1930            guard.used_bytes.store(960, Ordering::Relaxed);
1931            assert_eq!(guard.pressure_level(), PressureLevel::High);
1932
1933            // At budget => Critical.
1934            guard.used_bytes.store(1000, Ordering::Relaxed);
1935            assert_eq!(guard.pressure_level(), PressureLevel::Critical);
1936
1937            // Unlimited budget => always None regardless of usage.
1938            guard.set_budget(0);
1939            assert_eq!(guard.pressure_level(), PressureLevel::None);
1940        }
1941
1942        #[test]
1943        fn multiple_hooks_called_until_enough_freed() {
1944            let device = make_device();
1945            let budget = 2048_usize;
1946            let guard = MemoryGuardBuilder::new(device)
1947                .budget_bytes(budget)
1948                .build()
1949                .expect("build guard");
1950
1951            // Pre-fill: 256 f32 = 1024 bytes.
1952            let prefill = guard.safe_alloc::<f32>(256).expect("prefill");
1953            assert_eq!(guard.stats().used_bytes, 1024);
1954
1955            let count = Arc::new(AtomicUsize::new(0));
1956
1957            // Hook A: priority 1, frees 256 bytes.
1958            let c1 = Arc::clone(&count);
1959            guard.register_hook(MemoryHook {
1960                name: "hook_a".into(),
1961                estimated_free_bytes: 256,
1962                execution_overhead_bytes: 0,
1963                priority: 1,
1964                callback: Box::new(move || {
1965                    c1.fetch_add(1, Ordering::SeqCst);
1966                    256
1967                }),
1968            });
1969
1970            // Hook B: priority 2, frees 512 bytes.
1971            let c2 = Arc::clone(&count);
1972            guard.register_hook(MemoryHook {
1973                name: "hook_b".into(),
1974                estimated_free_bytes: 512,
1975                execution_overhead_bytes: 0,
1976                priority: 2,
1977                callback: Box::new(move || {
1978                    c2.fetch_add(1, Ordering::SeqCst);
1979                    512
1980                }),
1981            });
1982
1983            // Hook C: priority 3, frees 512 bytes. Should NOT be called if
1984            // A+B free enough.
1985            let c3 = Arc::new(AtomicBool::new(false));
1986            let c3_clone = Arc::clone(&c3);
1987            guard.register_hook(MemoryHook {
1988                name: "hook_c".into(),
1989                estimated_free_bytes: 512,
1990                execution_overhead_bytes: 0,
1991                priority: 3,
1992                callback: Box::new(move || {
1993                    c3_clone.store(true, Ordering::SeqCst);
1994                    512
1995                }),
1996            });
1997
1998            // Request 384 f32 = 1536 bytes. 1024 + 1536 = 2560 > 2048.
1999            // Shortfall = 512. Hook A frees 256 (not enough), Hook B frees
2000            // 512 (now 768 >= 512). Hook C should be skipped.
2001            let buf = guard
2002                .safe_alloc_with_hooks::<f32>(384)
2003                .expect("alloc with hooks");
2004            assert_eq!(count.load(Ordering::SeqCst), 2, "hooks A and B should fire");
2005            assert!(
2006                !c3.load(Ordering::SeqCst),
2007                "hook C should not have been called"
2008            );
2009
2010            guard.free(buf);
2011            guard.free(prefill);
2012        }
2013
2014        #[test]
2015        fn hook_with_excessive_overhead_is_skipped() {
2016            let device = make_device();
2017            let budget = 2048_usize;
2018            let guard = MemoryGuardBuilder::new(device)
2019                .budget_bytes(budget)
2020                .build()
2021                .expect("build guard");
2022
2023            // Pre-fill: 480 f32 = 1920 bytes. Headroom = 128 bytes.
2024            let prefill = guard.safe_alloc::<f32>(480).expect("prefill");
2025            assert_eq!(guard.stats().used_bytes, 1920);
2026
2027            let expensive_called = Arc::new(AtomicBool::new(false));
2028            let expensive_clone = Arc::clone(&expensive_called);
2029
2030            // Hook with overhead of 256 > headroom of 128 => should be skipped.
2031            guard.register_hook(MemoryHook {
2032                name: "expensive_hook".into(),
2033                estimated_free_bytes: 1024,
2034                execution_overhead_bytes: 256,
2035                priority: 1,
2036                callback: Box::new(move || {
2037                    expensive_clone.store(true, Ordering::SeqCst);
2038                    1024
2039                }),
2040            });
2041
2042            let cheap_called = Arc::new(AtomicBool::new(false));
2043            let cheap_clone = Arc::clone(&cheap_called);
2044
2045            // Hook with zero overhead => should fire.
2046            guard.register_hook(MemoryHook {
2047                name: "cheap_hook".into(),
2048                estimated_free_bytes: 512,
2049                execution_overhead_bytes: 0,
2050                priority: 2,
2051                callback: Box::new(move || {
2052                    cheap_clone.store(true, Ordering::SeqCst);
2053                    512
2054                }),
2055            });
2056
2057            // Request 64 f32 = 256 bytes. 1920 + 256 = 2176 > 2048.
2058            // Shortfall = 128.
2059            let buf = guard
2060                .safe_alloc_with_hooks::<f32>(64)
2061                .expect("alloc with hooks");
2062
2063            assert!(
2064                !expensive_called.load(Ordering::SeqCst),
2065                "expensive hook should have been skipped due to overhead"
2066            );
2067            assert!(
2068                cheap_called.load(Ordering::SeqCst),
2069                "cheap hook should have been called"
2070            );
2071
2072            guard.free(buf);
2073            guard.free(prefill);
2074        }
2075
2076        #[test]
2077        fn pressure_listener_notified_on_change() {
2078            let device = make_device();
2079            let budget = 1000_usize;
2080            let guard = MemoryGuardBuilder::new(device)
2081                .budget_bytes(budget)
2082                .build()
2083                .expect("build guard");
2084
2085            struct TestListener {
2086                changes: Mutex<Vec<(PressureLevel, PressureLevel)>>,
2087            }
2088
2089            impl MemoryPressureListener for TestListener {
2090                fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel) {
2091                    self.changes.lock().unwrap().push((old, new));
2092                }
2093            }
2094
2095            let listener = Arc::new(TestListener {
2096                changes: Mutex::new(Vec::new()),
2097            });
2098            let listener_ref = Arc::clone(&listener);
2099
2100            // Wrap in a Box<dyn MemoryPressureListener> — we need to share
2101            // the Arc for assertions, so we use a thin wrapper.
2102            struct ListenerWrapper(Arc<TestListener>);
2103            impl MemoryPressureListener for ListenerWrapper {
2104                fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel) {
2105                    self.0.on_pressure_change(old, new);
2106                }
2107            }
2108
2109            guard.add_pressure_listener(Box::new(ListenerWrapper(listener_ref)));
2110
2111            // Allocate a small amount. Pressure should stay None, no
2112            // notification expected (None -> None is not a change).
2113            let buf1 = guard.safe_alloc::<f32>(1).expect("small alloc");
2114            // notify_pressure_change is only called by free and
2115            // safe_alloc_with_hooks; safe_alloc does not call it. Trigger
2116            // manually via free.
2117            guard.free(buf1);
2118
2119            // Force a pressure change by directly setting used_bytes and
2120            // calling notify_pressure_change.
2121            guard.used_bytes.store(960, Ordering::Relaxed);
2122            guard.notify_pressure_change(); // None -> High
2123            guard.used_bytes.store(0, Ordering::Relaxed);
2124            guard.notify_pressure_change(); // High -> None
2125
2126            let changes = listener.changes.lock().unwrap();
2127            assert!(
2128                changes.len() >= 2,
2129                "should have at least 2 pressure changes, got {}",
2130                changes.len()
2131            );
2132            assert_eq!(changes[0], (PressureLevel::None, PressureLevel::High));
2133            assert_eq!(changes[1], (PressureLevel::High, PressureLevel::None));
2134        }
2135
2136        #[test]
2137        fn safe_alloc_with_hooks_fast_path_no_hooks() {
2138            // When the allocation fits within budget, hooks should not fire
2139            // and the allocation should succeed.
2140            let device = make_device();
2141            let guard = MemoryGuardBuilder::new(device)
2142                .budget_bytes(1024 * 1024)
2143                .build()
2144                .expect("build guard");
2145
2146            let buf = guard
2147                .safe_alloc_with_hooks::<f32>(64)
2148                .expect("fast-path alloc");
2149            assert_eq!(guard.stats().used_bytes, 64 * 4);
2150            guard.free(buf);
2151        }
2152    }
2153}