use anyhow::{Result, bail};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct RecursionGuard {
max_depth: usize,
current_depth: usize,
max_depth_reached: usize,
}
impl RecursionGuard {
pub fn new(max_depth: usize) -> Result<Self> {
if max_depth == 0 {
bail!("RecursionGuard max_depth cannot be 0");
}
Ok(Self {
max_depth,
current_depth: 0,
max_depth_reached: 0,
})
}
pub fn enter(&mut self) -> Result<(), RecursionError> {
self.current_depth += 1;
if self.current_depth > self.max_depth_reached {
self.max_depth_reached = self.current_depth;
}
if self.current_depth > self.max_depth {
return Err(RecursionError::DepthLimitExceeded {
current: self.current_depth,
limit: self.max_depth,
});
}
Ok(())
}
pub fn exit(&mut self) {
if self.current_depth > 0 {
self.current_depth -= 1;
}
}
#[must_use]
pub fn current_depth(&self) -> usize {
self.current_depth
}
#[must_use]
pub fn max_depth_reached(&self) -> usize {
self.max_depth_reached
}
#[must_use]
pub fn max_depth(&self) -> usize {
self.max_depth
}
}
#[derive(Debug)]
pub struct ExprFuelCounter {
fuel: AtomicUsize,
initial_fuel: usize,
}
impl ExprFuelCounter {
pub fn new(initial_fuel: usize) -> Result<Self> {
if initial_fuel == 0 {
bail!("ExprFuelCounter initial_fuel cannot be 0");
}
Ok(Self {
fuel: AtomicUsize::new(initial_fuel),
initial_fuel,
})
}
pub fn consume(&self, amount: usize) -> Result<(), RecursionError> {
let result = self
.fuel
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
if current >= amount {
Some(current - amount)
} else {
None
}
});
match result {
Ok(_previous) => Ok(()),
Err(current) => Err(RecursionError::FuelExhausted {
remaining: current,
requested: amount,
}),
}
}
#[must_use]
pub fn remaining(&self) -> usize {
self.fuel.load(Ordering::SeqCst)
}
#[must_use]
pub fn initial_fuel(&self) -> usize {
self.initial_fuel
}
#[must_use]
pub fn consumed(&self) -> usize {
self.initial_fuel.saturating_sub(self.remaining())
}
#[must_use]
pub fn has_fuel(&self, amount: usize) -> bool {
self.remaining() >= amount
}
pub fn reset(&self) {
self.fuel.store(self.initial_fuel, Ordering::SeqCst);
}
}
#[derive(Debug, thiserror::Error)]
pub enum RecursionError {
#[error("Recursion depth limit exceeded: depth {current} > limit {limit}")]
DepthLimitExceeded {
current: usize,
limit: usize,
},
#[error(
"Expression evaluation fuel exhausted: requested {requested}, only {remaining} remaining"
)]
FuelExhausted {
remaining: usize,
requested: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_guard_new() {
let guard = RecursionGuard::new(100).unwrap();
assert_eq!(guard.current_depth(), 0);
assert_eq!(guard.max_depth(), 100);
assert_eq!(guard.max_depth_reached(), 0);
}
#[test]
fn test_guard_new_zero_fails() {
let result = RecursionGuard::new(0);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be 0"));
}
#[test]
fn test_guard_enter_exit() {
let mut guard = RecursionGuard::new(10).unwrap();
guard.enter().unwrap();
assert_eq!(guard.current_depth(), 1);
assert_eq!(guard.max_depth_reached(), 1);
guard.enter().unwrap();
assert_eq!(guard.current_depth(), 2);
assert_eq!(guard.max_depth_reached(), 2);
guard.exit();
assert_eq!(guard.current_depth(), 1);
assert_eq!(guard.max_depth_reached(), 2);
guard.exit();
assert_eq!(guard.current_depth(), 0);
}
#[test]
fn test_guard_depth_limit_enforced() {
let mut guard = RecursionGuard::new(3).unwrap();
guard.enter().unwrap(); guard.enter().unwrap(); guard.enter().unwrap();
let err = guard.enter().unwrap_err(); assert!(matches!(
err,
RecursionError::DepthLimitExceeded {
current: 4,
limit: 3
}
));
}
#[test]
fn test_guard_exit_at_zero_is_safe() {
let mut guard = RecursionGuard::new(10).unwrap();
guard.exit(); assert_eq!(guard.current_depth(), 0);
}
#[test]
fn test_guard_max_depth_tracking() {
let mut guard = RecursionGuard::new(100).unwrap();
for _ in 0..5 {
guard.enter().unwrap();
}
assert_eq!(guard.max_depth_reached(), 5);
for _ in 0..3 {
guard.exit();
}
assert_eq!(guard.current_depth(), 2);
assert_eq!(guard.max_depth_reached(), 5);
guard.enter().unwrap();
assert_eq!(guard.max_depth_reached(), 5); }
#[test]
fn test_fuel_new() {
let fuel = ExprFuelCounter::new(1000).unwrap();
assert_eq!(fuel.remaining(), 1000);
assert_eq!(fuel.initial_fuel(), 1000);
assert_eq!(fuel.consumed(), 0);
}
#[test]
fn test_fuel_new_zero_fails() {
let result = ExprFuelCounter::new(0);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be 0"));
}
#[test]
fn test_fuel_consume() {
let fuel = ExprFuelCounter::new(100).unwrap();
fuel.consume(30).unwrap();
assert_eq!(fuel.remaining(), 70);
assert_eq!(fuel.consumed(), 30);
fuel.consume(40).unwrap();
assert_eq!(fuel.remaining(), 30);
assert_eq!(fuel.consumed(), 70);
}
#[test]
fn test_fuel_exhaustion() {
let fuel = ExprFuelCounter::new(50).unwrap();
fuel.consume(30).unwrap();
assert_eq!(fuel.remaining(), 20);
let err = fuel.consume(30).unwrap_err();
assert!(matches!(
err,
RecursionError::FuelExhausted {
remaining: 20,
requested: 30
}
));
assert_eq!(fuel.remaining(), 20);
}
#[test]
fn test_fuel_exact_exhaustion() {
let fuel = ExprFuelCounter::new(100).unwrap();
fuel.consume(100).unwrap();
assert_eq!(fuel.remaining(), 0);
let err = fuel.consume(1).unwrap_err();
assert!(matches!(
err,
RecursionError::FuelExhausted {
remaining: 0,
requested: 1
}
));
}
#[test]
fn test_fuel_has_fuel() {
let fuel = ExprFuelCounter::new(100).unwrap();
assert!(fuel.has_fuel(50));
assert!(fuel.has_fuel(100));
assert!(!fuel.has_fuel(101));
fuel.consume(60).unwrap();
assert!(fuel.has_fuel(40));
assert!(!fuel.has_fuel(41));
}
#[test]
fn test_fuel_reset() {
let fuel = ExprFuelCounter::new(100).unwrap();
fuel.consume(80).unwrap();
assert_eq!(fuel.remaining(), 20);
fuel.reset();
assert_eq!(fuel.remaining(), 100);
assert_eq!(fuel.consumed(), 0);
}
#[test]
fn test_fuel_no_underflow_on_exhaustion() {
let fuel = ExprFuelCounter::new(5).unwrap();
let err = fuel.consume(10).unwrap_err();
assert!(matches!(
err,
RecursionError::FuelExhausted {
remaining: 5,
requested: 10
}
));
assert_eq!(fuel.remaining(), 5);
}
#[test]
fn test_fuel_multiple_small_consumes() {
let fuel = ExprFuelCounter::new(100).unwrap();
for _ in 0..10 {
fuel.consume(10).unwrap();
}
assert_eq!(fuel.remaining(), 0);
assert_eq!(fuel.consumed(), 100);
}
#[test]
fn test_recursive_function_with_guard() {
fn recursive_countdown(
n: usize,
guard: &mut RecursionGuard,
) -> Result<usize, RecursionError> {
guard.enter()?;
let result = if n == 0 {
Ok(0)
} else {
recursive_countdown(n - 1, guard)
};
guard.exit();
result
}
let mut guard = RecursionGuard::new(100).unwrap();
let result = recursive_countdown(50, &mut guard);
assert!(result.is_ok());
assert_eq!(guard.current_depth(), 0); assert_eq!(guard.max_depth_reached(), 51); }
#[test]
fn test_recursive_function_exceeds_limit() {
fn recursive_countdown(
n: usize,
guard: &mut RecursionGuard,
) -> Result<usize, RecursionError> {
guard.enter()?;
let result = if n == 0 {
Ok(0)
} else {
recursive_countdown(n - 1, guard)
};
guard.exit();
result
}
let mut guard = RecursionGuard::new(10).unwrap();
let result = recursive_countdown(20, &mut guard);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RecursionError::DepthLimitExceeded { .. }
));
}
#[test]
fn test_expression_evaluation_with_fuel() {
fn evaluate_tree(nodes: usize, fuel: &ExprFuelCounter) -> Result<(), RecursionError> {
for _ in 0..nodes {
fuel.consume(1)?;
}
Ok(())
}
let fuel = ExprFuelCounter::new(100).unwrap();
let result = evaluate_tree(50, &fuel);
assert!(result.is_ok());
assert_eq!(fuel.remaining(), 50);
}
#[test]
fn test_expression_evaluation_exhausts_fuel() {
fn evaluate_tree(nodes: usize, fuel: &ExprFuelCounter) -> Result<(), RecursionError> {
for _ in 0..nodes {
fuel.consume(1)?;
}
Ok(())
}
let fuel = ExprFuelCounter::new(50).unwrap();
let result = evaluate_tree(100, &fuel);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RecursionError::FuelExhausted { .. }
));
}
}