1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
41use std::sync::{Arc, Mutex};
42use std::thread::JoinHandle;
43use std::time::{Duration, Instant};
44
45use crate::buffer::CudaBuffer;
46use crate::device::GpuDevice;
47use crate::error::{GpuError, GpuResult};
48
49#[derive(Debug, Clone, PartialEq, Eq, Default)]
55pub enum OomPolicy {
56 #[default]
58 Fail,
59 RetryAfterFree,
61 WaitAndRetry {
64 timeout_secs: u64,
66 },
67 CheckpointAndFail,
69}
70
71pub struct MemoryHook {
97 pub name: String,
99 pub estimated_free_bytes: usize,
101 pub execution_overhead_bytes: usize,
105 pub priority: u32,
108 pub callback: Box<dyn Fn() -> usize + Send + Sync>,
111}
112
113impl std::fmt::Debug for MemoryHook {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("MemoryHook")
116 .field("name", &self.name)
117 .field("estimated_free_bytes", &self.estimated_free_bytes)
118 .field("execution_overhead_bytes", &self.execution_overhead_bytes)
119 .field("priority", &self.priority)
120 .finish()
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
130pub enum PressureLevel {
131 None,
133 Low,
135 Medium,
137 High,
139 Critical,
141}
142
143impl std::fmt::Display for PressureLevel {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 let label = match self {
146 Self::None => "none",
147 Self::Low => "low",
148 Self::Medium => "medium",
149 Self::High => "high",
150 Self::Critical => "critical",
151 };
152 f.write_str(label)
153 }
154}
155
156pub trait MemoryPressureListener: Send + Sync {
166 fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel);
168}
169
170pub struct MemoryReservation {
180 _reservation: CudaBuffer<u8>,
183 reserved_bytes: usize,
185 device_ordinal: usize,
187}
188
189impl MemoryReservation {
190 #[inline]
192 pub fn reserved_bytes(&self) -> usize {
193 self.reserved_bytes
194 }
195
196 #[inline]
198 pub fn device_ordinal(&self) -> usize {
199 self.device_ordinal
200 }
201}
202
203impl std::fmt::Debug for MemoryReservation {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.debug_struct("MemoryReservation")
206 .field("reserved_bytes", &self.reserved_bytes)
207 .field("device_ordinal", &self.device_ordinal)
208 .finish()
209 }
210}
211
212#[derive(Debug, Clone, PartialEq, Eq)]
218pub struct MemoryStats {
219 pub used_bytes: usize,
221 pub budget_bytes: usize,
223 pub peak_bytes: usize,
225 pub free_device_bytes: usize,
227 pub total_device_bytes: usize,
229 pub num_allocations: usize,
231 pub num_oom_recoveries: usize,
233}
234
235pub struct MemoryGuard {
250 device: Arc<GpuDevice>,
251 reservation: Mutex<Option<MemoryReservation>>,
253 budget_bytes: AtomicUsize,
255 used_bytes: AtomicUsize,
257 peak_bytes: AtomicUsize,
259 num_allocations: AtomicUsize,
261 num_oom_recoveries: AtomicUsize,
263 oom_policy: Mutex<OomPolicy>,
265 on_oom_callback: Mutex<Option<Box<dyn Fn() + Send + Sync>>>,
267 hooks: Mutex<Vec<MemoryHook>>,
269 pressure_listeners: Mutex<Vec<Box<dyn MemoryPressureListener>>>,
271 last_pressure_level: Mutex<PressureLevel>,
273}
274
275unsafe impl Send for MemoryGuard {}
277unsafe impl Sync for MemoryGuard {}
278
279impl MemoryGuard {
280 pub fn set_budget(&self, bytes: usize) {
290 self.budget_bytes.store(bytes, Ordering::SeqCst);
291 }
292
293 #[inline]
295 pub fn budget(&self) -> usize {
296 self.budget_bytes.load(Ordering::Relaxed)
297 }
298
299 pub fn on_oom<F: Fn() + Send + Sync + 'static>(&self, f: F) {
307 *self.on_oom_callback.lock().unwrap() = Some(Box::new(f));
308 }
309
310 pub fn set_oom_policy(&self, policy: OomPolicy) {
312 *self.oom_policy.lock().unwrap() = policy;
313 }
314
315 pub fn register_hook(&self, hook: MemoryHook) {
325 self.hooks.lock().unwrap().push(hook);
326 }
327
328 pub fn remove_hook(&self, name: &str) -> bool {
332 let mut hooks = self.hooks.lock().unwrap();
333 let before = hooks.len();
334 hooks.retain(|h| h.name != name);
335 hooks.len() < before
336 }
337
338 pub fn pressure_level(&self) -> PressureLevel {
343 let budget = self.budget_bytes.load(Ordering::Relaxed);
344 if budget == 0 {
345 return PressureLevel::None;
346 }
347 let used = self.used_bytes.load(Ordering::Relaxed);
348 Self::compute_pressure(budget, used)
349 }
350
351 fn compute_pressure(budget: usize, used: usize) -> PressureLevel {
355 if budget == 0 {
356 return PressureLevel::None;
357 }
358 if used >= budget {
359 return PressureLevel::Critical;
360 }
361 let free_frac = ((budget - used) as f64) / (budget as f64);
362 if free_frac > 0.30 {
363 PressureLevel::None
364 } else if free_frac > 0.10 {
365 PressureLevel::Low
366 } else if free_frac > 0.05 {
367 PressureLevel::Medium
368 } else {
369 PressureLevel::High
370 }
371 }
372
373 pub fn add_pressure_listener(&self, listener: Box<dyn MemoryPressureListener>) {
376 self.pressure_listeners.lock().unwrap().push(listener);
377 }
378
379 fn notify_pressure_change(&self) {
381 let new_level = self.pressure_level();
382 let mut last = self.last_pressure_level.lock().unwrap();
383 if *last != new_level {
384 let old = *last;
385 *last = new_level;
386 drop(last);
389 let listeners = self.pressure_listeners.lock().unwrap();
390 for listener in listeners.iter() {
391 listener.on_pressure_change(old, new_level);
392 }
393 }
394 }
395
396 #[cfg(feature = "cuda")]
412 pub fn safe_alloc_with_hooks<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
413 where
414 T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
415 {
416 let alloc_bytes = count.saturating_mul(std::mem::size_of::<T>());
417
418 if self.check_budget(alloc_bytes).is_ok() {
420 let result = self.try_alloc_zeros::<T>(count, alloc_bytes);
421 if result.is_ok() {
422 self.notify_pressure_change();
423 }
424 return result;
425 }
426
427 let budget = self.budget_bytes.load(Ordering::Relaxed);
429 let used = self.used_bytes.load(Ordering::Relaxed);
430 let shortfall = (used + alloc_bytes).saturating_sub(budget);
431
432 let freed = self.run_hooks(shortfall, budget, used);
434
435 if freed > 0 {
436 if self.check_budget(alloc_bytes).is_ok() {
438 let result = self.try_alloc_zeros::<T>(count, alloc_bytes);
439 if result.is_ok() {
440 self.notify_pressure_change();
441 return result;
442 }
443 if let Err(e) = result {
446 if self.is_oom(&e) {
447 return self.handle_oom(count, alloc_bytes, e);
448 }
449 return Err(e);
450 }
451 }
452 }
453
454 if self.check_budget(alloc_bytes).is_err() {
457 let budget = self.budget_bytes.load(Ordering::Relaxed);
458 let used = self.used_bytes.load(Ordering::Relaxed);
459 return Err(crate::error::GpuError::BudgetExceeded {
460 requested_bytes: alloc_bytes,
461 budget_bytes: budget,
462 used_bytes: used,
463 });
464 }
465
466 match self.try_alloc_zeros::<T>(count, alloc_bytes) {
468 Ok(buf) => {
469 self.notify_pressure_change();
470 Ok(buf)
471 }
472 Err(e) if self.is_oom(&e) => self.handle_oom(count, alloc_bytes, e),
473 Err(e) => Err(e),
474 }
475 }
476
477 #[allow(dead_code)]
482 fn run_hooks(&self, shortfall: usize, budget: usize, used: usize) -> usize {
483 let hooks = self.hooks.lock().unwrap();
487 if hooks.is_empty() {
488 return 0;
489 }
490
491 let mut indices: Vec<usize> = (0..hooks.len()).collect();
492 indices.sort_by(|&a, &b| {
493 hooks[a].priority.cmp(&hooks[b].priority).then_with(|| {
494 hooks[b]
495 .estimated_free_bytes
496 .cmp(&hooks[a].estimated_free_bytes)
497 })
498 });
499
500 let mut total_freed: usize = 0;
501 let mut current_used = used;
502
503 for &idx in &indices {
504 if total_freed >= shortfall {
505 break;
506 }
507
508 let hook = &hooks[idx];
509
510 let headroom = budget.saturating_sub(current_used);
514 if hook.execution_overhead_bytes > headroom {
515 continue;
516 }
517
518 let freed = (hook.callback)();
519 total_freed = total_freed.saturating_add(freed);
520 current_used = current_used.saturating_sub(freed);
521 }
522
523 if total_freed > 0 {
526 self.used_bytes
527 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
528 Some(current.saturating_sub(total_freed))
529 })
530 .ok();
531 }
532
533 total_freed
534 }
535
536 pub fn release_reservation(&self) -> usize {
544 let mut lock = self.reservation.lock().unwrap();
545 if let Some(res) = lock.take() {
546 let bytes = res.reserved_bytes;
547 drop(res);
548 bytes
549 } else {
550 0
551 }
552 }
553
554 pub fn has_reservation(&self) -> bool {
556 self.reservation.lock().unwrap().is_some()
557 }
558
559 #[cfg(feature = "cuda")]
569 pub fn safe_alloc<T>(&self, count: usize) -> GpuResult<CudaBuffer<T>>
570 where
571 T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
572 {
573 let alloc_bytes = count.saturating_mul(std::mem::size_of::<T>());
574
575 self.check_budget(alloc_bytes)?;
577
578 match self.try_alloc_zeros::<T>(count, alloc_bytes) {
580 Ok(buf) => Ok(buf),
581 Err(e) if self.is_oom(&e) => self.handle_oom(count, alloc_bytes, e),
582 Err(e) => Err(e),
583 }
584 }
585
586 #[cfg(feature = "cuda")]
589 pub fn safe_alloc_copy<T>(&self, data: &[T]) -> GpuResult<CudaBuffer<T>>
590 where
591 T: cudarc::driver::DeviceRepr,
592 {
593 let alloc_bytes = data.len().saturating_mul(std::mem::size_of::<T>());
594
595 self.check_budget(alloc_bytes)?;
596
597 match self.try_alloc_copy(data, alloc_bytes) {
598 Ok(buf) => Ok(buf),
599 Err(e) if self.is_oom(&e) => {
600 let policy = self.oom_policy.lock().unwrap().clone();
602 match policy {
603 OomPolicy::Fail => Err(e),
604 OomPolicy::RetryAfterFree => {
605 self.free_caches();
606 self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
607 self.try_alloc_copy(data, alloc_bytes)
608 }
609 OomPolicy::WaitAndRetry { timeout_secs } => {
610 self.wait_for_memory(alloc_bytes, timeout_secs)?;
611 self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
612 self.try_alloc_copy(data, alloc_bytes)
613 }
614 OomPolicy::CheckpointAndFail => {
615 self.trigger_emergency_checkpoint();
616 Err(e)
617 }
618 }
619 }
620 Err(e) => Err(e),
621 }
622 }
623
624 pub fn free<T>(&self, buffer: CudaBuffer<T>) {
627 let bytes = buffer
628 .len()
629 .checked_mul(std::mem::size_of::<T>())
630 .unwrap_or(0);
631 self.used_bytes
632 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
633 Some(current.saturating_sub(bytes))
634 })
635 .ok();
636 self.num_allocations
637 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
638 Some(current.saturating_sub(1))
639 })
640 .ok();
641 drop(buffer);
642 self.notify_pressure_change();
643 }
644
645 pub fn stats(&self) -> MemoryStats {
651 let (free_device, total_device) = self.query_device_memory();
652 MemoryStats {
653 used_bytes: self.used_bytes.load(Ordering::Relaxed),
654 budget_bytes: self.budget_bytes.load(Ordering::Relaxed),
655 peak_bytes: self.peak_bytes.load(Ordering::Relaxed),
656 free_device_bytes: free_device,
657 total_device_bytes: total_device,
658 num_allocations: self.num_allocations.load(Ordering::Relaxed),
659 num_oom_recoveries: self.num_oom_recoveries.load(Ordering::Relaxed),
660 }
661 }
662
663 pub fn reset_peak_stats(&self) {
665 let current = self.used_bytes.load(Ordering::Relaxed);
666 self.peak_bytes.store(current, Ordering::Relaxed);
667 }
668
669 #[inline]
671 pub fn device(&self) -> &GpuDevice {
672 &self.device
673 }
674
675 #[inline]
677 pub fn device_arc(&self) -> &Arc<GpuDevice> {
678 &self.device
679 }
680
681 #[allow(dead_code)]
691 fn check_budget(&self, alloc_bytes: usize) -> GpuResult<()> {
692 let budget = self.budget_bytes.load(Ordering::Relaxed);
693 if budget == 0 {
694 return Ok(()); }
696 let used = self.used_bytes.load(Ordering::Relaxed);
697 if used.saturating_add(alloc_bytes) > budget {
698 return Err(GpuError::BudgetExceeded {
699 requested_bytes: alloc_bytes,
700 budget_bytes: budget,
701 used_bytes: used,
702 });
703 }
704 Ok(())
705 }
706
707 #[cfg(feature = "cuda")]
709 fn try_alloc_zeros<T>(&self, count: usize, alloc_bytes: usize) -> GpuResult<CudaBuffer<T>>
710 where
711 T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
712 {
713 let slice = self.device.stream().alloc_zeros::<T>(count)?;
714
715 let prev = self.used_bytes.fetch_add(alloc_bytes, Ordering::Relaxed);
716 self.peak_bytes
717 .fetch_max(prev + alloc_bytes, Ordering::Relaxed);
718 self.num_allocations.fetch_add(1, Ordering::Relaxed);
719
720 Ok(CudaBuffer {
721 data: Some(slice),
722 len: count,
723 alloc_len: count,
724 device_ordinal: self.device.ordinal(),
725 pool_fn: None,
726 })
727 }
728
729 #[cfg(feature = "cuda")]
731 fn try_alloc_copy<T>(&self, data: &[T], alloc_bytes: usize) -> GpuResult<CudaBuffer<T>>
732 where
733 T: cudarc::driver::DeviceRepr,
734 {
735 let slice = self.device.stream().clone_htod(data)?;
736
737 let prev = self.used_bytes.fetch_add(alloc_bytes, Ordering::Relaxed);
738 self.peak_bytes
739 .fetch_max(prev + alloc_bytes, Ordering::Relaxed);
740 self.num_allocations.fetch_add(1, Ordering::Relaxed);
741
742 Ok(CudaBuffer {
743 data: Some(slice),
744 len: data.len(),
745 alloc_len: data.len(),
746 device_ordinal: self.device.ordinal(),
747 pool_fn: None,
748 })
749 }
750
751 #[allow(dead_code)]
753 fn is_oom(&self, err: &GpuError) -> bool {
754 match err {
755 GpuError::OutOfMemory { .. } => true,
756 #[cfg(feature = "cuda")]
757 GpuError::Driver(driver_err) => {
758 let msg = format!("{driver_err}");
759 msg.contains("OUT_OF_MEMORY")
760 || msg.contains("out of memory")
761 || msg.contains("CUDA_ERROR_OUT_OF_MEMORY")
762 }
763 _ => false,
764 }
765 }
766
767 #[cfg(feature = "cuda")]
769 fn handle_oom<T>(
770 &self,
771 count: usize,
772 alloc_bytes: usize,
773 original_err: GpuError,
774 ) -> GpuResult<CudaBuffer<T>>
775 where
776 T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
777 {
778 let policy = self.oom_policy.lock().unwrap().clone();
779 match policy {
780 OomPolicy::Fail => Err(original_err),
781 OomPolicy::RetryAfterFree => {
782 self.free_caches();
783 self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
784 self.try_alloc_zeros(count, alloc_bytes)
785 }
786 OomPolicy::WaitAndRetry { timeout_secs } => {
787 self.wait_for_memory(alloc_bytes, timeout_secs)?;
788 self.num_oom_recoveries.fetch_add(1, Ordering::Relaxed);
789 self.try_alloc_zeros(count, alloc_bytes)
790 }
791 OomPolicy::CheckpointAndFail => {
792 self.trigger_emergency_checkpoint();
793 Err(original_err)
794 }
795 }
796 }
797
798 #[allow(dead_code)]
802 fn free_caches(&self) {
803 }
806
807 #[allow(dead_code)]
810 fn wait_for_memory(&self, needed_bytes: usize, timeout_secs: u64) -> GpuResult<()> {
811 let deadline = Instant::now() + Duration::from_secs(timeout_secs);
812 loop {
813 let (free, _) = self.query_device_memory();
814 if free >= needed_bytes {
815 return Ok(());
816 }
817 if Instant::now() >= deadline {
818 return Err(GpuError::OutOfMemory {
819 requested_bytes: needed_bytes,
820 free_bytes: free,
821 });
822 }
823 std::thread::sleep(Duration::from_millis(100));
824 }
825 }
826
827 #[allow(dead_code)]
829 fn trigger_emergency_checkpoint(&self) {
830 let lock = self.on_oom_callback.lock().unwrap();
831 if let Some(cb) = lock.as_ref() {
832 cb();
833 }
834 }
835
836 fn query_device_memory(&self) -> (usize, usize) {
841 #[cfg(feature = "cuda")]
842 {
843 cudarc::driver::result::mem_get_info().unwrap_or((0, 0))
844 }
845 #[cfg(not(feature = "cuda"))]
846 {
847 (0, 0)
848 }
849 }
850}
851
852impl std::fmt::Debug for MemoryGuard {
853 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
854 f.debug_struct("MemoryGuard")
855 .field("device_ordinal", &self.device.ordinal())
856 .field("budget_bytes", &self.budget_bytes.load(Ordering::Relaxed))
857 .field("used_bytes", &self.used_bytes.load(Ordering::Relaxed))
858 .field("peak_bytes", &self.peak_bytes.load(Ordering::Relaxed))
859 .field(
860 "has_reservation",
861 &self.reservation.lock().unwrap().is_some(),
862 )
863 .finish()
864 }
865}
866
867#[cfg(not(feature = "cuda"))]
872impl MemoryGuard {
873 pub fn safe_alloc<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
875 Err(GpuError::NoCudaFeature)
876 }
877
878 pub fn safe_alloc_copy<T>(&self, _data: &[T]) -> GpuResult<CudaBuffer<T>> {
880 Err(GpuError::NoCudaFeature)
881 }
882
883 pub fn safe_alloc_with_hooks<T>(&self, _count: usize) -> GpuResult<CudaBuffer<T>> {
885 Err(GpuError::NoCudaFeature)
886 }
887}
888
889pub struct MemoryGuardBuilder {
908 device: Arc<GpuDevice>,
909 budget_bytes: usize,
910 reserve_bytes: usize,
911 oom_policy: OomPolicy,
912}
913
914impl MemoryGuardBuilder {
915 pub fn new(device: Arc<GpuDevice>) -> Self {
917 Self {
918 device,
919 budget_bytes: 0,
920 reserve_bytes: 0,
921 oom_policy: OomPolicy::default(),
922 }
923 }
924
925 pub fn budget_bytes(mut self, bytes: usize) -> Self {
927 self.budget_bytes = bytes;
928 self
929 }
930
931 pub fn reserve_bytes(mut self, bytes: usize) -> Self {
934 self.reserve_bytes = bytes;
935 self
936 }
937
938 pub fn oom_policy(mut self, policy: OomPolicy) -> Self {
940 self.oom_policy = policy;
941 self
942 }
943
944 #[cfg(feature = "cuda")]
949 pub fn build(self) -> GpuResult<MemoryGuard> {
950 let reservation = if self.reserve_bytes > 0 {
951 let slice = self.device.stream().alloc_zeros::<u8>(self.reserve_bytes)?;
952 Some(MemoryReservation {
953 _reservation: CudaBuffer {
954 data: Some(slice),
955 len: self.reserve_bytes,
956 alloc_len: self.reserve_bytes,
957 device_ordinal: self.device.ordinal(),
958 pool_fn: None,
959 },
960 reserved_bytes: self.reserve_bytes,
961 device_ordinal: self.device.ordinal(),
962 })
963 } else {
964 None
965 };
966
967 Ok(MemoryGuard {
968 device: self.device,
969 reservation: Mutex::new(reservation),
970 budget_bytes: AtomicUsize::new(self.budget_bytes),
971 used_bytes: AtomicUsize::new(0),
972 peak_bytes: AtomicUsize::new(0),
973 num_allocations: AtomicUsize::new(0),
974 num_oom_recoveries: AtomicUsize::new(0),
975 oom_policy: Mutex::new(self.oom_policy),
976 on_oom_callback: Mutex::new(None),
977 hooks: Mutex::new(Vec::new()),
978 pressure_listeners: Mutex::new(Vec::new()),
979 last_pressure_level: Mutex::new(PressureLevel::None),
980 })
981 }
982
983 #[cfg(not(feature = "cuda"))]
985 pub fn build(self) -> GpuResult<MemoryGuard> {
986 Ok(MemoryGuard {
987 device: self.device,
988 reservation: Mutex::new(None),
989 budget_bytes: AtomicUsize::new(self.budget_bytes),
990 used_bytes: AtomicUsize::new(0),
991 peak_bytes: AtomicUsize::new(0),
992 num_allocations: AtomicUsize::new(0),
993 num_oom_recoveries: AtomicUsize::new(0),
994 oom_policy: Mutex::new(self.oom_policy),
995 on_oom_callback: Mutex::new(None),
996 hooks: Mutex::new(Vec::new()),
997 pressure_listeners: Mutex::new(Vec::new()),
998 last_pressure_level: Mutex::new(PressureLevel::None),
999 })
1000 }
1001}
1002
1003pub struct MemoryWatchdog {
1033 device: Arc<GpuDevice>,
1034 pressure_threshold_bytes: usize,
1036 check_interval: Duration,
1038 paused: AtomicBool,
1040 stop: AtomicBool,
1042 has_checked: AtomicBool,
1044}
1045
1046impl MemoryWatchdog {
1047 pub fn new(
1050 device: Arc<GpuDevice>,
1051 pressure_threshold_bytes: usize,
1052 check_interval: Duration,
1053 ) -> Self {
1054 Self {
1055 device,
1056 pressure_threshold_bytes,
1057 check_interval,
1058 paused: AtomicBool::new(false),
1059 stop: AtomicBool::new(false),
1060 has_checked: AtomicBool::new(false),
1061 }
1062 }
1063
1064 pub fn start(self: Arc<Self>) -> JoinHandle<()> {
1067 std::thread::Builder::new()
1068 .name("ferrotorch-memory-watchdog".into())
1069 .spawn(move || {
1070 while !self.stop.load(Ordering::Relaxed) {
1071 let free = self.query_free_memory();
1072 if free < self.pressure_threshold_bytes {
1073 self.paused.store(true, Ordering::SeqCst);
1074 while self.query_free_memory() < self.pressure_threshold_bytes {
1076 if self.stop.load(Ordering::Relaxed) {
1077 return;
1078 }
1079 std::thread::sleep(Duration::from_millis(500));
1080 }
1081 self.paused.store(false, Ordering::SeqCst);
1082 }
1083 self.has_checked.store(true, Ordering::SeqCst);
1084 std::thread::sleep(self.check_interval);
1085 }
1086 })
1087 .expect("failed to spawn memory watchdog thread")
1088 }
1089
1090 pub fn stop(&self) {
1092 self.stop.store(true, Ordering::SeqCst);
1093 }
1094
1095 #[inline]
1098 pub fn check_pressure(&self) -> bool {
1099 self.paused.load(Ordering::SeqCst)
1100 }
1101
1102 pub fn wait_if_paused(&self) {
1105 while self.paused.load(Ordering::SeqCst) {
1106 std::thread::sleep(Duration::from_millis(100));
1107 }
1108 }
1109
1110 pub fn wait_for_first_check(&self, timeout: Duration) {
1113 let start = std::time::Instant::now();
1114 while !self.has_checked.load(Ordering::SeqCst) {
1115 if start.elapsed() > timeout {
1116 return;
1117 }
1118 std::thread::sleep(Duration::from_millis(5));
1119 }
1120 }
1121
1122 #[inline]
1124 pub fn pressure_threshold_bytes(&self) -> usize {
1125 self.pressure_threshold_bytes
1126 }
1127
1128 fn query_free_memory(&self) -> usize {
1133 #[cfg(feature = "cuda")]
1134 {
1135 let ctx = self.device.context();
1137 let _ = ctx.bind_to_thread();
1138 cudarc::driver::result::mem_get_info()
1139 .map(|(free, _)| free)
1140 .unwrap_or(0)
1141 }
1142 #[cfg(not(feature = "cuda"))]
1143 {
1144 let _ = &self.device;
1145 0
1146 }
1147 }
1148}
1149
1150impl std::fmt::Debug for MemoryWatchdog {
1151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1152 f.debug_struct("MemoryWatchdog")
1153 .field("device_ordinal", &self.device.ordinal())
1154 .field("pressure_threshold_bytes", &self.pressure_threshold_bytes)
1155 .field("check_interval", &self.check_interval)
1156 .field("paused", &self.paused.load(Ordering::Relaxed))
1157 .finish()
1158 }
1159}
1160
1161impl GpuDevice {
1166 #[cfg(feature = "cuda")]
1174 pub fn memory_info(&self) -> GpuResult<(usize, usize)> {
1175 let info = cudarc::driver::result::mem_get_info()?;
1180 Ok(info)
1181 }
1182
1183 #[cfg(not(feature = "cuda"))]
1185 pub fn memory_info(&self) -> GpuResult<(usize, usize)> {
1186 Err(GpuError::NoCudaFeature)
1187 }
1188}
1189
1190pub struct MemoryGuardedDevice {
1194 pub guard: MemoryGuard,
1196}
1197
1198impl MemoryGuardedDevice {
1199 #[inline]
1201 pub fn device(&self) -> &GpuDevice {
1202 self.guard.device()
1203 }
1204
1205 #[inline]
1207 pub fn guard(&self) -> &MemoryGuard {
1208 &self.guard
1209 }
1210}
1211
1212impl std::fmt::Debug for MemoryGuardedDevice {
1213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1214 f.debug_struct("MemoryGuardedDevice")
1215 .field("guard", &self.guard)
1216 .finish()
1217 }
1218}
1219
1220#[cfg(test)]
1225mod tests {
1226 use super::*;
1227
1228 #[test]
1233 fn oom_policy_default_is_fail() {
1234 assert_eq!(OomPolicy::default(), OomPolicy::Fail);
1235 }
1236
1237 #[test]
1238 fn oom_policy_debug() {
1239 let p = OomPolicy::WaitAndRetry { timeout_secs: 30 };
1240 let s = format!("{p:?}");
1241 assert!(s.contains("WaitAndRetry"));
1242 assert!(s.contains("30"));
1243 }
1244
1245 #[test]
1246 fn memory_stats_clone_eq() {
1247 let s = MemoryStats {
1248 used_bytes: 100,
1249 budget_bytes: 1000,
1250 peak_bytes: 200,
1251 free_device_bytes: 800,
1252 total_device_bytes: 2000,
1253 num_allocations: 5,
1254 num_oom_recoveries: 1,
1255 };
1256 let s2 = s.clone();
1257 assert_eq!(s, s2);
1258 }
1259
1260 #[test]
1261 fn memory_stats_debug() {
1262 let s = MemoryStats {
1263 used_bytes: 0,
1264 budget_bytes: 0,
1265 peak_bytes: 0,
1266 free_device_bytes: 0,
1267 total_device_bytes: 0,
1268 num_allocations: 0,
1269 num_oom_recoveries: 0,
1270 };
1271 let d = format!("{s:?}");
1272 assert!(d.contains("MemoryStats"));
1273 assert!(d.contains("used_bytes"));
1274 }
1275
1276 #[test]
1277 fn gpu_error_out_of_memory_display() {
1278 let e = GpuError::OutOfMemory {
1279 requested_bytes: 1024,
1280 free_bytes: 512,
1281 };
1282 let s = format!("{e}");
1283 assert!(s.contains("1024"));
1284 assert!(s.contains("512"));
1285 assert!(s.contains("out of memory"));
1286 }
1287
1288 #[test]
1289 fn gpu_error_budget_exceeded_display() {
1290 let e = GpuError::BudgetExceeded {
1291 requested_bytes: 500,
1292 budget_bytes: 1000,
1293 used_bytes: 800,
1294 };
1295 let s = format!("{e}");
1296 assert!(s.contains("500"));
1297 assert!(s.contains("1000"));
1298 assert!(s.contains("800"));
1299 assert!(s.contains("budget exceeded"));
1300 }
1301
1302 #[test]
1307 fn pressure_level_ordering() {
1308 assert!(PressureLevel::None < PressureLevel::Low);
1309 assert!(PressureLevel::Low < PressureLevel::Medium);
1310 assert!(PressureLevel::Medium < PressureLevel::High);
1311 assert!(PressureLevel::High < PressureLevel::Critical);
1312 }
1313
1314 #[test]
1315 fn pressure_level_display() {
1316 assert_eq!(format!("{}", PressureLevel::None), "none");
1317 assert_eq!(format!("{}", PressureLevel::Critical), "critical");
1318 }
1319
1320 #[test]
1321 fn pressure_level_debug_clone_eq() {
1322 let p = PressureLevel::Medium;
1323 let p2 = p;
1324 assert_eq!(p, p2);
1325 let s = format!("{p:?}");
1326 assert!(s.contains("Medium"));
1327 }
1328
1329 #[test]
1330 fn compute_pressure_unlimited_budget_is_none() {
1331 assert_eq!(MemoryGuard::compute_pressure(0, 0), PressureLevel::None);
1333 }
1334
1335 #[test]
1336 fn compute_pressure_thresholds() {
1337 let budget = 1000;
1338 assert_eq!(
1340 MemoryGuard::compute_pressure(budget, 0),
1341 PressureLevel::None
1342 );
1343 assert_eq!(
1344 MemoryGuard::compute_pressure(budget, 600),
1345 PressureLevel::None
1346 );
1347 assert_eq!(
1348 MemoryGuard::compute_pressure(budget, 699),
1349 PressureLevel::None
1350 );
1351 assert_eq!(
1353 MemoryGuard::compute_pressure(budget, 750),
1354 PressureLevel::Low
1355 );
1356 assert_eq!(
1357 MemoryGuard::compute_pressure(budget, 890),
1358 PressureLevel::Low
1359 );
1360 assert_eq!(
1362 MemoryGuard::compute_pressure(budget, 910),
1363 PressureLevel::Medium
1364 );
1365 assert_eq!(
1366 MemoryGuard::compute_pressure(budget, 949),
1367 PressureLevel::Medium
1368 );
1369 assert_eq!(
1371 MemoryGuard::compute_pressure(budget, 960),
1372 PressureLevel::High
1373 );
1374 assert_eq!(
1375 MemoryGuard::compute_pressure(budget, 999),
1376 PressureLevel::High
1377 );
1378 assert_eq!(
1380 MemoryGuard::compute_pressure(budget, 1000),
1381 PressureLevel::Critical
1382 );
1383 assert_eq!(
1384 MemoryGuard::compute_pressure(budget, 2000),
1385 PressureLevel::Critical
1386 );
1387 }
1388
1389 #[test]
1390 fn memory_hook_debug() {
1391 let hook = MemoryHook {
1392 name: "test_hook".into(),
1393 estimated_free_bytes: 1024,
1394 execution_overhead_bytes: 64,
1395 priority: 5,
1396 callback: Box::new(|| 1024),
1397 };
1398 let s = format!("{hook:?}");
1399 assert!(s.contains("test_hook"));
1400 assert!(s.contains("1024"));
1401 assert!(s.contains("64"));
1402 assert!(s.contains("5"));
1403 }
1404
1405 #[cfg(feature = "cuda")]
1410 mod gpu_tests {
1411 use super::*;
1412
1413 fn make_device() -> Arc<GpuDevice> {
1414 Arc::new(GpuDevice::new(0).expect("CUDA device 0"))
1415 }
1416
1417 #[test]
1418 fn guard_construction_and_stats() {
1419 let device = make_device();
1420 let guard = MemoryGuardBuilder::new(device)
1421 .budget_bytes(1024 * 1024 * 1024) .oom_policy(OomPolicy::Fail)
1423 .build()
1424 .expect("build guard");
1425
1426 let stats = guard.stats();
1427 assert_eq!(stats.used_bytes, 0);
1428 assert_eq!(stats.budget_bytes, 1024 * 1024 * 1024);
1429 assert_eq!(stats.peak_bytes, 0);
1430 assert_eq!(stats.num_allocations, 0);
1431 assert_eq!(stats.num_oom_recoveries, 0);
1432 assert!(stats.total_device_bytes > 0);
1433 assert!(stats.free_device_bytes > 0);
1434 }
1435
1436 #[test]
1437 fn budget_enforcement_rejects_over_budget() {
1438 let device = make_device();
1439 let guard = MemoryGuardBuilder::new(device)
1440 .budget_bytes(256) .build()
1442 .expect("build guard");
1443
1444 let result = guard.safe_alloc::<f32>(1024); assert!(result.is_err());
1447 match result.unwrap_err() {
1448 GpuError::BudgetExceeded {
1449 requested_bytes,
1450 budget_bytes,
1451 used_bytes,
1452 } => {
1453 assert_eq!(requested_bytes, 1024 * 4);
1454 assert_eq!(budget_bytes, 256);
1455 assert_eq!(used_bytes, 0);
1456 }
1457 other => panic!("expected BudgetExceeded, got {other:?}"),
1458 }
1459 }
1460
1461 #[test]
1462 fn safe_alloc_tracks_usage() {
1463 let device = make_device();
1464 let guard = MemoryGuardBuilder::new(device)
1465 .budget_bytes(0) .build()
1467 .expect("build guard");
1468
1469 let buf = guard.safe_alloc::<f32>(256).expect("alloc 256 f32");
1470 let expected = 256 * std::mem::size_of::<f32>();
1471
1472 let stats = guard.stats();
1473 assert_eq!(stats.used_bytes, expected);
1474 assert_eq!(stats.peak_bytes, expected);
1475 assert_eq!(stats.num_allocations, 1);
1476
1477 guard.free(buf);
1478
1479 let stats = guard.stats();
1480 assert_eq!(stats.used_bytes, 0);
1481 assert_eq!(stats.num_allocations, 0);
1482 assert_eq!(stats.peak_bytes, expected);
1484 }
1485
1486 #[test]
1487 fn safe_alloc_copy_tracks_usage() {
1488 let device = make_device();
1489 let guard = MemoryGuardBuilder::new(device)
1490 .build()
1491 .expect("build guard");
1492
1493 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
1494 let buf = guard.safe_alloc_copy(&data).expect("alloc_copy");
1495 let expected = 4 * std::mem::size_of::<f64>();
1496
1497 assert_eq!(guard.stats().used_bytes, expected);
1498 guard.free(buf);
1499 assert_eq!(guard.stats().used_bytes, 0);
1500 }
1501
1502 #[test]
1503 fn reset_peak_stats_works() {
1504 let device = make_device();
1505 let guard = MemoryGuardBuilder::new(device)
1506 .build()
1507 .expect("build guard");
1508
1509 let buf = guard.safe_alloc::<f32>(512).expect("alloc");
1510 let peak = guard.stats().peak_bytes;
1511 assert!(peak > 0);
1512
1513 guard.free(buf);
1514 assert_eq!(guard.stats().peak_bytes, peak); guard.reset_peak_stats();
1517 assert_eq!(guard.stats().peak_bytes, 0); }
1519
1520 #[test]
1521 fn emergency_checkpoint_callback_invoked() {
1522 let device = make_device();
1523 let guard = MemoryGuardBuilder::new(device)
1524 .build()
1525 .expect("build guard");
1526
1527 let called = Arc::new(AtomicBool::new(false));
1528 let called_clone = Arc::clone(&called);
1529 guard.on_oom(move || {
1530 called_clone.store(true, Ordering::SeqCst);
1531 });
1532
1533 guard.trigger_emergency_checkpoint();
1535 assert!(called.load(Ordering::SeqCst));
1536 }
1537
1538 #[test]
1539 fn set_budget_at_runtime() {
1540 let device = make_device();
1541 let guard = MemoryGuardBuilder::new(device)
1542 .budget_bytes(0)
1543 .build()
1544 .expect("build guard");
1545
1546 assert_eq!(guard.budget(), 0);
1547
1548 guard.set_budget(1024);
1549 assert_eq!(guard.budget(), 1024);
1550
1551 let result = guard.safe_alloc::<f32>(1024); assert!(result.is_err());
1554 }
1555
1556 #[test]
1557 fn memory_info_returns_nonzero() {
1558 let device = GpuDevice::new(0).expect("CUDA device 0");
1559 let (free, total) = device.memory_info().expect("memory_info");
1560 assert!(total > 0, "total device memory should be > 0");
1561 assert!(free > 0, "free device memory should be > 0");
1562 assert!(free <= total, "free should not exceed total");
1563 }
1564
1565 #[test]
1566 fn reservation_holds_memory() {
1567 let device = make_device();
1568 let (free_before, _) = device.memory_info().expect("memory_info");
1569
1570 let reserve_bytes = 64 * 1024 * 1024;
1572 let guard = MemoryGuardBuilder::new(device)
1573 .reserve_bytes(reserve_bytes)
1574 .build()
1575 .expect("build guard with reservation");
1576
1577 assert!(guard.has_reservation());
1578
1579 let released = guard.release_reservation();
1581 assert_eq!(released, reserve_bytes);
1582 assert!(!guard.has_reservation());
1583
1584 assert_eq!(guard.release_reservation(), 0);
1586
1587 let _ = free_before; }
1589
1590 #[test]
1591 fn guard_debug_impl() {
1592 let device = make_device();
1593 let guard = MemoryGuardBuilder::new(device)
1594 .budget_bytes(999)
1595 .build()
1596 .expect("build guard");
1597
1598 let s = format!("{guard:?}");
1599 assert!(s.contains("MemoryGuard"));
1600 assert!(s.contains("budget_bytes"));
1601 assert!(s.contains("999"));
1602 }
1603
1604 #[test]
1605 fn watchdog_detects_no_pressure_when_plenty_free() {
1606 let device = make_device();
1607 let watchdog = Arc::new(MemoryWatchdog::new(device, 1, Duration::from_millis(50)));
1609
1610 assert!(!watchdog.check_pressure());
1611 watchdog.wait_if_paused(); let wd = Arc::clone(&watchdog);
1615 let handle = wd.start();
1616 watchdog.wait_for_first_check(Duration::from_secs(5));
1617 assert!(!watchdog.check_pressure());
1618 watchdog.stop();
1619 handle.join().expect("watchdog thread");
1620 }
1621
1622 #[test]
1623 fn watchdog_debug_impl() {
1624 let device = make_device();
1625 let watchdog = MemoryWatchdog::new(device, 1024, Duration::from_secs(1));
1626 let s = format!("{watchdog:?}");
1627 assert!(s.contains("MemoryWatchdog"));
1628 assert!(s.contains("1024"));
1629 }
1630
1631 #[test]
1632 fn memory_guarded_device() {
1633 let device = make_device();
1634 let guard = MemoryGuardBuilder::new(Arc::clone(&device))
1635 .budget_bytes(1024 * 1024)
1636 .build()
1637 .expect("build guard");
1638
1639 let guarded = MemoryGuardedDevice { guard };
1640 assert_eq!(guarded.device().ordinal(), 0);
1641 assert_eq!(guarded.guard().budget(), 1024 * 1024);
1642
1643 let s = format!("{guarded:?}");
1644 assert!(s.contains("MemoryGuardedDevice"));
1645 }
1646
1647 #[test]
1648 fn oom_policy_retry_after_free() {
1649 let device = make_device();
1654 let guard = MemoryGuardBuilder::new(device)
1655 .oom_policy(OomPolicy::RetryAfterFree)
1656 .build()
1657 .expect("build guard");
1658
1659 let buf = guard.safe_alloc::<f32>(64).expect("alloc");
1662 assert_eq!(guard.stats().num_oom_recoveries, 0);
1663 guard.free(buf);
1664 }
1665
1666 #[test]
1667 fn multiple_allocations_budget_accounting() {
1668 let device = make_device();
1669 let budget = 2048_usize;
1670 let guard = MemoryGuardBuilder::new(device)
1671 .budget_bytes(budget)
1672 .build()
1673 .expect("build guard");
1674
1675 let buf1 = guard.safe_alloc::<f32>(128).expect("alloc 1");
1677 assert_eq!(guard.stats().used_bytes, 512);
1678 assert_eq!(guard.stats().num_allocations, 1);
1679
1680 let buf2 = guard.safe_alloc::<f32>(128).expect("alloc 2");
1682 assert_eq!(guard.stats().used_bytes, 1024);
1683 assert_eq!(guard.stats().num_allocations, 2);
1684
1685 let result = guard.safe_alloc::<f32>(512);
1687 assert!(result.is_err());
1688
1689 guard.free(buf1);
1693 guard.free(buf2);
1694 assert_eq!(guard.stats().used_bytes, 0);
1695
1696 let buf3 = guard.safe_alloc::<f32>(512).expect("alloc 3 after free");
1697 assert_eq!(guard.stats().used_bytes, 2048);
1698 guard.free(buf3);
1699 }
1700
1701 #[test]
1706 fn hook_called_on_budget_exceeded() {
1707 let device = make_device();
1708 let budget = 1024_usize; let guard = MemoryGuardBuilder::new(device)
1710 .budget_bytes(budget)
1711 .build()
1712 .expect("build guard");
1713
1714 let called = Arc::new(AtomicBool::new(false));
1715 let called_clone = Arc::clone(&called);
1716
1717 guard.register_hook(MemoryHook {
1718 name: "test_hook".into(),
1719 estimated_free_bytes: 2048,
1720 execution_overhead_bytes: 0,
1721 priority: 10,
1722 callback: Box::new(move || {
1723 called_clone.store(true, Ordering::SeqCst);
1724 0 }),
1726 });
1727
1728 let _result = guard.safe_alloc_with_hooks::<f32>(512);
1732 assert!(called.load(Ordering::SeqCst), "hook was not called");
1733 }
1734
1735 #[test]
1736 fn hook_frees_enough_memory_allocation_succeeds() {
1737 let device = make_device();
1738 let budget = 2048_usize;
1741 let guard = MemoryGuardBuilder::new(device)
1742 .budget_bytes(budget)
1743 .build()
1744 .expect("build guard");
1745
1746 let prefill = guard.safe_alloc::<f32>(384).expect("prefill");
1748 assert_eq!(guard.stats().used_bytes, 1536);
1749
1750 let called = Arc::new(AtomicBool::new(false));
1751 let called_clone = Arc::clone(&called);
1752
1753 guard.register_hook(MemoryHook {
1758 name: "free_1k".into(),
1759 estimated_free_bytes: 1024,
1760 execution_overhead_bytes: 0,
1761 priority: 10,
1762 callback: Box::new(move || {
1763 called_clone.store(true, Ordering::SeqCst);
1764 1024
1765 }),
1766 });
1767
1768 let buf = guard
1771 .safe_alloc_with_hooks::<f32>(256)
1772 .expect("alloc after hook");
1773 assert!(called.load(Ordering::SeqCst), "hook was not called");
1774
1775 assert_eq!(guard.stats().used_bytes, 1536);
1777
1778 guard.free(buf);
1779 guard.free(prefill);
1780 }
1781
1782 #[test]
1783 fn hook_not_enough_falls_through_to_oom_policy() {
1784 let device = make_device();
1785 let budget = 512_usize;
1786 let guard = MemoryGuardBuilder::new(device)
1787 .budget_bytes(budget)
1788 .oom_policy(OomPolicy::Fail)
1789 .build()
1790 .expect("build guard");
1791
1792 let called = Arc::new(AtomicBool::new(false));
1793 let called_clone = Arc::clone(&called);
1794
1795 guard.register_hook(MemoryHook {
1796 name: "weak_hook".into(),
1797 estimated_free_bytes: 64,
1798 execution_overhead_bytes: 0,
1799 priority: 10,
1800 callback: Box::new(move || {
1801 called_clone.store(true, Ordering::SeqCst);
1802 64 }),
1804 });
1805
1806 let result = guard.safe_alloc_with_hooks::<f32>(1024);
1809 assert!(
1810 called.load(Ordering::SeqCst),
1811 "hook should have been called"
1812 );
1813 assert!(result.is_err(), "allocation should have failed");
1814 }
1815
1816 #[test]
1817 fn hooks_called_in_priority_order() {
1818 let device = make_device();
1819 let budget = 1024_usize;
1820 let guard = MemoryGuardBuilder::new(device)
1821 .budget_bytes(budget)
1822 .build()
1823 .expect("build guard");
1824
1825 let order = Arc::new(Mutex::new(Vec::new()));
1826
1827 let o1 = Arc::clone(&order);
1828 guard.register_hook(MemoryHook {
1829 name: "priority_20".into(),
1830 estimated_free_bytes: 256,
1831 execution_overhead_bytes: 0,
1832 priority: 20,
1833 callback: Box::new(move || {
1834 o1.lock().unwrap().push(20_u32);
1835 256
1836 }),
1837 });
1838
1839 let o2 = Arc::clone(&order);
1840 guard.register_hook(MemoryHook {
1841 name: "priority_5".into(),
1842 estimated_free_bytes: 256,
1843 execution_overhead_bytes: 0,
1844 priority: 5,
1845 callback: Box::new(move || {
1846 o2.lock().unwrap().push(5_u32);
1847 256
1848 }),
1849 });
1850
1851 let o3 = Arc::clone(&order);
1852 guard.register_hook(MemoryHook {
1853 name: "priority_10".into(),
1854 estimated_free_bytes: 256,
1855 execution_overhead_bytes: 0,
1856 priority: 10,
1857 callback: Box::new(move || {
1858 o3.lock().unwrap().push(10_u32);
1859 256
1860 }),
1861 });
1862
1863 let _result = guard.safe_alloc_with_hooks::<f32>(512);
1866 let call_order = order.lock().unwrap();
1867 assert_eq!(
1868 &*call_order,
1869 &[5, 10, 20],
1870 "hooks should fire in priority order"
1871 );
1872 }
1873
1874 #[test]
1875 fn remove_hook_by_name() {
1876 let device = make_device();
1877 let guard = MemoryGuardBuilder::new(device)
1878 .budget_bytes(1024)
1879 .build()
1880 .expect("build guard");
1881
1882 let called = Arc::new(AtomicBool::new(false));
1883 let called_clone = Arc::clone(&called);
1884
1885 guard.register_hook(MemoryHook {
1886 name: "removable".into(),
1887 estimated_free_bytes: 2048,
1888 execution_overhead_bytes: 0,
1889 priority: 10,
1890 callback: Box::new(move || {
1891 called_clone.store(true, Ordering::SeqCst);
1892 2048
1893 }),
1894 });
1895
1896 assert!(guard.remove_hook("removable"));
1898 assert!(!guard.remove_hook("removable"));
1900
1901 let _result = guard.safe_alloc_with_hooks::<f32>(512);
1903 assert!(
1904 !called.load(Ordering::SeqCst),
1905 "removed hook should not have been called"
1906 );
1907 }
1908
1909 #[test]
1910 fn pressure_level_tracks_usage() {
1911 let device = make_device();
1912 let budget = 1000_usize;
1913 let guard = MemoryGuardBuilder::new(device)
1914 .budget_bytes(budget)
1915 .build()
1916 .expect("build guard");
1917
1918 assert_eq!(guard.pressure_level(), PressureLevel::None);
1920
1921 guard.used_bytes.store(750, Ordering::Relaxed);
1923 assert_eq!(guard.pressure_level(), PressureLevel::Low);
1924
1925 guard.used_bytes.store(920, Ordering::Relaxed);
1927 assert_eq!(guard.pressure_level(), PressureLevel::Medium);
1928
1929 guard.used_bytes.store(960, Ordering::Relaxed);
1931 assert_eq!(guard.pressure_level(), PressureLevel::High);
1932
1933 guard.used_bytes.store(1000, Ordering::Relaxed);
1935 assert_eq!(guard.pressure_level(), PressureLevel::Critical);
1936
1937 guard.set_budget(0);
1939 assert_eq!(guard.pressure_level(), PressureLevel::None);
1940 }
1941
1942 #[test]
1943 fn multiple_hooks_called_until_enough_freed() {
1944 let device = make_device();
1945 let budget = 2048_usize;
1946 let guard = MemoryGuardBuilder::new(device)
1947 .budget_bytes(budget)
1948 .build()
1949 .expect("build guard");
1950
1951 let prefill = guard.safe_alloc::<f32>(256).expect("prefill");
1953 assert_eq!(guard.stats().used_bytes, 1024);
1954
1955 let count = Arc::new(AtomicUsize::new(0));
1956
1957 let c1 = Arc::clone(&count);
1959 guard.register_hook(MemoryHook {
1960 name: "hook_a".into(),
1961 estimated_free_bytes: 256,
1962 execution_overhead_bytes: 0,
1963 priority: 1,
1964 callback: Box::new(move || {
1965 c1.fetch_add(1, Ordering::SeqCst);
1966 256
1967 }),
1968 });
1969
1970 let c2 = Arc::clone(&count);
1972 guard.register_hook(MemoryHook {
1973 name: "hook_b".into(),
1974 estimated_free_bytes: 512,
1975 execution_overhead_bytes: 0,
1976 priority: 2,
1977 callback: Box::new(move || {
1978 c2.fetch_add(1, Ordering::SeqCst);
1979 512
1980 }),
1981 });
1982
1983 let c3 = Arc::new(AtomicBool::new(false));
1986 let c3_clone = Arc::clone(&c3);
1987 guard.register_hook(MemoryHook {
1988 name: "hook_c".into(),
1989 estimated_free_bytes: 512,
1990 execution_overhead_bytes: 0,
1991 priority: 3,
1992 callback: Box::new(move || {
1993 c3_clone.store(true, Ordering::SeqCst);
1994 512
1995 }),
1996 });
1997
1998 let buf = guard
2002 .safe_alloc_with_hooks::<f32>(384)
2003 .expect("alloc with hooks");
2004 assert_eq!(count.load(Ordering::SeqCst), 2, "hooks A and B should fire");
2005 assert!(
2006 !c3.load(Ordering::SeqCst),
2007 "hook C should not have been called"
2008 );
2009
2010 guard.free(buf);
2011 guard.free(prefill);
2012 }
2013
2014 #[test]
2015 fn hook_with_excessive_overhead_is_skipped() {
2016 let device = make_device();
2017 let budget = 2048_usize;
2018 let guard = MemoryGuardBuilder::new(device)
2019 .budget_bytes(budget)
2020 .build()
2021 .expect("build guard");
2022
2023 let prefill = guard.safe_alloc::<f32>(480).expect("prefill");
2025 assert_eq!(guard.stats().used_bytes, 1920);
2026
2027 let expensive_called = Arc::new(AtomicBool::new(false));
2028 let expensive_clone = Arc::clone(&expensive_called);
2029
2030 guard.register_hook(MemoryHook {
2032 name: "expensive_hook".into(),
2033 estimated_free_bytes: 1024,
2034 execution_overhead_bytes: 256,
2035 priority: 1,
2036 callback: Box::new(move || {
2037 expensive_clone.store(true, Ordering::SeqCst);
2038 1024
2039 }),
2040 });
2041
2042 let cheap_called = Arc::new(AtomicBool::new(false));
2043 let cheap_clone = Arc::clone(&cheap_called);
2044
2045 guard.register_hook(MemoryHook {
2047 name: "cheap_hook".into(),
2048 estimated_free_bytes: 512,
2049 execution_overhead_bytes: 0,
2050 priority: 2,
2051 callback: Box::new(move || {
2052 cheap_clone.store(true, Ordering::SeqCst);
2053 512
2054 }),
2055 });
2056
2057 let buf = guard
2060 .safe_alloc_with_hooks::<f32>(64)
2061 .expect("alloc with hooks");
2062
2063 assert!(
2064 !expensive_called.load(Ordering::SeqCst),
2065 "expensive hook should have been skipped due to overhead"
2066 );
2067 assert!(
2068 cheap_called.load(Ordering::SeqCst),
2069 "cheap hook should have been called"
2070 );
2071
2072 guard.free(buf);
2073 guard.free(prefill);
2074 }
2075
2076 #[test]
2077 fn pressure_listener_notified_on_change() {
2078 let device = make_device();
2079 let budget = 1000_usize;
2080 let guard = MemoryGuardBuilder::new(device)
2081 .budget_bytes(budget)
2082 .build()
2083 .expect("build guard");
2084
2085 struct TestListener {
2086 changes: Mutex<Vec<(PressureLevel, PressureLevel)>>,
2087 }
2088
2089 impl MemoryPressureListener for TestListener {
2090 fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel) {
2091 self.changes.lock().unwrap().push((old, new));
2092 }
2093 }
2094
2095 let listener = Arc::new(TestListener {
2096 changes: Mutex::new(Vec::new()),
2097 });
2098 let listener_ref = Arc::clone(&listener);
2099
2100 struct ListenerWrapper(Arc<TestListener>);
2103 impl MemoryPressureListener for ListenerWrapper {
2104 fn on_pressure_change(&self, old: PressureLevel, new: PressureLevel) {
2105 self.0.on_pressure_change(old, new);
2106 }
2107 }
2108
2109 guard.add_pressure_listener(Box::new(ListenerWrapper(listener_ref)));
2110
2111 let buf1 = guard.safe_alloc::<f32>(1).expect("small alloc");
2114 guard.free(buf1);
2118
2119 guard.used_bytes.store(960, Ordering::Relaxed);
2122 guard.notify_pressure_change(); guard.used_bytes.store(0, Ordering::Relaxed);
2124 guard.notify_pressure_change(); let changes = listener.changes.lock().unwrap();
2127 assert!(
2128 changes.len() >= 2,
2129 "should have at least 2 pressure changes, got {}",
2130 changes.len()
2131 );
2132 assert_eq!(changes[0], (PressureLevel::None, PressureLevel::High));
2133 assert_eq!(changes[1], (PressureLevel::High, PressureLevel::None));
2134 }
2135
2136 #[test]
2137 fn safe_alloc_with_hooks_fast_path_no_hooks() {
2138 let device = make_device();
2141 let guard = MemoryGuardBuilder::new(device)
2142 .budget_bytes(1024 * 1024)
2143 .build()
2144 .expect("build guard");
2145
2146 let buf = guard
2147 .safe_alloc_with_hooks::<f32>(64)
2148 .expect("fast-path alloc");
2149 assert_eq!(guard.stats().used_bytes, 64 * 4);
2150 guard.free(buf);
2151 }
2152 }
2153}