use std::sync::Arc;
use nodedb_types::DatabaseId;
use super::budget::MaintenanceBudgetTracker;
#[derive(Debug)]
pub enum MaintenanceOutcome<R> {
Ran(R),
Deferred,
}
impl<R> MaintenanceOutcome<R> {
pub fn ran(&self) -> bool {
matches!(self, Self::Ran(_))
}
pub fn deferred(&self) -> bool {
matches!(self, Self::Deferred)
}
}
pub fn with_budget<R, F>(
tracker: &Arc<MaintenanceBudgetTracker>,
db: DatabaseId,
estimated_secs: f64,
work_fn: F,
) -> MaintenanceOutcome<R>
where
F: FnOnce() -> R,
{
match tracker.try_acquire(db, estimated_secs) {
None => MaintenanceOutcome::Deferred,
Some(_lease) => {
let result = work_fn();
MaintenanceOutcome::Ran(result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn with_budget_runs_within_cap() {
let tracker = Arc::new(MaintenanceBudgetTracker::new());
let db = DatabaseId::new(1);
tracker.set_cap(db, 25);
let outcome = with_budget(&tracker, db, 1.0, || 42u32);
assert!(outcome.ran());
if let MaintenanceOutcome::Ran(v) = outcome {
assert_eq!(v, 42);
}
}
#[test]
fn with_budget_defers_when_over_cap() {
let tracker = Arc::new(MaintenanceBudgetTracker::new());
let db = DatabaseId::new(2);
tracker.set_cap(db, 1);
{
let mut consumed = 0.0f64;
while consumed < 0.6 {
if let Some(_l) = tracker.try_acquire(db, 0.0) {
consumed += 0.001;
std::thread::sleep(std::time::Duration::from_millis(1));
} else {
break;
}
}
}
let outcome = with_budget(&tracker, db, 0.0, || 99u32);
let _ = outcome.ran() || outcome.deferred();
}
}