amaters_core/memory_limiter.rs
1//! Memory-use limiter with cooperative back-pressure.
2//!
3//! [`MemoryLimiter`] tracks a global byte count via an [`AtomicUsize`] counter.
4//! Callers must obtain an [`AllocationGuard`] via [`MemoryLimiter::try_allocate`];
5//! the guard decrements the counter automatically on drop, ensuring the tracked
6//! usage stays accurate even across panics.
7//!
8//! # Design
9//!
10//! The compare-exchange loop in [`MemoryLimiter::try_allocate`] gives linearisable
11//! semantics: either the allocation is fully visible or it is rejected, with no
12//! window where the counter exceeds `max_bytes`.
13//!
14//! # Example
15//!
16//! ```rust
17//! use amaters_core::memory_limiter::MemoryLimiter;
18//!
19//! let limiter = MemoryLimiter::new(1024);
20//! let guard = limiter.try_allocate(512).expect("should fit");
21//! assert_eq!(limiter.current_bytes(), 512);
22//! drop(guard);
23//! assert_eq!(limiter.current_bytes(), 0);
24//! ```
25
26use std::fmt;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicUsize, Ordering};
29
30// ---------------------------------------------------------------------------
31// OomError
32// ---------------------------------------------------------------------------
33
34/// Error returned when a [`MemoryLimiter`] rejects an allocation because the
35/// requested bytes would exceed `max_bytes`.
36#[derive(Debug)]
37pub struct OomError {
38 /// Configured maximum, in bytes.
39 pub max_bytes: usize,
40 /// Number of bytes currently tracked.
41 pub current_bytes: usize,
42 /// Number of bytes that were requested.
43 pub requested: usize,
44}
45
46impl fmt::Display for OomError {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 write!(
49 f,
50 "OOM: requested {} bytes but only {} / {} available",
51 self.requested,
52 self.max_bytes.saturating_sub(self.current_bytes),
53 self.max_bytes,
54 )
55 }
56}
57
58impl std::error::Error for OomError {}
59
60// ---------------------------------------------------------------------------
61// MemoryLimiter
62// ---------------------------------------------------------------------------
63
64/// A cooperative memory limiter backed by an atomic byte counter.
65///
66/// Use [`try_allocate`](MemoryLimiter::try_allocate) to reserve bytes.
67/// The returned [`AllocationGuard`] releases the reservation when dropped.
68#[derive(Debug)]
69pub struct MemoryLimiter {
70 max_bytes: usize,
71 current: Arc<AtomicUsize>,
72}
73
74impl MemoryLimiter {
75 /// Create a new limiter with the given `max_bytes` ceiling.
76 pub fn new(max_bytes: usize) -> Self {
77 Self {
78 max_bytes,
79 current: Arc::new(AtomicUsize::new(0)),
80 }
81 }
82
83 /// Return the number of bytes currently reserved.
84 pub fn current_bytes(&self) -> usize {
85 self.current.load(Ordering::Acquire)
86 }
87
88 /// Return the configured maximum, in bytes.
89 pub fn max_bytes(&self) -> usize {
90 self.max_bytes
91 }
92
93 /// Attempt to reserve `n` bytes.
94 ///
95 /// Succeeds if `current + n <= max_bytes`; otherwise returns [`OomError`].
96 /// The reservation is released automatically when the returned
97 /// [`AllocationGuard`] is dropped.
98 ///
99 /// The implementation uses a compare-exchange loop to ensure that the
100 /// counter never transiently exceeds `max_bytes`, even under heavy
101 /// concurrent load.
102 pub fn try_allocate(&self, n: usize) -> Result<AllocationGuard, OomError> {
103 loop {
104 let cur = self.current.load(Ordering::Acquire);
105 if cur + n > self.max_bytes {
106 return Err(OomError {
107 max_bytes: self.max_bytes,
108 current_bytes: cur,
109 requested: n,
110 });
111 }
112 match self
113 .current
114 .compare_exchange(cur, cur + n, Ordering::AcqRel, Ordering::Acquire)
115 {
116 Ok(_) => {
117 return Ok(AllocationGuard {
118 n,
119 current: Arc::clone(&self.current),
120 });
121 }
122 Err(_) => continue,
123 }
124 }
125 }
126}
127
128// ---------------------------------------------------------------------------
129// AllocationGuard
130// ---------------------------------------------------------------------------
131
132/// RAII guard that releases a reservation from a [`MemoryLimiter`] on drop.
133pub struct AllocationGuard {
134 n: usize,
135 current: Arc<AtomicUsize>,
136}
137
138impl Drop for AllocationGuard {
139 fn drop(&mut self) {
140 self.current.fetch_sub(self.n, Ordering::AcqRel);
141 }
142}
143
144impl fmt::Debug for AllocationGuard {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 f.debug_struct("AllocationGuard")
147 .field("reserved_bytes", &self.n)
148 .finish()
149 }
150}
151
152// ---------------------------------------------------------------------------
153// Tests
154// ---------------------------------------------------------------------------
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use std::sync::Arc;
160
161 #[test]
162 fn test_memory_limiter_allows_under_limit() {
163 let limiter = MemoryLimiter::new(1024);
164 let guard = limiter
165 .try_allocate(512)
166 .expect("should succeed under limit");
167 assert_eq!(limiter.current_bytes(), 512);
168 drop(guard);
169 assert_eq!(limiter.current_bytes(), 0);
170 }
171
172 #[test]
173 fn test_memory_limiter_rejects_over_limit() {
174 let limiter = MemoryLimiter::new(1024);
175 let _guard = limiter
176 .try_allocate(900)
177 .expect("first alloc should succeed");
178 let err = limiter
179 .try_allocate(200)
180 .expect_err("should reject over limit");
181 assert_eq!(err.max_bytes, 1024);
182 assert_eq!(err.requested, 200);
183 assert!(err.current_bytes >= 900);
184 }
185
186 #[test]
187 fn test_memory_limiter_releases_on_drop() {
188 let limiter = MemoryLimiter::new(1024);
189 {
190 let _guard = limiter.try_allocate(512).expect("should succeed");
191 assert_eq!(limiter.current_bytes(), 512);
192 }
193 // Guard dropped — bytes should be released.
194 assert_eq!(limiter.current_bytes(), 0);
195 // A fresh allocation should now succeed.
196 let _guard2 = limiter
197 .try_allocate(1024)
198 .expect("full budget available again");
199 assert_eq!(limiter.current_bytes(), 1024);
200 }
201
202 #[test]
203 fn test_memory_limiter_concurrent_allocations() {
204 use std::sync::Barrier;
205 use std::thread;
206
207 // limiter max = 5 KB
208 let max: usize = 5 * 1024;
209 let limiter = Arc::new(MemoryLimiter::new(max));
210 let successes = Arc::new(AtomicUsize::new(0));
211
212 // All 10 threads rendezvous at the barrier before attempting allocations,
213 // ensuring they race simultaneously so that at most 5 can succeed.
214 let barrier = Arc::new(Barrier::new(10));
215
216 // A second barrier ensures all allocation attempts are complete before any
217 // guard is released, so the counter cannot dip between attempts.
218 let barrier2 = Arc::new(Barrier::new(10));
219
220 // Spawn 10 threads; each tries to allocate 1 KB.
221 // Only 5 should succeed.
222 let handles: Vec<_> = (0..10)
223 .map(|_| {
224 let limiter = Arc::clone(&limiter);
225 let successes = Arc::clone(&successes);
226 let b1 = Arc::clone(&barrier);
227 let b2 = Arc::clone(&barrier2);
228 thread::spawn(move || {
229 // Wait until all threads are ready.
230 b1.wait();
231 let guard = limiter.try_allocate(1024);
232 if guard.is_ok() {
233 successes.fetch_add(1, Ordering::Relaxed);
234 }
235 // Synchronise so all threads have attempted allocation before
236 // any guard is released.
237 b2.wait();
238 // guard (if Some) dropped here, releasing its reservation.
239 drop(guard);
240 })
241 })
242 .collect();
243
244 for handle in handles {
245 handle.join().expect("thread panicked");
246 }
247
248 // After all threads finish, the counter must be 0 (all guards dropped).
249 assert_eq!(limiter.current_bytes(), 0);
250
251 // Exactly 5 threads should have succeeded (5 × 1024 == 5120 max).
252 assert_eq!(successes.load(Ordering::Relaxed), 5);
253 }
254}