#![allow(unsafe_code)]
#![allow(clippy::panic)]
use std::alloc::{GlobalAlloc, Layout, System};
use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
const FAIL_DISABLED: usize = 0;
const OOM_SENTINEL: &str = "vyre-conform oom injection";
static ENABLED: AtomicBool = AtomicBool::new(false);
thread_local! {
static COUNTER: Cell<usize> = const { Cell::new(0) };
static FAIL_AT: Cell<usize> = const { Cell::new(FAIL_DISABLED) };
static TRIGGERED: Cell<usize> = const { Cell::new(0) };
}
pub struct OomAllocator;
#[inline]
pub fn arm_thread(fail_at: usize) {
COUNTER.with(|counter| counter.set(0));
FAIL_AT.with(|cell| cell.set(fail_at));
TRIGGERED.with(|cell| cell.set(0));
ENABLED.store(true, Ordering::SeqCst);
}
#[inline]
pub fn disarm_thread() {
ENABLED.store(false, Ordering::SeqCst);
FAIL_AT.with(|cell| cell.set(FAIL_DISABLED));
}
#[inline]
pub fn clear_thread() {
COUNTER.with(|counter| counter.set(0));
FAIL_AT.with(|cell| cell.set(FAIL_DISABLED));
TRIGGERED.with(|cell| cell.set(0));
}
#[inline]
pub fn allocation_count() -> usize {
COUNTER.with(Cell::get)
}
#[inline]
pub fn triggered_at() -> usize {
TRIGGERED.with(Cell::get)
}
#[inline]
pub fn is_oom_payload(payload: &(dyn std::any::Any + Send)) -> bool {
payload
.downcast_ref::<&'static str>()
.map(|message| *message == OOM_SENTINEL)
.unwrap_or(false)
}
unsafe impl GlobalAlloc for OomAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
if !ENABLED.load(Ordering::SeqCst) {
return System.alloc(layout);
}
let should_fail = COUNTER.with(|counter| {
let next = counter.get().saturating_add(1);
counter.set(next);
FAIL_AT.with(|fail_at| fail_at.get() != FAIL_DISABLED && next == fail_at.get())
});
if should_fail {
TRIGGERED.with(|triggered| triggered.set(allocation_count()));
ENABLED.store(false, Ordering::SeqCst);
panic!("{OOM_SENTINEL}");
}
System.alloc(layout)
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
System.dealloc(ptr, layout);
}
}
#[global_allocator]
pub static GLOBAL_ALLOCATOR: OomAllocator = OomAllocator;