Skip to main content

infernum_arbiter/
memory.rs

1//! GPU memory tracking and pressure monitoring.
2//!
3//! Tracks VRAM usage across all workloads and signals pressure levels.
4
5use serde::{Deserialize, Serialize};
6use std::sync::atomic::{AtomicU64, Ordering};
7
8/// Memory pressure levels.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
10pub enum MemoryPressure {
11    /// Plenty of memory available (< 50%).
12    Low,
13    /// Moderate usage (50-75%).
14    Moderate,
15    /// High usage (75-90%).
16    High,
17    /// Critical usage (> 90%).
18    Critical,
19}
20
21impl MemoryPressure {
22    /// Creates pressure level from utilization fraction (0.0 - 1.0).
23    pub fn from_utilization(utilization: f32) -> Self {
24        if utilization < 0.5 {
25            Self::Low
26        } else if utilization < 0.75 {
27            Self::Moderate
28        } else if utilization < 0.9 {
29            Self::High
30        } else {
31            Self::Critical
32        }
33    }
34
35    /// Returns the quality reduction factor for this pressure level.
36    pub fn quality_factor(self) -> f32 {
37        match self {
38            Self::Low => 1.0,
39            Self::Moderate => 0.9,
40            Self::High => 0.7,
41            Self::Critical => 0.5,
42        }
43    }
44
45    /// Returns whether new allocations should be blocked.
46    pub fn should_block_new(self) -> bool {
47        matches!(self, Self::Critical)
48    }
49
50    /// Returns whether background work should pause.
51    pub fn should_pause_background(self) -> bool {
52        matches!(self, Self::High | Self::Critical)
53    }
54}
55
56/// Statistics about memory usage.
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
58pub struct MemoryStats {
59    /// Total allocations made.
60    pub total_allocations: u64,
61    /// Current allocations.
62    pub current_allocations: u64,
63    /// Total bytes allocated.
64    pub total_allocated: u64,
65    /// Total bytes deallocated.
66    pub total_deallocated: u64,
67    /// Peak usage in bytes.
68    pub peak_usage: u64,
69    /// Times pressure went critical.
70    pub critical_pressure_count: u64,
71}
72
73/// Tracks GPU memory usage.
74pub struct GpuMemoryTracker {
75    /// Total capacity in bytes.
76    capacity: u64,
77    /// Current used bytes.
78    used: AtomicU64,
79    /// Peak used bytes.
80    peak: AtomicU64,
81    /// Total allocations.
82    total_allocations: AtomicU64,
83    /// Current active allocations.
84    current_allocations: AtomicU64,
85    /// Total allocated ever.
86    total_allocated: AtomicU64,
87    /// Total deallocated ever.
88    total_deallocated: AtomicU64,
89    /// Critical pressure events.
90    critical_events: AtomicU64,
91}
92
93impl GpuMemoryTracker {
94    /// Creates a new tracker with the given capacity.
95    pub fn new(capacity: u64) -> Self {
96        Self {
97            capacity,
98            used: AtomicU64::new(0),
99            peak: AtomicU64::new(0),
100            total_allocations: AtomicU64::new(0),
101            current_allocations: AtomicU64::new(0),
102            total_allocated: AtomicU64::new(0),
103            total_deallocated: AtomicU64::new(0),
104            critical_events: AtomicU64::new(0),
105        }
106    }
107
108    /// Returns total capacity.
109    pub fn capacity(&self) -> u64 {
110        self.capacity
111    }
112
113    /// Returns current usage in bytes.
114    pub fn used(&self) -> u64 {
115        self.used.load(Ordering::Relaxed)
116    }
117
118    /// Returns available bytes.
119    pub fn available(&self) -> u64 {
120        self.capacity.saturating_sub(self.used())
121    }
122
123    /// Returns utilization as fraction (0.0 - 1.0).
124    pub fn utilization(&self) -> f32 {
125        if self.capacity == 0 {
126            return 1.0;
127        }
128        self.used() as f32 / self.capacity as f32
129    }
130
131    /// Returns current pressure level.
132    pub fn pressure(&self) -> f32 {
133        self.utilization()
134    }
135
136    /// Returns current pressure as enum.
137    pub fn pressure_level(&self) -> MemoryPressure {
138        MemoryPressure::from_utilization(self.utilization())
139    }
140
141    /// Allocates memory, returning true if successful.
142    pub fn try_allocate(&self, bytes: u64) -> bool {
143        let mut current = self.used.load(Ordering::Relaxed);
144        loop {
145            let new = current + bytes;
146            if new > self.capacity {
147                return false;
148            }
149            match self
150                .used
151                .compare_exchange_weak(current, new, Ordering::AcqRel, Ordering::Relaxed)
152            {
153                Ok(_) => {
154                    self.total_allocations.fetch_add(1, Ordering::Relaxed);
155                    self.current_allocations.fetch_add(1, Ordering::Relaxed);
156                    self.total_allocated.fetch_add(bytes, Ordering::Relaxed);
157
158                    // Update peak
159                    let mut peak = self.peak.load(Ordering::Relaxed);
160                    while new > peak {
161                        match self.peak.compare_exchange_weak(
162                            peak,
163                            new,
164                            Ordering::AcqRel,
165                            Ordering::Relaxed,
166                        ) {
167                            Ok(_) => break,
168                            Err(p) => peak = p,
169                        }
170                    }
171
172                    // Track critical events
173                    if MemoryPressure::from_utilization(new as f32 / self.capacity as f32)
174                        == MemoryPressure::Critical
175                    {
176                        self.critical_events.fetch_add(1, Ordering::Relaxed);
177                    }
178
179                    return true;
180                },
181                Err(c) => current = c,
182            }
183        }
184    }
185
186    /// Allocates memory unconditionally.
187    pub fn allocate(&self, bytes: u64) {
188        self.used.fetch_add(bytes, Ordering::Relaxed);
189        self.total_allocations.fetch_add(1, Ordering::Relaxed);
190        self.current_allocations.fetch_add(1, Ordering::Relaxed);
191        self.total_allocated.fetch_add(bytes, Ordering::Relaxed);
192
193        // Update peak
194        let new = self.used.load(Ordering::Relaxed);
195        let mut peak = self.peak.load(Ordering::Relaxed);
196        while new > peak {
197            match self
198                .peak
199                .compare_exchange_weak(peak, new, Ordering::AcqRel, Ordering::Relaxed)
200            {
201                Ok(_) => break,
202                Err(p) => peak = p,
203            }
204        }
205    }
206
207    /// Deallocates memory.
208    pub fn deallocate(&self, bytes: u64) {
209        self.used
210            .fetch_sub(bytes.min(self.used()), Ordering::Relaxed);
211        self.current_allocations.fetch_sub(1, Ordering::Relaxed);
212        self.total_deallocated.fetch_add(bytes, Ordering::Relaxed);
213    }
214
215    /// Returns current statistics.
216    pub fn stats(&self) -> MemoryStats {
217        MemoryStats {
218            total_allocations: self.total_allocations.load(Ordering::Relaxed),
219            current_allocations: self.current_allocations.load(Ordering::Relaxed),
220            total_allocated: self.total_allocated.load(Ordering::Relaxed),
221            total_deallocated: self.total_deallocated.load(Ordering::Relaxed),
222            peak_usage: self.peak.load(Ordering::Relaxed),
223            critical_pressure_count: self.critical_events.load(Ordering::Relaxed),
224        }
225    }
226
227    /// Resets the tracker (for testing).
228    #[cfg(test)]
229    pub fn reset(&self) {
230        self.used.store(0, Ordering::Relaxed);
231        self.peak.store(0, Ordering::Relaxed);
232        self.total_allocations.store(0, Ordering::Relaxed);
233        self.current_allocations.store(0, Ordering::Relaxed);
234        self.total_allocated.store(0, Ordering::Relaxed);
235        self.total_deallocated.store(0, Ordering::Relaxed);
236        self.critical_events.store(0, Ordering::Relaxed);
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_pressure_levels() {
246        assert_eq!(MemoryPressure::from_utilization(0.3), MemoryPressure::Low);
247        assert_eq!(
248            MemoryPressure::from_utilization(0.6),
249            MemoryPressure::Moderate
250        );
251        assert_eq!(MemoryPressure::from_utilization(0.8), MemoryPressure::High);
252        assert_eq!(
253            MemoryPressure::from_utilization(0.95),
254            MemoryPressure::Critical
255        );
256    }
257
258    #[test]
259    fn test_tracker_allocation() {
260        let tracker = GpuMemoryTracker::new(1000);
261
262        assert!(tracker.try_allocate(500));
263        assert_eq!(tracker.used(), 500);
264        assert_eq!(tracker.available(), 500);
265
266        assert!(tracker.try_allocate(400));
267        assert_eq!(tracker.used(), 900);
268
269        // Should fail - not enough space
270        assert!(!tracker.try_allocate(200));
271        assert_eq!(tracker.used(), 900);
272    }
273
274    #[test]
275    fn test_tracker_deallocation() {
276        let tracker = GpuMemoryTracker::new(1000);
277
278        tracker.allocate(500);
279        tracker.deallocate(300);
280
281        assert_eq!(tracker.used(), 200);
282        assert_eq!(tracker.available(), 800);
283    }
284
285    #[test]
286    fn test_utilization() {
287        let tracker = GpuMemoryTracker::new(1000);
288
289        tracker.allocate(750);
290        assert!((tracker.utilization() - 0.75).abs() < 0.001);
291        assert_eq!(tracker.pressure_level(), MemoryPressure::High);
292    }
293
294    #[test]
295    fn test_peak_tracking() {
296        let tracker = GpuMemoryTracker::new(1000);
297
298        tracker.allocate(800);
299        tracker.deallocate(500);
300        tracker.allocate(200);
301
302        let stats = tracker.stats();
303        assert_eq!(stats.peak_usage, 800);
304    }
305}