use std::cell::{Cell, RefCell};
use std::fmt;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReferenceBudget {
pub max_per_case_ms: u64,
pub max_total_seconds: u64,
pub max_input_bytes: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReferenceBombDetected {
pub op_id: String,
pub case_index: u64,
pub elapsed_ms: u64,
pub budget_ms: u64,
}
impl fmt::Display for ReferenceBombDetected {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ReferenceBombDetected: op_id={} case_index={} elapsed_ms={} budget_ms={}. \
Fix: reference implementation exceeds wall-clock budget; check for algorithmic bombs.",
self.op_id, self.case_index, self.elapsed_ms, self.budget_ms
)
}
}
impl std::error::Error for ReferenceBombDetected {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Archetype(pub &'static str);
#[must_use]
#[inline]
pub fn default_budget_for_archetype(arch: &Archetype) -> ReferenceBudget {
match arch.0 {
"hash-bytes-to-u32" | "hash-bytes-to-u64" => ReferenceBudget {
max_per_case_ms: 1,
max_total_seconds: 60,
max_input_bytes: 64 * 1024,
},
"decode-bytes-to-bytes" => ReferenceBudget {
max_per_case_ms: 5,
max_total_seconds: 120,
max_input_bytes: 64 * 1024,
},
"compression-bytes-to-bytes" => ReferenceBudget {
max_per_case_ms: 50,
max_total_seconds: 300,
max_input_bytes: 64 * 1024,
},
_ => ReferenceBudget {
max_per_case_ms: 1,
max_total_seconds: 30,
max_input_bytes: 1024,
},
}
}
pub struct BudgetTracker {
budget: ReferenceBudget,
op_id: String,
total_elapsed: Duration,
case_index: u64,
}
impl BudgetTracker {
#[inline]
pub fn new(budget: ReferenceBudget, op_id: &str) -> Self {
Self {
budget,
op_id: op_id.to_string(),
total_elapsed: Duration::ZERO,
case_index: 0,
}
}
#[inline]
pub fn case_index(&self) -> u64 {
self.case_index
}
#[inline]
pub fn check_input(&self, input: &[u8]) -> Result<(), ReferenceBombDetected> {
if input.len() > self.budget.max_input_bytes {
return Err(ReferenceBombDetected {
op_id: self.op_id.clone(),
case_index: self.case_index,
elapsed_ms: 0,
budget_ms: self.budget.max_per_case_ms,
});
}
Ok(())
}
#[inline]
pub fn record_case(&mut self, elapsed: Duration) -> Result<(), ReferenceBombDetected> {
let elapsed_ms = elapsed.as_millis() as u64;
if elapsed_ms > self.budget.max_per_case_ms {
return Err(ReferenceBombDetected {
op_id: self.op_id.clone(),
case_index: self.case_index,
elapsed_ms,
budget_ms: self.budget.max_per_case_ms,
});
}
self.total_elapsed += elapsed;
if self.total_elapsed.as_millis() as u64 > self.budget.max_total_seconds * 1000 {
return Err(ReferenceBombDetected {
op_id: self.op_id.clone(),
case_index: self.case_index,
elapsed_ms,
budget_ms: self.budget.max_per_case_ms,
});
}
self.case_index += 1;
Ok(())
}
#[inline]
pub fn finish(self) -> Result<(), ReferenceBombDetected> {
if self.total_elapsed.as_millis() as u64 > self.budget.max_total_seconds * 1000 {
return Err(ReferenceBombDetected {
op_id: self.op_id,
case_index: self.case_index.saturating_sub(1),
elapsed_ms: self.total_elapsed.as_millis() as u64,
budget_ms: self.budget.max_per_case_ms,
});
}
Ok(())
}
}
#[inline]
pub fn call_with_budget<F>(
f: F,
tracker: &mut BudgetTracker,
input: &[u8],
) -> Result<Vec<u8>, ReferenceBombDetected>
where
F: FnOnce() -> Vec<u8>,
{
tracker.check_input(input)?;
let start = Instant::now();
let output = f();
let elapsed = start.elapsed();
tracker.record_case(elapsed)?;
Ok(output)
}
thread_local! {
static ACTIVE_BUDGET: RefCell<Option<BudgetTracker>> = const { RefCell::new(None) };
static ACTIVE_FN: Cell<Option<fn(&[u8]) -> Vec<u8>>> = const { Cell::new(None) };
static LAST_BOMB: RefCell<Option<ReferenceBombDetected>> = const { RefCell::new(None) };
}
#[inline]
pub fn with_certify_budget<F, R>(
tracker: BudgetTracker,
real_fn: fn(&[u8]) -> Vec<u8>,
f: F,
) -> Result<R, ReferenceBombDetected>
where
F: FnOnce() -> R,
{
ACTIVE_BUDGET.with(|b| {
*b.borrow_mut() = Some(tracker);
});
ACTIVE_FN.with(|c| c.set(Some(real_fn)));
LAST_BOMB.with(|b| {
*b.borrow_mut() = None;
});
let result = f();
ACTIVE_FN.with(|c| c.set(None));
let outcome = ACTIVE_BUDGET.with(|b| {
let tracker = b.borrow_mut().take();
if let Some(tracker) = tracker {
tracker.finish()
} else {
Ok(())
}
});
outcome.map(|()| result)
}
#[inline]
pub fn certify_budget_wrapper(input: &[u8]) -> Vec<u8> {
let real_fn = ACTIVE_FN
.with(|c| c.get())
.expect("certify_budget_wrapper called without active function");
let start = Instant::now();
let output = real_fn(input);
let elapsed = start.elapsed();
ACTIVE_BUDGET.with(|b| {
if let Some(ref mut tracker) = *b.borrow_mut() {
if let Err(bomb) = tracker.record_case(elapsed) {
LAST_BOMB.with(|lb| {
*lb.borrow_mut() = Some(bomb);
});
}
}
});
output
}
#[inline]
pub fn take_last_bomb() -> Option<ReferenceBombDetected> {
LAST_BOMB.with(|b| b.borrow_mut().take())
}
#[inline]
pub fn with_exec_budget<F, R>(tracker: BudgetTracker, f: F) -> Result<R, ReferenceBombDetected>
where
F: FnOnce() -> R,
{
ACTIVE_BUDGET.with(|b| {
*b.borrow_mut() = Some(tracker);
});
let result = f();
ACTIVE_BUDGET
.with(|b| {
let tracker = b.borrow_mut().take();
if let Some(tracker) = tracker {
tracker.finish()
} else {
Ok(())
}
})
.map(|()| result)
}
#[inline]
pub fn exec_budget_record(elapsed: Duration) -> Result<(), ReferenceBombDetected> {
ACTIVE_BUDGET.with(|b| {
if let Some(ref mut tracker) = *b.borrow_mut() {
tracker.record_case(elapsed)
} else {
Ok(())
}
})
}
#[must_use]
#[inline]
pub fn load_budget_override(op_id: &str) -> Option<ReferenceBudget> {
let path = format!("core/src/ops/{}/spec.toml", op_id.replace('.', "/"));
let text = std::fs::read_to_string(&path).ok()?;
let value: toml::Value = text.parse().ok()?;
let verify = value.get("verify")?;
let budget_table = verify.get("budget")?;
let max_per_case_ms = budget_table.get("max_per_case_ms")?.as_integer()? as u64;
let max_total_seconds = budget_table.get("max_total_seconds")?.as_integer()? as u64;
let max_input_bytes = budget_table.get("max_input_bytes")?.as_integer()? as usize;
Some(ReferenceBudget {
max_per_case_ms,
max_total_seconds,
max_input_bytes,
})
}
#[must_use]
#[inline]
pub fn budget_for_op(op_id: &str, archetype: &Archetype) -> ReferenceBudget {
load_budget_override(op_id).unwrap_or_else(|| default_budget_for_archetype(archetype))
}
#[cfg(test)]
mod tests {
use super::{
certify_budget_wrapper, default_budget_for_archetype, take_last_bomb, with_certify_budget,
Archetype, BudgetTracker, ReferenceBudget,
};
use std::time::Duration;
#[test]
fn default_budget_hash_bytes_to_u32() {
let b = default_budget_for_archetype(&Archetype("hash-bytes-to-u32"));
assert_eq!(b.max_per_case_ms, 1);
assert_eq!(b.max_total_seconds, 60);
assert_eq!(b.max_input_bytes, 64 * 1024);
}
#[test]
fn default_budget_decode_bytes_to_bytes() {
let b = default_budget_for_archetype(&Archetype("decode-bytes-to-bytes"));
assert_eq!(b.max_per_case_ms, 5);
assert_eq!(b.max_total_seconds, 120);
assert_eq!(b.max_input_bytes, 64 * 1024);
}
#[test]
fn default_budget_compression_bytes_to_bytes() {
let b = default_budget_for_archetype(&Archetype("compression-bytes-to-bytes"));
assert_eq!(b.max_per_case_ms, 50);
assert_eq!(b.max_total_seconds, 300);
assert_eq!(b.max_input_bytes, 64 * 1024);
}
#[test]
fn default_budget_unknown_falls_back() {
let b = default_budget_for_archetype(&Archetype("unknown-archetype"));
assert_eq!(b.max_per_case_ms, 1);
assert_eq!(b.max_total_seconds, 30);
assert_eq!(b.max_input_bytes, 1024);
}
#[test]
fn tracker_rejects_per_case_overrun() {
let budget = ReferenceBudget {
max_per_case_ms: 5,
max_total_seconds: 60,
max_input_bytes: 1024,
};
let mut tracker = BudgetTracker::new(budget, "test.op");
let err = tracker.record_case(Duration::from_millis(6)).unwrap_err();
assert_eq!(err.case_index, 0);
assert_eq!(err.elapsed_ms, 6);
assert_eq!(err.budget_ms, 5);
}
#[test]
fn tracker_rejects_total_overrun() {
let budget = ReferenceBudget {
max_per_case_ms: 1000,
max_total_seconds: 1,
max_input_bytes: 1024,
};
let mut tracker = BudgetTracker::new(budget, "test.op");
tracker.record_case(Duration::from_millis(500)).unwrap();
let err = tracker.record_case(Duration::from_millis(600)).unwrap_err();
assert_eq!(err.budget_ms, 1000);
}
#[test]
fn tracker_rejects_oversized_input() {
let budget = ReferenceBudget {
max_per_case_ms: 1000,
max_total_seconds: 60,
max_input_bytes: 4,
};
let tracker = BudgetTracker::new(budget, "test.op");
let err = tracker.check_input(&[0; 5]).unwrap_err();
assert_eq!(err.case_index, 0);
}
#[test]
fn certify_wrapper_detects_bomb() {
let budget = ReferenceBudget {
max_per_case_ms: 1,
max_total_seconds: 60,
max_input_bytes: 1024,
};
let tracker = BudgetTracker::new(budget, "test.op");
let slow_fn: fn(&[u8]) -> Vec<u8> = |_input| {
std::thread::sleep(std::time::Duration::from_millis(10));
vec![1]
};
with_certify_budget(tracker, slow_fn, || {
certify_budget_wrapper(&[]);
})
.unwrap();
let bomb = take_last_bomb().expect("expected a bomb");
assert_eq!(bomb.op_id, "test.op");
assert!(bomb.elapsed_ms >= 10);
}
}