use async_lock::Semaphore;
use cfg_if::cfg_if;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
cfg_if!(
if #[cfg(all(not(any(feature = "rt-async-std", feature = "rt-smol")), feature = "rt-tokio"))] {
mod tokio;
} else if #[cfg(all(not(any(feature = "rt-tokio", feature = "rt-smol")), feature = "rt-async-std"))] {
mod async_std;
} else if #[cfg(all(not(any(feature = "rt-tokio", feature = "rt-async-std")), feature = "rt-smol"))] {
mod smol;
} else {
compile_error!("you must enable one feature between `rt-tokio`, `rt-async-std` and `rt-smol`");
}
);
#[cfg(all(test, feature = "rt-tokio"))]
mod tests;
#[derive(Debug, Error)]
pub enum BulkheadError {
#[error("the maximum number of concurrent calls is met")]
Timeout,
#[error("max concurrent calls must be at least 1")]
InvalidConcurrentCalls,
}
#[derive(Debug, Copy, Clone)]
pub struct BulkheadBuilder {
max_concurrent_calls: usize,
max_wait_duration: Duration,
}
impl BulkheadBuilder {
pub fn max_concurrent_calls(mut self, max_concurrent_calls: usize) -> Self {
self.max_concurrent_calls = max_concurrent_calls;
self
}
pub fn max_wait_duration(mut self, max_wait_duration: Duration) -> Self {
self.max_wait_duration = max_wait_duration;
self
}
pub fn build(self) -> Result<Bulkhead, BulkheadError> {
if self.max_concurrent_calls > 0 {
Ok(Bulkhead {
max_concurrent_calls: Arc::new(Semaphore::new(self.max_concurrent_calls)),
max_wait_duration: self.max_wait_duration,
})
} else {
Err(BulkheadError::InvalidConcurrentCalls)
}
}
}
impl Default for BulkheadBuilder {
fn default() -> Self {
Self {
max_concurrent_calls: 25,
max_wait_duration: Duration::from_millis(1),
}
}
}
#[derive(Debug, Clone)]
pub struct Bulkhead {
max_concurrent_calls: Arc<Semaphore>,
max_wait_duration: Duration,
}
impl Bulkhead {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> BulkheadBuilder {
BulkheadBuilder::default()
}
}
impl Default for Bulkhead {
fn default() -> Self {
let BulkheadBuilder {
max_concurrent_calls,
max_wait_duration,
} = BulkheadBuilder::default();
Self {
max_concurrent_calls: Arc::new(Semaphore::new(max_concurrent_calls)),
max_wait_duration,
}
}
}
#[derive(Debug, Clone)]
pub struct BulkheadRegistry(HashMap<String, Bulkhead>);
impl BulkheadRegistry {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn register(&mut self, resource: String, bulkhead: Bulkhead) -> &mut Self {
self.0.insert(resource, bulkhead);
self
}
pub fn get(&self, resource: &str) -> Option<&Bulkhead> {
self.0.get(resource)
}
}
impl Default for BulkheadRegistry {
fn default() -> Self {
Self::new()
}
}