use std::cell::Cell;
use std::future::Future;
tokio::task_local! {
static COUNTER: Cell<usize>;
}
pub(crate) fn bump() {
let _ = COUNTER.try_with(|c| c.set(c.get() + 1));
}
pub struct QueryCounter;
impl QueryCounter {
pub async fn scope<F: Future>(fut: F) -> F::Output {
COUNTER.scope(Cell::new(0), fut).await
}
#[must_use]
pub fn current() -> usize {
COUNTER
.try_with(Cell::get)
.expect("QueryCounter::current() called outside an active scope")
}
pub fn take() -> usize {
COUNTER
.try_with(|c| {
let n = c.get();
c.set(0);
n
})
.expect("QueryCounter::take() called outside an active scope")
}
}
pub async fn assert_num_queries<F: Future>(expected: usize, fut: F) -> F::Output {
QueryCounter::scope(async move {
let result = fut.await;
let actual = QueryCounter::current();
assert_eq!(
actual, expected,
"assertNumQueries failed: expected {expected} queries, observed {actual}"
);
result
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bump_outside_scope_is_no_op() {
bump();
bump();
bump();
}
#[tokio::test]
async fn assert_num_queries_passes_on_exact_count() {
assert_num_queries(3, async {
bump();
bump();
bump();
})
.await;
}
#[tokio::test]
async fn assert_num_queries_passes_on_zero_when_no_queries() {
assert_num_queries(0, async {
let _ = 1 + 1;
})
.await;
}
#[tokio::test]
#[should_panic(expected = "assertNumQueries failed: expected 2 queries, observed 3")]
async fn assert_num_queries_panics_with_count_in_message() {
assert_num_queries(2, async {
bump();
bump();
bump();
})
.await;
}
#[tokio::test]
async fn returns_inner_future_output() {
let value = assert_num_queries(1, async {
bump();
42
})
.await;
assert_eq!(value, 42);
}
#[tokio::test]
async fn current_reads_running_count_mid_scope() {
QueryCounter::scope(async {
assert_eq!(QueryCounter::current(), 0);
bump();
assert_eq!(QueryCounter::current(), 1);
bump();
bump();
assert_eq!(QueryCounter::current(), 3);
})
.await;
}
#[tokio::test]
async fn take_resets_counter_atomically() {
QueryCounter::scope(async {
bump();
bump();
assert_eq!(QueryCounter::take(), 2);
assert_eq!(QueryCounter::current(), 0);
bump();
assert_eq!(QueryCounter::take(), 1);
assert_eq!(QueryCounter::current(), 0);
})
.await;
}
#[tokio::test]
async fn parallel_scopes_count_independently() {
let (a, b) = tokio::join!(
assert_num_queries(2, async {
bump();
bump();
}),
assert_num_queries(5, async {
for _ in 0..5 {
bump();
}
}),
);
assert_eq!(a, ());
assert_eq!(b, ());
}
}