use std::io;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug)]
pub struct ByteBudget {
used: AtomicU64,
limit: u64,
label: String,
}
impl ByteBudget {
pub fn new(limit: u64) -> Self {
Self::labeled(limit, "memory")
}
pub fn labeled(limit: u64, label: impl Into<String>) -> Self {
Self {
used: AtomicU64::new(0),
limit,
label: label.into(),
}
}
pub fn limit(&self) -> u64 {
self.limit
}
pub fn used(&self) -> u64 {
self.used.load(Ordering::Acquire)
}
pub fn label(&self) -> &str {
&self.label
}
pub fn remaining(&self) -> u64 {
self.limit.saturating_sub(self.used())
}
pub fn try_charge(&self, bytes: u64) -> io::Result<()> {
let exhausted = |current: u64| {
io::Error::new(
io::ErrorKind::StorageFull,
format!(
"{} budget exhausted: {} bytes used + {} requested exceeds the {} byte limit",
self.label, current, bytes, self.limit
),
)
};
let mut current = self.used.load(Ordering::Relaxed);
loop {
let next = current.checked_add(bytes).ok_or_else(|| exhausted(current))?;
if next > self.limit {
return Err(exhausted(current));
}
match self
.used
.compare_exchange_weak(current, next, Ordering::AcqRel, Ordering::Relaxed)
{
Ok(_) => return Ok(()),
Err(actual) => current = actual,
}
}
}
pub fn credit(&self, bytes: u64) {
let previous = self.used.fetch_sub(bytes, Ordering::AcqRel);
assert!(
previous >= bytes,
"ByteBudget '{}' underflow: credited {} bytes with only {} charged — accounting bug",
self.label,
bytes,
previous
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_charge_and_credit() {
let budget = ByteBudget::new(100);
budget.try_charge(60).unwrap();
assert_eq!(budget.used(), 60);
assert_eq!(budget.remaining(), 40);
budget.credit(20);
assert_eq!(budget.used(), 40);
}
#[test]
fn test_exceeding_fails_loudly_without_charging() {
let budget = ByteBudget::labeled(100, "vfs-budget");
budget.try_charge(90).unwrap();
let error = budget.try_charge(11).unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::StorageFull);
assert!(error.to_string().contains("vfs-budget"));
assert!(error.to_string().contains("100"));
assert_eq!(budget.used(), 90);
budget.try_charge(10).unwrap();
}
#[test]
fn test_overflow_treated_as_full() {
let budget = ByteBudget::new(u64::MAX);
budget.try_charge(u64::MAX - 1).unwrap();
let error = budget.try_charge(u64::MAX).unwrap_err();
assert_eq!(error.kind(), io::ErrorKind::StorageFull);
}
#[test]
#[should_panic(expected = "underflow")]
fn test_credit_underflow_panics() {
let budget = ByteBudget::new(100);
budget.try_charge(10).unwrap();
budget.credit(11);
}
}