use std::time::Instant;
use crate::error::SolverError;
use crate::types::ComputeBudget;
const DEFAULT_MEMORY_LIMIT: usize = 256 * 1024 * 1024;
pub struct BudgetEnforcer {
start_time: Instant,
budget: ComputeBudget,
iterations_used: usize,
memory_used: usize,
memory_limit: usize,
}
impl BudgetEnforcer {
pub fn new(budget: ComputeBudget) -> Self {
Self {
start_time: Instant::now(),
budget,
iterations_used: 0,
memory_used: 0,
memory_limit: DEFAULT_MEMORY_LIMIT,
}
}
pub fn with_memory_limit(budget: ComputeBudget, memory_limit: usize) -> Self {
Self {
start_time: Instant::now(),
budget,
iterations_used: 0,
memory_used: 0,
memory_limit,
}
}
pub fn check_iteration(&mut self) -> Result<(), SolverError> {
self.iterations_used += 1;
if self.iterations_used > self.budget.max_iterations {
return Err(SolverError::BudgetExhausted {
reason: format!(
"iteration limit reached ({} > {})",
self.iterations_used, self.budget.max_iterations,
),
elapsed: self.start_time.elapsed(),
});
}
let elapsed = self.start_time.elapsed();
if elapsed > self.budget.max_time {
return Err(SolverError::BudgetExhausted {
reason: format!(
"wall-clock time limit reached ({:.2?} > {:.2?})",
elapsed, self.budget.max_time,
),
elapsed,
});
}
Ok(())
}
pub fn check_memory(&mut self, additional: usize) -> Result<(), SolverError> {
let new_total = self.memory_used.saturating_add(additional);
if new_total > self.memory_limit {
return Err(SolverError::BudgetExhausted {
reason: format!(
"memory limit reached ({} + {} = {} > {} bytes)",
self.memory_used, additional, new_total, self.memory_limit,
),
elapsed: self.start_time.elapsed(),
});
}
self.memory_used = new_total;
Ok(())
}
#[inline]
pub fn elapsed_us(&self) -> u64 {
self.start_time.elapsed().as_micros() as u64
}
#[inline]
pub fn elapsed(&self) -> std::time::Duration {
self.start_time.elapsed()
}
#[inline]
pub fn iterations_used(&self) -> usize {
self.iterations_used
}
#[inline]
pub fn memory_used(&self) -> usize {
self.memory_used
}
#[inline]
pub fn tolerance(&self) -> f64 {
self.budget.tolerance
}
#[inline]
pub fn budget(&self) -> &ComputeBudget {
&self.budget
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ComputeBudget;
use std::time::Duration;
fn tiny_budget() -> ComputeBudget {
ComputeBudget {
max_time: Duration::from_secs(60),
max_iterations: 5,
tolerance: 1e-6,
}
}
#[test]
fn iterations_within_budget() {
let mut enforcer = BudgetEnforcer::new(tiny_budget());
for _ in 0..5 {
enforcer.check_iteration().unwrap();
}
assert_eq!(enforcer.iterations_used(), 5);
}
#[test]
fn iteration_limit_exceeded() {
let mut enforcer = BudgetEnforcer::new(tiny_budget());
for _ in 0..5 {
enforcer.check_iteration().unwrap();
}
let err = enforcer.check_iteration().unwrap_err();
match err {
SolverError::BudgetExhausted { ref reason, .. } => {
assert!(reason.contains("iteration"), "reason: {reason}");
}
other => panic!("expected BudgetExhausted, got {other:?}"),
}
}
#[test]
fn wall_clock_limit_exceeded() {
let budget = ComputeBudget {
max_time: Duration::from_nanos(1), max_iterations: 1_000_000,
tolerance: 1e-6,
};
let mut enforcer = BudgetEnforcer::new(budget);
std::thread::sleep(Duration::from_micros(10));
let err = enforcer.check_iteration().unwrap_err();
match err {
SolverError::BudgetExhausted { ref reason, .. } => {
assert!(reason.contains("wall-clock"), "reason: {reason}");
}
other => panic!("expected BudgetExhausted for time, got {other:?}"),
}
}
#[test]
fn memory_within_budget() {
let mut enforcer = BudgetEnforcer::with_memory_limit(tiny_budget(), 1024);
enforcer.check_memory(512).unwrap();
enforcer.check_memory(512).unwrap();
assert_eq!(enforcer.memory_used(), 1024);
}
#[test]
fn memory_limit_exceeded() {
let mut enforcer = BudgetEnforcer::with_memory_limit(tiny_budget(), 1024);
enforcer.check_memory(800).unwrap();
let err = enforcer.check_memory(300).unwrap_err();
match err {
SolverError::BudgetExhausted { ref reason, .. } => {
assert!(reason.contains("memory"), "reason: {reason}");
}
other => panic!("expected BudgetExhausted for memory, got {other:?}"),
}
assert_eq!(enforcer.memory_used(), 800);
}
#[test]
fn memory_saturating_add_no_panic() {
let limit = usize::MAX / 2;
let mut enforcer = BudgetEnforcer::with_memory_limit(tiny_budget(), limit);
enforcer.check_memory(limit - 1).unwrap();
let err = enforcer.check_memory(usize::MAX).unwrap_err();
assert!(matches!(err, SolverError::BudgetExhausted { .. }));
}
#[test]
fn elapsed_us_positive() {
let enforcer = BudgetEnforcer::new(tiny_budget());
let _ = enforcer.elapsed_us();
}
#[test]
fn tolerance_accessor() {
let enforcer = BudgetEnforcer::new(tiny_budget());
assert!((enforcer.tolerance() - 1e-6).abs() < f64::EPSILON);
}
#[test]
fn budget_accessor() {
let budget = tiny_budget();
let enforcer = BudgetEnforcer::new(budget.clone());
assert_eq!(enforcer.budget().max_iterations, 5);
}
}