use std::sync::Arc;
use tokio::sync::Semaphore;
#[derive(Clone, Debug)]
pub struct BlockingSpawner {
allow_block_in_place: bool,
concurrent_block_in_place_semaphore: Arc<Semaphore>,
}
impl BlockingSpawner {
pub fn new(max_blocking_threads: usize) -> Self {
let handle = tokio::runtime::Handle::current();
let allow_block_in_place = match handle.runtime_flavor() {
tokio::runtime::RuntimeFlavor::CurrentThread => false,
tokio::runtime::RuntimeFlavor::MultiThread => true,
_ => true,
};
Self {
allow_block_in_place,
concurrent_block_in_place_semaphore: Arc::new(Semaphore::new(
max_blocking_threads.max(1),
)),
}
}
pub fn block_in_place<F: FnOnce() -> R, R>(&self, f: F) -> R {
if self.allow_block_in_place {
return tokio::task::block_in_place(f);
}
f()
}
pub async fn block_in_place_with_semaphore<F: FnOnce() -> R, R>(&self, f: F) -> R {
if self.allow_block_in_place {
let _permit = self
.concurrent_block_in_place_semaphore
.acquire()
.await
.unwrap();
return tokio::task::block_in_place(f);
}
f()
}
pub fn semaphore(&self) -> Arc<Semaphore> {
self.concurrent_block_in_place_semaphore.clone()
}
}