Skip to main content

amaters_core/
memory_limiter.rs

1//! Memory-use limiter with cooperative back-pressure.
2//!
3//! [`MemoryLimiter`] tracks a global byte count via an [`AtomicUsize`] counter.
4//! Callers must obtain an [`AllocationGuard`] via [`MemoryLimiter::try_allocate`];
5//! the guard decrements the counter automatically on drop, ensuring the tracked
6//! usage stays accurate even across panics.
7//!
8//! # Design
9//!
10//! The compare-exchange loop in [`MemoryLimiter::try_allocate`] gives linearisable
11//! semantics: either the allocation is fully visible or it is rejected, with no
12//! window where the counter exceeds `max_bytes`.
13//!
14//! # Example
15//!
16//! ```rust
17//! use amaters_core::memory_limiter::MemoryLimiter;
18//!
19//! let limiter = MemoryLimiter::new(1024);
20//! let guard = limiter.try_allocate(512).expect("should fit");
21//! assert_eq!(limiter.current_bytes(), 512);
22//! drop(guard);
23//! assert_eq!(limiter.current_bytes(), 0);
24//! ```
25
26use std::fmt;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicUsize, Ordering};
29
30// ---------------------------------------------------------------------------
31// OomError
32// ---------------------------------------------------------------------------
33
34/// Error returned when a [`MemoryLimiter`] rejects an allocation because the
35/// requested bytes would exceed `max_bytes`.
36#[derive(Debug)]
37pub struct OomError {
38    /// Configured maximum, in bytes.
39    pub max_bytes: usize,
40    /// Number of bytes currently tracked.
41    pub current_bytes: usize,
42    /// Number of bytes that were requested.
43    pub requested: usize,
44}
45
46impl fmt::Display for OomError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        write!(
49            f,
50            "OOM: requested {} bytes but only {} / {} available",
51            self.requested,
52            self.max_bytes.saturating_sub(self.current_bytes),
53            self.max_bytes,
54        )
55    }
56}
57
58impl std::error::Error for OomError {}
59
60// ---------------------------------------------------------------------------
61// MemoryLimiter
62// ---------------------------------------------------------------------------
63
64/// A cooperative memory limiter backed by an atomic byte counter.
65///
66/// Use [`try_allocate`](MemoryLimiter::try_allocate) to reserve bytes.
67/// The returned [`AllocationGuard`] releases the reservation when dropped.
68#[derive(Debug)]
69pub struct MemoryLimiter {
70    max_bytes: usize,
71    current: Arc<AtomicUsize>,
72}
73
74impl MemoryLimiter {
75    /// Create a new limiter with the given `max_bytes` ceiling.
76    pub fn new(max_bytes: usize) -> Self {
77        Self {
78            max_bytes,
79            current: Arc::new(AtomicUsize::new(0)),
80        }
81    }
82
83    /// Return the number of bytes currently reserved.
84    pub fn current_bytes(&self) -> usize {
85        self.current.load(Ordering::Acquire)
86    }
87
88    /// Return the configured maximum, in bytes.
89    pub fn max_bytes(&self) -> usize {
90        self.max_bytes
91    }
92
93    /// Attempt to reserve `n` bytes.
94    ///
95    /// Succeeds if `current + n <= max_bytes`; otherwise returns [`OomError`].
96    /// The reservation is released automatically when the returned
97    /// [`AllocationGuard`] is dropped.
98    ///
99    /// The implementation uses a compare-exchange loop to ensure that the
100    /// counter never transiently exceeds `max_bytes`, even under heavy
101    /// concurrent load.
102    pub fn try_allocate(&self, n: usize) -> Result<AllocationGuard, OomError> {
103        loop {
104            let cur = self.current.load(Ordering::Acquire);
105            if cur + n > self.max_bytes {
106                return Err(OomError {
107                    max_bytes: self.max_bytes,
108                    current_bytes: cur,
109                    requested: n,
110                });
111            }
112            match self
113                .current
114                .compare_exchange(cur, cur + n, Ordering::AcqRel, Ordering::Acquire)
115            {
116                Ok(_) => {
117                    return Ok(AllocationGuard {
118                        n,
119                        current: Arc::clone(&self.current),
120                    });
121                }
122                Err(_) => continue,
123            }
124        }
125    }
126}
127
128// ---------------------------------------------------------------------------
129// AllocationGuard
130// ---------------------------------------------------------------------------
131
132/// RAII guard that releases a reservation from a [`MemoryLimiter`] on drop.
133pub struct AllocationGuard {
134    n: usize,
135    current: Arc<AtomicUsize>,
136}
137
138impl Drop for AllocationGuard {
139    fn drop(&mut self) {
140        self.current.fetch_sub(self.n, Ordering::AcqRel);
141    }
142}
143
144impl fmt::Debug for AllocationGuard {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        f.debug_struct("AllocationGuard")
147            .field("reserved_bytes", &self.n)
148            .finish()
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Tests
154// ---------------------------------------------------------------------------
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use std::sync::Arc;
160
161    #[test]
162    fn test_memory_limiter_allows_under_limit() {
163        let limiter = MemoryLimiter::new(1024);
164        let guard = limiter
165            .try_allocate(512)
166            .expect("should succeed under limit");
167        assert_eq!(limiter.current_bytes(), 512);
168        drop(guard);
169        assert_eq!(limiter.current_bytes(), 0);
170    }
171
172    #[test]
173    fn test_memory_limiter_rejects_over_limit() {
174        let limiter = MemoryLimiter::new(1024);
175        let _guard = limiter
176            .try_allocate(900)
177            .expect("first alloc should succeed");
178        let err = limiter
179            .try_allocate(200)
180            .expect_err("should reject over limit");
181        assert_eq!(err.max_bytes, 1024);
182        assert_eq!(err.requested, 200);
183        assert!(err.current_bytes >= 900);
184    }
185
186    #[test]
187    fn test_memory_limiter_releases_on_drop() {
188        let limiter = MemoryLimiter::new(1024);
189        {
190            let _guard = limiter.try_allocate(512).expect("should succeed");
191            assert_eq!(limiter.current_bytes(), 512);
192        }
193        // Guard dropped — bytes should be released.
194        assert_eq!(limiter.current_bytes(), 0);
195        // A fresh allocation should now succeed.
196        let _guard2 = limiter
197            .try_allocate(1024)
198            .expect("full budget available again");
199        assert_eq!(limiter.current_bytes(), 1024);
200    }
201
202    #[test]
203    fn test_memory_limiter_concurrent_allocations() {
204        use std::sync::Barrier;
205        use std::thread;
206
207        // limiter max = 5 KB
208        let max: usize = 5 * 1024;
209        let limiter = Arc::new(MemoryLimiter::new(max));
210        let successes = Arc::new(AtomicUsize::new(0));
211
212        // All 10 threads rendezvous at the barrier before attempting allocations,
213        // ensuring they race simultaneously so that at most 5 can succeed.
214        let barrier = Arc::new(Barrier::new(10));
215
216        // A second barrier ensures all allocation attempts are complete before any
217        // guard is released, so the counter cannot dip between attempts.
218        let barrier2 = Arc::new(Barrier::new(10));
219
220        // Spawn 10 threads; each tries to allocate 1 KB.
221        // Only 5 should succeed.
222        let handles: Vec<_> = (0..10)
223            .map(|_| {
224                let limiter = Arc::clone(&limiter);
225                let successes = Arc::clone(&successes);
226                let b1 = Arc::clone(&barrier);
227                let b2 = Arc::clone(&barrier2);
228                thread::spawn(move || {
229                    // Wait until all threads are ready.
230                    b1.wait();
231                    let guard = limiter.try_allocate(1024);
232                    if guard.is_ok() {
233                        successes.fetch_add(1, Ordering::Relaxed);
234                    }
235                    // Synchronise so all threads have attempted allocation before
236                    // any guard is released.
237                    b2.wait();
238                    // guard (if Some) dropped here, releasing its reservation.
239                    drop(guard);
240                })
241            })
242            .collect();
243
244        for handle in handles {
245            handle.join().expect("thread panicked");
246        }
247
248        // After all threads finish, the counter must be 0 (all guards dropped).
249        assert_eq!(limiter.current_bytes(), 0);
250
251        // Exactly 5 threads should have succeeded (5 × 1024 == 5120 max).
252        assert_eq!(successes.load(Ordering::Relaxed), 5);
253    }
254}