use std::sync::Arc;
use tokio::sync::Semaphore;
#[derive(Clone, Debug)]
pub(crate) struct BlockingSpawner {
allow_block_in_place: bool,
semaphore: Arc<Semaphore>,
}
impl BlockingSpawner {
pub(crate) fn new(max_blocking: usize) -> Self {
let flavor = tokio::runtime::Handle::current().runtime_flavor();
let allow_block_in_place = matches!(flavor, tokio::runtime::RuntimeFlavor::MultiThread);
Self {
allow_block_in_place,
semaphore: Arc::new(Semaphore::new(max_blocking)),
}
}
pub(crate) async fn block_in_place<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
let _permit = self
.semaphore
.acquire()
.await
.expect("BlockingSpawner semaphore closed");
if self.allow_block_in_place {
tokio::task::block_in_place(f)
} else {
f()
}
}
pub(crate) fn block_in_place_sync<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
if self.allow_block_in_place {
tokio::task::block_in_place(f)
} else {
f()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blocking_spawner_limits_concurrency() {
let spawner = BlockingSpawner::new(2);
let concurrent = Arc::new(AtomicUsize::new(0));
let max_observed = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..4 {
let s = spawner.clone();
let c = Arc::clone(&concurrent);
let m = Arc::clone(&max_observed);
handles.push(tokio::spawn(async move {
s.block_in_place(|| {
let prev = c.fetch_add(1, Ordering::SeqCst);
let current = prev + 1;
m.fetch_max(current, Ordering::SeqCst);
std::thread::sleep(Duration::from_millis(50));
c.fetch_sub(1, Ordering::SeqCst);
})
.await;
}));
}
for h in handles {
h.await.unwrap();
}
let max = max_observed.load(Ordering::SeqCst);
assert!(
max <= 2,
"expected at most 2 concurrent ops, observed {max}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn blocking_spawner_semaphore_backpressure() {
let spawner = BlockingSpawner::new(1);
let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
let s1 = spawner.clone();
let o1 = Arc::clone(&order);
let h1 = tokio::spawn(async move {
s1.block_in_place(|| {
o1.lock().push("first-start");
std::thread::sleep(Duration::from_millis(80));
o1.lock().push("first-end");
})
.await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
let s2 = spawner.clone();
let o2 = Arc::clone(&order);
let h2 = tokio::spawn(async move {
s2.block_in_place(|| {
o2.lock().push("second-start");
})
.await;
});
h1.await.unwrap();
h2.await.unwrap();
let log = order.lock();
let first_end = log.iter().position(|s| *s == "first-end").unwrap();
let second_start = log.iter().position(|s| *s == "second-start").unwrap();
assert!(
first_end < second_start,
"expected first-end before second-start, got: {log:?}"
);
}
#[test]
fn blocking_spawner_single_threaded_runtime() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let spawner = BlockingSpawner::new(2);
let result = spawner.block_in_place(|| 42).await;
assert_eq!(result, 42);
let sync_result = spawner.block_in_place_sync(|| 99);
assert_eq!(sync_result, 99);
});
}
}