async_bulkhead/bulkhead/
mod.rs

1use async_lock::Semaphore;
2use cfg_if::cfg_if;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use thiserror::Error;
7
8cfg_if!(
9    if #[cfg(all(not(any(feature = "rt-async-std", feature = "rt-smol")), feature = "rt-tokio"))] {
10        mod tokio;
11    } else if #[cfg(all(not(any(feature = "rt-tokio", feature = "rt-smol")), feature = "rt-async-std"))] {
12        mod async_std;
13    } else if #[cfg(all(not(any(feature = "rt-tokio", feature = "rt-async-std")), feature = "rt-smol"))] {
14        mod smol;
15    } else {
16        compile_error!("you must enable one feature between `rt-tokio`, `rt-async-std` and `rt-smol`");
17    }
18);
19
20#[cfg(all(test, feature = "rt-tokio"))]
21mod tests;
22
23/// The error type for operations with [`Bulkhead`].
24#[derive(Debug, Error)]
25pub enum BulkheadError {
26    /// The error returned when the bulkhead semaphore permit could not be acquired before the
27    /// specified maximum wait duration.
28    #[error("the maximum number of concurrent calls is met")]
29    Timeout,
30    /// The error returned when a non-positive maximum number of concurrent calls is specified
31    #[error("max concurrent calls must be at least 1")]
32    InvalidConcurrentCalls,
33}
34
35/// A builder type for a [`Bulkhead`]
36#[derive(Debug, Copy, Clone)]
37pub struct BulkheadBuilder {
38    max_concurrent_calls: usize,
39    max_wait_duration: Duration,
40}
41
42impl BulkheadBuilder {
43    /// Specifies the maximum number of concurrent calls the bulkhead will allow.
44    ///
45    /// Defaults to 25.
46    pub fn max_concurrent_calls(mut self, max_concurrent_calls: usize) -> Self {
47        self.max_concurrent_calls = max_concurrent_calls;
48        self
49    }
50
51    /// Specifies the maximum wait duration for the bulkhead's semaphore guard to be acquired.
52    ///
53    /// Defaults to `Duration::from_millis(1)`.
54    pub fn max_wait_duration(mut self, max_wait_duration: Duration) -> Self {
55        self.max_wait_duration = max_wait_duration;
56        self
57    }
58
59    /// Builds the [`Bulkhead`]. This returns an `Err(BulkheadError::InvalidConcurrentCalls)`
60    /// value if the number of concurrent calls is not positive.
61    pub fn build(self) -> Result<Bulkhead, BulkheadError> {
62        if self.max_concurrent_calls > 0 {
63            Ok(Bulkhead {
64                max_concurrent_calls: Arc::new(Semaphore::new(self.max_concurrent_calls)),
65                max_wait_duration: self.max_wait_duration,
66            })
67        } else {
68            Err(BulkheadError::InvalidConcurrentCalls)
69        }
70    }
71}
72
73impl Default for BulkheadBuilder {
74    fn default() -> Self {
75        Self {
76            max_concurrent_calls: 25,
77            max_wait_duration: Duration::from_millis(1),
78        }
79    }
80}
81
82/// A semaphore-based bulkhead for limiting the number of concurrent
83/// calls to a resource.
84///
85/// This type can be safely cloned and sent across threads while
86/// maintaining the correct number of allowed concurrent calls.
87#[derive(Debug, Clone)]
88pub struct Bulkhead {
89    max_concurrent_calls: Arc<Semaphore>,
90    max_wait_duration: Duration,
91}
92
93impl Bulkhead {
94    /// Creates a new bulkhead with the default configuration
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Creates a new [`BulkheadBuilder`] containing the default configuration
100    pub fn builder() -> BulkheadBuilder {
101        BulkheadBuilder::default()
102    }
103}
104
105impl Default for Bulkhead {
106    fn default() -> Self {
107        let BulkheadBuilder {
108            max_concurrent_calls,
109            max_wait_duration,
110        } = BulkheadBuilder::default();
111        Self {
112            max_concurrent_calls: Arc::new(Semaphore::new(max_concurrent_calls)),
113            max_wait_duration,
114        }
115    }
116}
117
118/// A structure for tracking multiple bulkheads for different resources.
119///
120/// This type can be safely cloned and sent across threads while
121/// maintaining the correct number of allowed concurrent calls in each
122/// resource's corresponding bulkhead.
123#[derive(Debug, Clone)]
124pub struct BulkheadRegistry(HashMap<String, Bulkhead>);
125
126impl BulkheadRegistry {
127    /// Creates an empty registry
128    pub fn new() -> Self {
129        Self(HashMap::new())
130    }
131
132    /// Adds a new bulkhead for the specified resource to the registry
133    pub fn register(&mut self, resource: String, bulkhead: Bulkhead) -> &mut Self {
134        self.0.insert(resource, bulkhead);
135        self
136    }
137
138    /// Retrieves the requested resource bulkhead from the registry
139    pub fn get(&self, resource: &str) -> Option<&Bulkhead> {
140        self.0.get(resource)
141    }
142}
143
144impl Default for BulkheadRegistry {
145    fn default() -> Self {
146        Self::new()
147    }
148}