agentic_tools_utils/
async_control.rs1use std::time::Duration;
4use thiserror::Error;
5use tokio::sync::Semaphore;
6
7#[derive(Debug, Error)]
9pub enum AsyncControlError {
10 #[error("Semaphore closed")]
12 SemaphoreClosed,
13
14 #[error("Timed out after {0}s")]
16 Timeout(u64),
17
18 #[error("{0}")]
20 Operation(String),
21}
22
23pub async fn with_permit_and_timeout<F, Fut, T, E>(
35 semaphore: &Semaphore,
36 timeout_dur: Duration,
37 op: F,
38) -> Result<T, AsyncControlError>
39where
40 F: FnOnce() -> Fut,
41 Fut: std::future::Future<Output = Result<T, E>>,
42 E: std::fmt::Display,
43{
44 let _permit = semaphore
45 .acquire()
46 .await
47 .map_err(|_| AsyncControlError::SemaphoreClosed)?;
48
49 match tokio::time::timeout(timeout_dur, op()).await {
50 Ok(Ok(v)) => Ok(v),
51 Ok(Err(e)) => Err(AsyncControlError::Operation(e.to_string())),
52 Err(_) => Err(AsyncControlError::Timeout(timeout_dur.as_secs())),
53 }
54}
55
56pub async fn retry_fixed_delays<F, Fut, SleepFn, SleepFut, T, E>(
71 delays: &[Duration],
72 mut sleep_fn: SleepFn,
73 mut op: F,
74) -> Result<T, E>
75where
76 F: FnMut() -> Fut,
77 Fut: std::future::Future<Output = Result<T, E>>,
78 SleepFn: FnMut(Duration) -> SleepFut,
79 SleepFut: std::future::Future<Output = ()>,
80 E: std::fmt::Debug,
81{
82 let mut last_err = None;
83
84 for d in delays {
85 sleep_fn(*d).await;
86
87 match op().await {
88 Ok(v) => return Ok(v),
89 Err(e) => {
90 last_err = Some(e);
91 }
92 }
93 }
94
95 Err(last_err.expect("retry_fixed_delays called with empty delays"))
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use std::sync::Arc;
103 use std::sync::atomic::{AtomicUsize, Ordering};
104
105 #[derive(Debug)]
106 struct TestError(String);
107
108 impl std::fmt::Display for TestError {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 write!(f, "{}", self.0)
111 }
112 }
113
114 #[tokio::test]
115 async fn semaphore_limits_concurrency() {
116 let semaphore = Semaphore::new(2);
117 let in_flight = Arc::new(AtomicUsize::new(0));
118 let max_observed = Arc::new(AtomicUsize::new(0));
119
120 let mut handles = vec![];
121 for _ in 0..4 {
122 let sem = &semaphore;
123 let in_flight = Arc::clone(&in_flight);
124 let max_observed = Arc::clone(&max_observed);
125
126 handles.push(async move {
127 let result: Result<(), AsyncControlError> =
128 with_permit_and_timeout(sem, Duration::from_secs(10), || async {
129 let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
130 max_observed.fetch_max(current, Ordering::SeqCst);
131 tokio::time::sleep(Duration::from_millis(50)).await;
132 in_flight.fetch_sub(1, Ordering::SeqCst);
133 Ok::<_, TestError>(())
134 })
135 .await;
136 result
137 });
138 }
139
140 futures::future::join_all(handles).await;
141
142 assert_eq!(max_observed.load(Ordering::SeqCst), 2);
144 }
145
146 #[tokio::test]
147 async fn timeout_returns_error_when_exceeded() {
148 let semaphore = Semaphore::new(1);
149
150 let result: Result<(), AsyncControlError> =
151 with_permit_and_timeout(&semaphore, Duration::from_millis(10), || async {
152 tokio::time::sleep(Duration::from_millis(100)).await;
153 Ok::<_, TestError>(())
154 })
155 .await;
156
157 assert!(result.is_err());
158 match result.unwrap_err() {
159 AsyncControlError::Timeout(_) => {}
160 other => panic!("Expected Timeout error, got: {other:?}"),
161 }
162 }
163
164 #[tokio::test]
165 async fn timeout_returns_success_when_op_completes_in_time() {
166 let semaphore = Semaphore::new(1);
167
168 let result: Result<i32, AsyncControlError> =
169 with_permit_and_timeout(&semaphore, Duration::from_secs(10), || async {
170 Ok::<_, TestError>(42)
171 })
172 .await;
173
174 assert!(result.is_ok());
175 assert_eq!(result.unwrap(), 42);
176 }
177
178 #[tokio::test]
179 async fn retry_succeeds_on_third_attempt() {
180 let attempt_count = Arc::new(AtomicUsize::new(0));
181 let delays_observed = Arc::new(std::sync::Mutex::new(Vec::new()));
182
183 let delays = [
184 Duration::from_millis(0),
185 Duration::from_millis(10),
186 Duration::from_millis(20),
187 ];
188
189 let result: Result<&str, TestError> = retry_fixed_delays(
190 &delays,
191 |d| {
192 let delays_observed = Arc::clone(&delays_observed);
193 async move {
194 delays_observed.lock().unwrap().push(d);
195 }
196 },
197 || {
198 let attempt_count = Arc::clone(&attempt_count);
199 async move {
200 let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
201 if attempt < 3 {
202 Err(TestError(format!("attempt {attempt} failed")))
203 } else {
204 Ok("success")
205 }
206 }
207 },
208 )
209 .await;
210
211 assert!(result.is_ok());
212 assert_eq!(result.unwrap(), "success");
213 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
214 }
215
216 #[tokio::test]
217 async fn retry_returns_last_error_when_all_fail() {
218 let delays = [Duration::from_millis(0), Duration::from_millis(0)];
219
220 let result: Result<(), TestError> = retry_fixed_delays(
221 &delays,
222 |_| async {},
223 || async { Err(TestError("always fails".into())) },
224 )
225 .await;
226
227 assert!(result.is_err());
228 assert_eq!(result.unwrap_err().0, "always fails");
229 }
230
231 #[tokio::test]
232 async fn retry_succeeds_on_first_attempt() {
233 let delays = [Duration::from_millis(0)];
234
235 let result: Result<i32, TestError> =
236 retry_fixed_delays(&delays, |_| async {}, || async { Ok(42) }).await;
237
238 assert!(result.is_ok());
239 assert_eq!(result.unwrap(), 42);
240 }
241}