use std::{future::Future, sync::Arc};
use tokio::{
sync::{OwnedSemaphorePermit, Semaphore},
task::JoinHandle,
};
#[derive(Debug)]
pub struct TrySpawnError;
pub struct BoundedExecutor {
semaphore: Arc<Semaphore>,
max_available: usize,
}
impl BoundedExecutor {
pub fn new(num_permits: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(num_permits)),
max_available: num_permits,
}
}
pub fn allow_maximum() -> Self {
Self::new(Self::max_theoretical_tasks())
}
pub const fn max_theoretical_tasks() -> usize {
usize::MAX >> 4
}
pub fn can_spawn(&self) -> bool {
self.num_available() > 0
}
pub fn num_available(&self) -> usize {
self.semaphore.available_permits()
}
pub fn max_available(&self) -> usize {
self.max_available
}
pub fn try_spawn<F>(&self, future: F) -> Result<JoinHandle<F::Output>, TrySpawnError>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let permit = self.semaphore.clone().try_acquire_owned().map_err(|_| TrySpawnError)?;
let handle = self.do_spawn(permit, future);
Ok(handle)
}
pub async fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let permit = self.semaphore.clone().acquire_owned().await.expect("semaphore closed");
self.do_spawn(permit, future)
}
fn do_spawn<F>(&self, permit: OwnedSemaphorePermit, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
tokio::spawn(async move {
let ret = future.await;
drop(permit);
ret
})
}
}
#[cfg(test)]
mod test {
use std::{
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tokio::time::sleep;
use super::*;
#[tokio::test]
async fn spawn() {
let flag = Arc::new(AtomicBool::new(false));
let flag_cloned = flag.clone();
let executor = BoundedExecutor::new(1);
let task1_fut = executor
.spawn(async move {
sleep(Duration::from_millis(1)).await;
flag_cloned.store(true, Ordering::SeqCst);
})
.await;
let task2_fut = executor
.spawn(async move {
assert!(flag.load(Ordering::SeqCst));
})
.await;
task2_fut.await.unwrap();
task1_fut.await.unwrap();
}
}