use hopper_runtime::error::ProgramError;
pub struct ReallocGuard<const N: usize> {
original: [u32; N],
current: [u32; N],
count: usize,
budget: u32,
consumed: u32,
}
impl<const N: usize> ReallocGuard<N> {
#[inline(always)]
pub const fn new(budget: u32) -> Self {
Self {
original: [0u32; N],
current: [0u32; N],
count: 0,
budget,
consumed: 0,
}
}
#[inline(always)]
pub fn register(&mut self, slot: usize, size: usize) -> Result<(), ProgramError> {
if slot >= N {
return Err(ProgramError::InvalidArgument);
}
let size32 = size as u32;
self.original[slot] = size32;
self.current[slot] = size32;
if slot >= self.count {
self.count = slot + 1;
}
Ok(())
}
#[inline]
pub fn check_growth(&self, slot: usize, new_size: usize) -> Result<(), ProgramError> {
if slot >= self.count {
return Err(ProgramError::InvalidArgument);
}
let new_size32 = new_size as u32;
let current = self.current[slot];
if new_size32 <= current {
return Ok(());
}
let delta = new_size32 - current;
let new_consumed = self
.consumed
.checked_add(delta)
.ok_or(ProgramError::ArithmeticOverflow)?;
if new_consumed > self.budget {
return Err(ProgramError::InvalidRealloc);
}
Ok(())
}
#[inline(always)]
pub fn commit_growth(&mut self, slot: usize, new_size: usize) -> Result<(), ProgramError> {
if slot >= self.count {
return Err(ProgramError::InvalidArgument);
}
let new_size32 = new_size as u32;
let current = self.current[slot];
if new_size32 > current {
let delta = new_size32 - current;
self.consumed += delta;
} else if new_size32 < current {
let credit = current - new_size32;
self.consumed = self.consumed.saturating_sub(credit);
}
self.current[slot] = new_size32;
Ok(())
}
#[inline(always)]
pub const fn remaining(&self) -> u32 {
self.budget.saturating_sub(self.consumed)
}
#[inline(always)]
pub const fn consumed(&self) -> u32 {
self.consumed
}
#[inline(always)]
pub const fn budget(&self) -> u32 {
self.budget
}
#[inline(always)]
pub fn slot_growth(&self, slot: usize) -> i32 {
if slot >= self.count {
return 0;
}
self.current[slot] as i32 - self.original[slot] as i32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_growth_tracking() {
let mut guard = ReallocGuard::<4>::new(1024);
guard.register(0, 100).unwrap();
guard.register(1, 200).unwrap();
assert!(guard.check_growth(0, 200).is_ok());
guard.commit_growth(0, 200).unwrap();
assert_eq!(guard.consumed(), 100);
assert_eq!(guard.remaining(), 924);
assert_eq!(guard.slot_growth(0), 100);
assert_eq!(guard.slot_growth(1), 0);
}
#[test]
fn budget_exceeded() {
let mut guard = ReallocGuard::<4>::new(100);
guard.register(0, 50).unwrap();
assert!(guard.check_growth(0, 200).is_err()); }
#[test]
fn shrink_returns_credit() {
let mut guard = ReallocGuard::<4>::new(200);
guard.register(0, 100).unwrap();
guard.commit_growth(0, 200).unwrap();
assert_eq!(guard.consumed(), 100);
guard.commit_growth(0, 150).unwrap();
assert_eq!(guard.consumed(), 50);
assert_eq!(guard.remaining(), 150);
}
#[test]
fn same_size_is_noop() {
let mut guard = ReallocGuard::<4>::new(100);
guard.register(0, 100).unwrap();
assert!(guard.check_growth(0, 100).is_ok());
guard.commit_growth(0, 100).unwrap();
assert_eq!(guard.consumed(), 0);
}
#[test]
fn register_out_of_bounds() {
let mut guard = ReallocGuard::<2>::new(1024);
assert!(guard.register(0, 100).is_ok());
assert!(guard.register(1, 200).is_ok());
assert!(guard.register(2, 300).is_err()); }
#[test]
fn commit_unregistered_slot() {
let mut guard = ReallocGuard::<4>::new(1024);
guard.register(0, 100).unwrap();
assert!(guard.commit_growth(3, 200).is_err());
}
}