use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct OomError {
pub max_bytes: usize,
pub current_bytes: usize,
pub requested: usize,
}
impl fmt::Display for OomError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"OOM: requested {} bytes but only {} / {} available",
self.requested,
self.max_bytes.saturating_sub(self.current_bytes),
self.max_bytes,
)
}
}
impl std::error::Error for OomError {}
#[derive(Debug)]
pub struct MemoryLimiter {
max_bytes: usize,
current: Arc<AtomicUsize>,
}
impl MemoryLimiter {
pub fn new(max_bytes: usize) -> Self {
Self {
max_bytes,
current: Arc::new(AtomicUsize::new(0)),
}
}
pub fn current_bytes(&self) -> usize {
self.current.load(Ordering::Acquire)
}
pub fn max_bytes(&self) -> usize {
self.max_bytes
}
pub fn try_allocate(&self, n: usize) -> Result<AllocationGuard, OomError> {
loop {
let cur = self.current.load(Ordering::Acquire);
if cur + n > self.max_bytes {
return Err(OomError {
max_bytes: self.max_bytes,
current_bytes: cur,
requested: n,
});
}
match self
.current
.compare_exchange(cur, cur + n, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
return Ok(AllocationGuard {
n,
current: Arc::clone(&self.current),
});
}
Err(_) => continue,
}
}
}
}
pub struct AllocationGuard {
n: usize,
current: Arc<AtomicUsize>,
}
impl Drop for AllocationGuard {
fn drop(&mut self) {
self.current.fetch_sub(self.n, Ordering::AcqRel);
}
}
impl fmt::Debug for AllocationGuard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AllocationGuard")
.field("reserved_bytes", &self.n)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_memory_limiter_allows_under_limit() {
let limiter = MemoryLimiter::new(1024);
let guard = limiter
.try_allocate(512)
.expect("should succeed under limit");
assert_eq!(limiter.current_bytes(), 512);
drop(guard);
assert_eq!(limiter.current_bytes(), 0);
}
#[test]
fn test_memory_limiter_rejects_over_limit() {
let limiter = MemoryLimiter::new(1024);
let _guard = limiter
.try_allocate(900)
.expect("first alloc should succeed");
let err = limiter
.try_allocate(200)
.expect_err("should reject over limit");
assert_eq!(err.max_bytes, 1024);
assert_eq!(err.requested, 200);
assert!(err.current_bytes >= 900);
}
#[test]
fn test_memory_limiter_releases_on_drop() {
let limiter = MemoryLimiter::new(1024);
{
let _guard = limiter.try_allocate(512).expect("should succeed");
assert_eq!(limiter.current_bytes(), 512);
}
assert_eq!(limiter.current_bytes(), 0);
let _guard2 = limiter
.try_allocate(1024)
.expect("full budget available again");
assert_eq!(limiter.current_bytes(), 1024);
}
#[test]
fn test_memory_limiter_concurrent_allocations() {
use std::sync::Barrier;
use std::thread;
let max: usize = 5 * 1024;
let limiter = Arc::new(MemoryLimiter::new(max));
let successes = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(10));
let barrier2 = Arc::new(Barrier::new(10));
let handles: Vec<_> = (0..10)
.map(|_| {
let limiter = Arc::clone(&limiter);
let successes = Arc::clone(&successes);
let b1 = Arc::clone(&barrier);
let b2 = Arc::clone(&barrier2);
thread::spawn(move || {
b1.wait();
let guard = limiter.try_allocate(1024);
if guard.is_ok() {
successes.fetch_add(1, Ordering::Relaxed);
}
b2.wait();
drop(guard);
})
})
.collect();
for handle in handles {
handle.join().expect("thread panicked");
}
assert_eq!(limiter.current_bytes(), 0);
assert_eq!(successes.load(Ordering::Relaxed), 5);
}
}