agentic_tools_utils/
async_control.rs1use std::time::Duration;
4use thiserror::Error;
5use tokio::sync::Semaphore;
6
7#[derive(Debug, Error)]
9pub enum AsyncControlError {
10 #[error("Semaphore closed")]
12 SemaphoreClosed,
13
14 #[error("Timed out after {0}s")]
16 Timeout(u64),
17
18 #[error("{0}")]
20 Operation(String),
21}
22
23pub async fn with_permit_and_timeout<F, Fut, T, E>(
35 semaphore: &Semaphore,
36 timeout_dur: Duration,
37 op: F,
38) -> Result<T, AsyncControlError>
39where
40 F: FnOnce() -> Fut,
41 Fut: std::future::Future<Output = Result<T, E>>,
42 E: std::fmt::Display,
43{
44 let _permit = semaphore
45 .acquire()
46 .await
47 .map_err(|_| AsyncControlError::SemaphoreClosed)?;
48
49 match tokio::time::timeout(timeout_dur, op()).await {
50 Ok(Ok(v)) => Ok(v),
51 Ok(Err(e)) => Err(AsyncControlError::Operation(e.to_string())),
52 Err(_) => Err(AsyncControlError::Timeout(timeout_dur.as_secs())),
53 }
54}
55
56pub async fn retry_fixed_delays<F, Fut, SleepFn, SleepFut, T, E>(
71 delays: &[Duration],
72 mut sleep_fn: SleepFn,
73 mut op: F,
74) -> Result<T, E>
75where
76 F: FnMut() -> Fut,
77 Fut: std::future::Future<Output = Result<T, E>>,
78 SleepFn: FnMut(Duration) -> SleepFut,
79 SleepFut: std::future::Future<Output = ()>,
80 E: std::fmt::Debug,
81{
82 let mut last_err = None;
83
84 for d in delays {
85 sleep_fn(*d).await;
86
87 match op().await {
88 Ok(v) => return Ok(v),
89 Err(e) => {
90 last_err = Some(e);
91 }
92 }
93 }
94
95 #[expect(clippy::expect_used)]
98 Err(last_err.expect("retry_fixed_delays called with empty delays"))
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use std::sync::Arc;
105 use std::sync::atomic::AtomicUsize;
106 use std::sync::atomic::Ordering;
107
108 #[derive(Debug)]
109 struct TestError(String);
110
111 impl std::fmt::Display for TestError {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(f, "{}", self.0)
114 }
115 }
116
117 #[tokio::test]
118 async fn semaphore_limits_concurrency() {
119 let semaphore = Semaphore::new(2);
120 let in_flight = Arc::new(AtomicUsize::new(0));
121 let max_observed = Arc::new(AtomicUsize::new(0));
122
123 let mut handles = vec![];
124 for _ in 0..4 {
125 let sem = &semaphore;
126 let in_flight = Arc::clone(&in_flight);
127 let max_observed = Arc::clone(&max_observed);
128
129 handles.push(async move {
130 let result: Result<(), AsyncControlError> =
131 with_permit_and_timeout(sem, Duration::from_secs(10), || async {
132 let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
133 max_observed.fetch_max(current, Ordering::SeqCst);
134 tokio::time::sleep(Duration::from_millis(50)).await;
135 in_flight.fetch_sub(1, Ordering::SeqCst);
136 Ok::<_, TestError>(())
137 })
138 .await;
139 result
140 });
141 }
142
143 futures::future::join_all(handles).await;
144
145 assert_eq!(max_observed.load(Ordering::SeqCst), 2);
147 }
148
149 #[tokio::test]
150 async fn timeout_returns_error_when_exceeded() {
151 let semaphore = Semaphore::new(1);
152
153 let result: Result<(), AsyncControlError> =
154 with_permit_and_timeout(&semaphore, Duration::from_millis(10), || async {
155 tokio::time::sleep(Duration::from_millis(100)).await;
156 Ok::<_, TestError>(())
157 })
158 .await;
159
160 assert!(result.is_err());
161 match result.unwrap_err() {
162 AsyncControlError::Timeout(_) => {}
163 other => panic!("Expected Timeout error, got: {other:?}"),
164 }
165 }
166
167 #[tokio::test]
168 async fn timeout_returns_success_when_op_completes_in_time() {
169 let semaphore = Semaphore::new(1);
170
171 let result: Result<i32, AsyncControlError> =
172 with_permit_and_timeout(&semaphore, Duration::from_secs(10), || async {
173 Ok::<_, TestError>(42)
174 })
175 .await;
176
177 assert!(result.is_ok());
178 assert_eq!(result.unwrap(), 42);
179 }
180
181 #[tokio::test]
182 async fn retry_succeeds_on_third_attempt() {
183 let attempt_count = Arc::new(AtomicUsize::new(0));
184 let delays_observed = Arc::new(std::sync::Mutex::new(Vec::new()));
185
186 let delays = [
187 Duration::from_millis(0),
188 Duration::from_millis(10),
189 Duration::from_millis(20),
190 ];
191
192 let result: Result<&str, TestError> = retry_fixed_delays(
193 &delays,
194 |d| {
195 let delays_observed = Arc::clone(&delays_observed);
196 async move {
197 delays_observed.lock().unwrap().push(d);
198 }
199 },
200 || {
201 let attempt_count = Arc::clone(&attempt_count);
202 async move {
203 let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
204 if attempt < 3 {
205 Err(TestError(format!("attempt {attempt} failed")))
206 } else {
207 Ok("success")
208 }
209 }
210 },
211 )
212 .await;
213
214 assert!(result.is_ok());
215 assert_eq!(result.unwrap(), "success");
216 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
217 }
218
219 #[tokio::test]
220 async fn retry_returns_last_error_when_all_fail() {
221 let delays = [Duration::from_millis(0), Duration::from_millis(0)];
222
223 let result: Result<(), TestError> = retry_fixed_delays(
224 &delays,
225 |_| async {},
226 || async { Err(TestError("always fails".into())) },
227 )
228 .await;
229
230 assert!(result.is_err());
231 assert_eq!(result.unwrap_err().0, "always fails");
232 }
233
234 #[tokio::test]
235 async fn retry_succeeds_on_first_attempt() {
236 let delays = [Duration::from_millis(0)];
237
238 let result: Result<i32, TestError> =
239 retry_fixed_delays(&delays, |_| async {}, || async { Ok(42) }).await;
240
241 assert!(result.is_ok());
242 assert_eq!(result.unwrap(), 42);
243 }
244}