use std::time::{Duration, Instant};
use crate::id::Cid;
#[derive(Debug)]
pub struct CommitBudgetGuard {
pub tag: &'static str,
pub start: Instant,
pub budget_ms: u32,
pub hard_wall_ms: u32,
pub commit_cid: Cid,
pub deferred: Vec<&'static str>,
pub charged: Vec<(&'static str, u32)>,
pub breached: bool,
pub hard_wall_hit: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Decision {
Proceed,
ShouldDefer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HardWallExceeded {
pub elapsed_ms: u32,
pub hard_wall_ms: u32,
}
impl core::fmt::Display for HardWallExceeded {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"commit-budget hard wall exceeded: elapsed {}ms > wall {}ms",
self.elapsed_ms, self.hard_wall_ms
)
}
}
impl std::error::Error for HardWallExceeded {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CommitBudgetReport {
pub tag: &'static str,
pub elapsed_ms: u32,
pub budget_ms: u32,
pub hard_wall_ms: u32,
pub breached: bool,
pub hard_wall_hit: bool,
pub deferred_stages: Vec<&'static str>,
pub charged_stages: Vec<(&'static str, u32)>,
}
impl CommitBudgetGuard {
#[must_use]
pub fn start(tag: &'static str, budget_ms: u32, hard_wall_ms: u32, commit_cid: Cid) -> Self {
Self {
tag,
start: Instant::now(),
budget_ms,
hard_wall_ms: hard_wall_ms.max(budget_ms),
commit_cid,
deferred: Vec::new(),
charged: Vec::new(),
breached: false,
hard_wall_hit: false,
}
}
#[must_use]
pub fn elapsed_ms(&self) -> u32 {
u32::try_from(self.start.elapsed().as_millis()).unwrap_or(u32::MAX)
}
pub fn charge(&mut self, stage: &'static str) -> Result<Decision, HardWallExceeded> {
let elapsed = self.elapsed_ms();
self.charged.push((stage, elapsed));
if elapsed > self.hard_wall_ms {
self.hard_wall_hit = true;
return Err(HardWallExceeded {
elapsed_ms: elapsed,
hard_wall_ms: self.hard_wall_ms,
});
}
if elapsed > self.budget_ms {
self.breached = true;
return Ok(Decision::ShouldDefer);
}
Ok(Decision::Proceed)
}
#[doc(hidden)]
pub fn charge_with(
&mut self,
stage: &'static str,
elapsed_ms: u32,
) -> Result<Decision, HardWallExceeded> {
self.charged.push((stage, elapsed_ms));
if elapsed_ms > self.hard_wall_ms {
self.hard_wall_hit = true;
return Err(HardWallExceeded {
elapsed_ms,
hard_wall_ms: self.hard_wall_ms,
});
}
if elapsed_ms > self.budget_ms {
self.breached = true;
return Ok(Decision::ShouldDefer);
}
Ok(Decision::Proceed)
}
pub fn defer(&mut self, stage: &'static str) {
self.deferred.push(stage);
}
#[must_use]
pub fn into_report(self) -> CommitBudgetReport {
CommitBudgetReport {
tag: self.tag,
elapsed_ms: u32::try_from(self.start.elapsed().as_millis()).unwrap_or(u32::MAX),
budget_ms: self.budget_ms,
hard_wall_ms: self.hard_wall_ms,
breached: self.breached,
hard_wall_hit: self.hard_wall_hit,
deferred_stages: self.deferred,
charged_stages: self.charged,
}
}
}
#[doc(hidden)]
#[must_use]
pub fn since(start: Instant) -> Duration {
start.elapsed()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::id::Multihash;
fn zero_cid() -> Cid {
Cid::new(
crate::id::CODEC_RAW,
Multihash::wrap(crate::id::HASH_BLAKE3_256, &[0u8; 32]).expect("32-byte digest"),
)
}
#[test]
fn charge_under_budget_proceeds() {
let mut g = CommitBudgetGuard::start("test", 100, 200, zero_cid());
let d = g.charge_with("stage_a", 50).unwrap();
assert_eq!(d, Decision::Proceed);
assert!(!g.breached);
}
#[test]
fn charge_over_budget_defers() {
let mut g = CommitBudgetGuard::start("test", 50, 200, zero_cid());
let d = g.charge_with("stage_a", 75).unwrap();
assert_eq!(d, Decision::ShouldDefer);
assert!(g.breached);
}
#[test]
fn charge_over_hard_wall_aborts() {
let mut g = CommitBudgetGuard::start("test", 50, 100, zero_cid());
let err = g.charge_with("stage_a", 150).unwrap_err();
assert_eq!(err.elapsed_ms, 150);
assert_eq!(err.hard_wall_ms, 100);
assert!(g.hard_wall_hit);
}
#[test]
fn hard_wall_clamped_to_at_least_budget() {
let g = CommitBudgetGuard::start("test", 100, 50, zero_cid());
assert_eq!(g.hard_wall_ms, 100);
}
#[test]
fn report_records_charged_and_deferred() {
let mut g = CommitBudgetGuard::start("test", 50, 200, zero_cid());
let _ = g.charge_with("a", 10).unwrap();
let _ = g.charge_with("b", 80).unwrap(); g.defer("c");
g.defer("d");
let rep = g.into_report();
assert_eq!(rep.tag, "test");
assert!(rep.breached);
assert!(!rep.hard_wall_hit);
assert_eq!(rep.deferred_stages, vec!["c", "d"]);
assert_eq!(rep.charged_stages.len(), 2);
assert_eq!(rep.charged_stages[0].0, "a");
assert_eq!(rep.charged_stages[1].0, "b");
}
}