#![allow(unsafe_code)]
use std::alloc::{GlobalAlloc, Layout, System};
use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
thread_local! {
static QUERY_BYTES: Cell<usize> = const { Cell::new(0) };
static QUERY_LIMIT: Cell<usize> = const { Cell::new(0) };
static QUERY_CANCEL: Cell<*const AtomicBool> = const { Cell::new(core::ptr::null()) };
}
pub struct BudgetAllocator;
unsafe impl GlobalAlloc for BudgetAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
let p = unsafe { System.alloc(layout) };
if !p.is_null() {
note_alloc(layout.size());
}
p
}
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
let p = unsafe { System.alloc_zeroed(layout) };
if !p.is_null() {
note_alloc(layout.size());
}
p
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
note_dealloc(layout.size());
unsafe { System.dealloc(ptr, layout) };
}
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
let p = unsafe { System.realloc(ptr, layout, new_size) };
if !p.is_null() {
if new_size >= layout.size() {
note_alloc(new_size - layout.size());
} else {
note_dealloc(layout.size() - new_size);
}
}
p
}
}
#[inline]
fn note_alloc(size: usize) {
let _ = QUERY_LIMIT.try_with(|lim| {
let limit = lim.get();
if limit == 0 {
return;
}
let used = QUERY_BYTES.with(|b| {
let n = b.get().saturating_add(size);
b.set(n);
n
});
if used > limit {
QUERY_CANCEL.with(|c| {
let ptr = c.get();
if !ptr.is_null() {
unsafe { (*ptr).store(true, Ordering::Relaxed) };
}
});
}
});
}
#[inline]
fn note_dealloc(size: usize) {
let _ = QUERY_LIMIT.try_with(|lim| {
if lim.get() == 0 {
return;
}
QUERY_BYTES.with(|b| b.set(b.get().saturating_sub(size)));
});
}
pub fn reset_query_budget(limit_bytes: usize, cancel: &AtomicBool) {
QUERY_BYTES.with(|b| b.set(0));
QUERY_LIMIT.with(|l| l.set(limit_bytes));
QUERY_CANCEL.with(|c| c.set(core::ptr::from_ref(cancel)));
}
pub fn clear_query_budget() {
QUERY_LIMIT.with(|l| l.set(0));
QUERY_CANCEL.with(|c| c.set(core::ptr::null()));
QUERY_BYTES.with(|b| b.set(0));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn disabled_budget_never_tracks() {
clear_query_budget(); note_alloc(1 << 30);
QUERY_BYTES.with(|b| assert_eq!(b.get(), 0, "untracked thread must not count"));
}
#[test]
fn single_over_limit_alloc_trips_cancel_flag() {
let flag = AtomicBool::new(false);
reset_query_budget(1024, &flag);
note_alloc(4096); let tripped = flag.load(Ordering::Relaxed);
clear_query_budget(); assert!(
tripped,
"a single allocation past the ceiling must trip the flag"
);
}
#[test]
fn dealloc_subtracts_from_net() {
let flag = AtomicBool::new(false);
reset_query_budget(usize::MAX, &flag); note_alloc(8192);
let after_alloc = QUERY_BYTES.with(Cell::get);
note_dealloc(8192);
let after_dealloc = QUERY_BYTES.with(Cell::get);
clear_query_budget();
assert!(
after_dealloc < after_alloc,
"dealloc must lower the net counter ({after_dealloc} !< {after_alloc})"
);
assert!(!flag.load(Ordering::Relaxed), "huge limit must never trip");
}
}