do_over/bulkhead.rs
1//! Bulkhead policy for concurrency limiting.
2//!
3//! The bulkhead policy limits the number of concurrent executions,
4//! preventing resource exhaustion and isolating failures.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use do_over::{policy::Policy, bulkhead::Bulkhead, error::DoOverError};
10//! use std::time::Duration;
11//!
12//! # async fn example() -> Result<(), DoOverError<std::io::Error>> {
13//! // Allow maximum 10 concurrent executions
14//! let bulkhead = Bulkhead::new(10);
15//!
16//! // With queue timeout - wait up to 1 second for a slot
17//! let bulkhead = Bulkhead::new(10)
18//! .with_queue_timeout(Duration::from_secs(1));
19//!
20//! match bulkhead.execute(|| async {
21//! Ok::<_, DoOverError<std::io::Error>>("completed")
22//! }).await {
23//! Ok(result) => println!("Success: {}", result),
24//! Err(DoOverError::BulkheadFull) => println!("No capacity available"),
25//! Err(e) => println!("Error: {:?}", e),
26//! }
27//! # Ok(())
28//! # }
29//! ```
30
31use tokio::sync::{Semaphore, OwnedSemaphorePermit};
32use std::{time::Duration, sync::Arc};
33use crate::{policy::Policy, error::DoOverError};
34
35/// A policy that limits concurrent executions.
36///
37/// The bulkhead uses a semaphore to control how many operations can run
38/// simultaneously. When the limit is reached, new requests are either
39/// rejected immediately or queued (if a queue timeout is configured).
40///
41/// # Examples
42///
43/// ```rust
44/// use do_over::{policy::Policy, bulkhead::Bulkhead, error::DoOverError};
45/// use std::time::Duration;
46///
47/// # async fn example() {
48/// // Basic bulkhead - reject immediately when full
49/// let bulkhead = Bulkhead::new(5);
50///
51/// // With queue timeout - wait for a slot
52/// let bulkhead = Bulkhead::new(5)
53/// .with_queue_timeout(Duration::from_millis(500));
54/// # }
55/// ```
56pub struct Bulkhead {
57 semaphore: Arc<Semaphore>,
58 queue_timeout: Option<Duration>,
59}
60
61impl Clone for Bulkhead {
62 fn clone(&self) -> Self {
63 Self {
64 semaphore: Arc::clone(&self.semaphore),
65 queue_timeout: self.queue_timeout,
66 }
67 }
68}
69
70impl Bulkhead {
71 /// Create a new bulkhead with the specified concurrency limit.
72 ///
73 /// Without a queue timeout, requests are rejected immediately when
74 /// no slots are available.
75 ///
76 /// # Arguments
77 ///
78 /// * `max_concurrent` - Maximum number of concurrent executions
79 ///
80 /// # Examples
81 ///
82 /// ```rust
83 /// use do_over::bulkhead::Bulkhead;
84 ///
85 /// // Allow up to 10 concurrent operations
86 /// let bulkhead = Bulkhead::new(10);
87 /// ```
88 pub fn new(max_concurrent: usize) -> Self {
89 Self { semaphore: Arc::new(Semaphore::new(max_concurrent)), queue_timeout: None }
90 }
91
92 /// Set a queue timeout for waiting on a slot.
93 ///
94 /// When set, requests will wait up to the specified duration for a slot
95 /// to become available before being rejected.
96 ///
97 /// # Arguments
98 ///
99 /// * `timeout` - Maximum time to wait for a slot
100 ///
101 /// # Examples
102 ///
103 /// ```rust
104 /// use do_over::bulkhead::Bulkhead;
105 /// use std::time::Duration;
106 ///
107 /// let bulkhead = Bulkhead::new(10)
108 /// .with_queue_timeout(Duration::from_secs(1));
109 /// ```
110 pub fn with_queue_timeout(mut self, timeout: Duration) -> Self {
111 self.queue_timeout = Some(timeout);
112 self
113 }
114
115 async fn acquire(&self) -> Result<OwnedSemaphorePermit, DoOverError<()>> {
116 match self.queue_timeout {
117 Some(t) => tokio::time::timeout(t, self.semaphore.clone().acquire_owned())
118 .await
119 .map_err(|_| DoOverError::BulkheadFull)?
120 .map_err(|_| DoOverError::BulkheadFull),
121 None => self.semaphore.clone().try_acquire_owned()
122 .map_err(|_| DoOverError::BulkheadFull),
123 }
124 }
125}
126
127#[async_trait::async_trait]
128impl<E> Policy<DoOverError<E>> for Bulkhead
129where
130 E: Send + Sync,
131{
132 async fn execute<F, Fut, T>(&self, f: F) -> Result<T, DoOverError<E>>
133 where
134 F: Fn() -> Fut + Send + Sync,
135 Fut: std::future::Future<Output = Result<T, DoOverError<E>>> + Send,
136 T: Send,
137 {
138 let permit = self.acquire().await.map_err(|_| DoOverError::BulkheadFull)?;
139 let r = f().await;
140 drop(permit);
141 r
142 }
143}