use std::cell::Cell;
use elfo_utils::time::Instant;
const MAX_TIME_NS: u64 = 5_000_000; const MAX_COUNT: u32 = 64;
thread_local! {
static BUDGET: Cell<Budget> = const { Cell::new(Budget::ByCount(0)) };
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(test, derive(PartialEq))]
enum Budget {
ByTime( Instant),
ByCount(u32),
}
#[inline]
pub(crate) fn reset(busy_since: Option<Instant>) {
BUDGET.with(|budget| budget.set(busy_since.map_or(Budget::ByCount(MAX_COUNT), Budget::ByTime)));
}
#[inline]
pub async fn consume_budget() {
let to_preempt = BUDGET.with(|cell| {
let budget = cell.get();
match budget {
Budget::ByTime(busy_since) => busy_since.elapsed_nanos() >= MAX_TIME_NS,
Budget::ByCount(0) => true,
Budget::ByCount(left) => {
cell.set(Budget::ByCount(left - 1));
false
}
}
});
if to_preempt {
tokio::task::yield_now().await;
}
}
#[cfg(test)]
mod tests {
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use pin_project::pin_project;
use tokio::runtime::Builder;
use elfo_utils::time::with_instant_mock;
use super::*;
fn current_budget() -> Budget {
BUDGET.with(Cell::get)
}
#[pin_project]
struct ResetOnPoll<F>( bool, #[pin] F);
impl<F: Future> Future for ResetOnPoll<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
reset(if *this.0 { Some(Instant::now()) } else { None });
this.1.poll(cx)
}
}
#[test]
fn by_count() {
let rt = Builder::new_current_thread().build().unwrap();
let task = async {
assert_eq!(current_budget(), current_budget());
for _ in 0..10 {
for i in (0..=MAX_COUNT).rev() {
assert_eq!(current_budget(), Budget::ByCount(i));
consume_budget().await;
}
}
};
rt.block_on(ResetOnPoll(false, task));
}
#[test]
fn by_time() {
let rt = Builder::new_current_thread().build().unwrap();
let steps = 10;
let timestep = Duration::from_nanos(MAX_TIME_NS / steps);
with_instant_mock(|mock| {
let task = async move {
assert_eq!(current_budget(), current_budget());
for _ in 0..10 {
let before = current_budget();
assert!(matches!(before, Budget::ByTime(_)));
for _ in 0..steps {
mock.advance(timestep);
assert_eq!(current_budget(), before);
consume_budget().await;
}
assert_ne!(current_budget(), before);
}
};
rt.block_on(ResetOnPoll(true, task));
})
}
}