use std::future::Future;
use std::sync::Arc;
use tokio::sync::Semaphore;
use tokio::task::{AbortHandle, JoinSet};
#[derive(Debug)]
pub struct RiderError(());
#[derive(Debug)]
pub struct Rider {
sem: Arc<Semaphore>,
set: JoinSet<()>,
}
impl RiderError {
fn closed() -> RiderError {
RiderError(())
}
}
impl Rider {
pub const MAX_CAPACITY: usize = Semaphore::MAX_PERMITS;
pub fn new(capacity: usize) -> Rider {
let sem = Arc::new(Semaphore::new(capacity));
let set = JoinSet::new();
Rider { sem, set }
}
pub async fn spawn<F>(&mut self, task: F) -> Result<AbortHandle, RiderError>
where
F: Future<Output = ()>,
F: Send + 'static,
{
let permit = self
.sem
.clone()
.acquire_owned()
.await
.map_err(|_| RiderError::closed())?;
Ok(self.set.spawn(async move {
task.await;
drop(permit);
}))
}
pub async fn shutdown(mut self) {
self.sem.close();
while let Some(handle) = self.set.join_next().await {
handle.expect("task in rider failed");
}
}
}