use tokio::sync::{Semaphore, OwnedSemaphorePermit};
use std::{time::Duration, sync::Arc};
use crate::{policy::Policy, error::DoOverError};
pub struct Bulkhead {
semaphore: Arc<Semaphore>,
queue_timeout: Option<Duration>,
}
impl Clone for Bulkhead {
fn clone(&self) -> Self {
Self {
semaphore: Arc::clone(&self.semaphore),
queue_timeout: self.queue_timeout,
}
}
}
impl Bulkhead {
pub fn new(max_concurrent: usize) -> Self {
Self { semaphore: Arc::new(Semaphore::new(max_concurrent)), queue_timeout: None }
}
pub fn with_queue_timeout(mut self, timeout: Duration) -> Self {
self.queue_timeout = Some(timeout);
self
}
async fn acquire(&self) -> Result<OwnedSemaphorePermit, DoOverError<()>> {
match self.queue_timeout {
Some(t) => tokio::time::timeout(t, self.semaphore.clone().acquire_owned())
.await
.map_err(|_| DoOverError::BulkheadFull)?
.map_err(|_| DoOverError::BulkheadFull),
None => self.semaphore.clone().try_acquire_owned()
.map_err(|_| DoOverError::BulkheadFull),
}
}
}
#[async_trait::async_trait]
impl<E> Policy<DoOverError<E>> for Bulkhead
where
E: Send + Sync,
{
async fn execute<F, Fut, T>(&self, f: F) -> Result<T, DoOverError<E>>
where
F: Fn() -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, DoOverError<E>>> + Send,
T: Send,
{
let permit = self.acquire().await.map_err(|_| DoOverError::BulkheadFull)?;
let r = f().await;
drop(permit);
r
}
}