Skip to main content

agentic_tools_utils/
async_control.rs

1//! Async control flow utilities: retry, semaphore, and timeout helpers.
2
3use std::time::Duration;
4use thiserror::Error;
5use tokio::sync::Semaphore;
6
7/// Error type for async control operations.
8#[derive(Debug, Error)]
9pub enum AsyncControlError {
10    /// The semaphore was closed.
11    #[error("Semaphore closed")]
12    SemaphoreClosed,
13
14    /// The operation timed out.
15    #[error("Timed out after {0}s")]
16    Timeout(u64),
17
18    /// An operation-specific error occurred.
19    #[error("{0}")]
20    Operation(String),
21}
22
23/// Generic helper that acquires a semaphore permit and wraps an operation in a timeout.
24///
25/// This is a testable building block: tests can inject a local semaphore and short timeout
26/// to verify behavior without real Claude sessions.
27///
28/// # Errors
29///
30/// Returns an error if:
31/// - The semaphore is closed
32/// - The operation times out
33/// - The operation itself returns an error
34pub async fn with_permit_and_timeout<F, Fut, T, E>(
35    semaphore: &Semaphore,
36    timeout_dur: Duration,
37    op: F,
38) -> Result<T, AsyncControlError>
39where
40    F: FnOnce() -> Fut,
41    Fut: std::future::Future<Output = Result<T, E>>,
42    E: std::fmt::Display,
43{
44    let _permit = semaphore
45        .acquire()
46        .await
47        .map_err(|_| AsyncControlError::SemaphoreClosed)?;
48
49    match tokio::time::timeout(timeout_dur, op()).await {
50        Ok(Ok(v)) => Ok(v),
51        Ok(Err(e)) => Err(AsyncControlError::Operation(e.to_string())),
52        Err(_) => Err(AsyncControlError::Timeout(timeout_dur.as_secs())),
53    }
54}
55
56/// Generic retry helper with fixed delays.
57///
58/// This is a testable building block: tests can inject a custom sleep function
59/// to verify retry behavior without real waits.
60///
61/// # Arguments
62///
63/// * `delays` - Slice of durations to wait before each attempt (first attempt uses delays[0])
64/// * `sleep_fn` - Async function to call for sleeping
65/// * `op` - The operation to retry
66///
67/// # Errors
68///
69/// Returns the last error from the operation if all retries fail.
70pub async fn retry_fixed_delays<F, Fut, SleepFn, SleepFut, T, E>(
71    delays: &[Duration],
72    mut sleep_fn: SleepFn,
73    mut op: F,
74) -> Result<T, E>
75where
76    F: FnMut() -> Fut,
77    Fut: std::future::Future<Output = Result<T, E>>,
78    SleepFn: FnMut(Duration) -> SleepFut,
79    SleepFut: std::future::Future<Output = ()>,
80    E: std::fmt::Debug,
81{
82    let mut last_err = None;
83
84    for d in delays {
85        sleep_fn(*d).await;
86
87        match op().await {
88            Ok(v) => return Ok(v),
89            Err(e) => {
90                last_err = Some(e);
91            }
92        }
93    }
94
95    // Preserve the last underlying error.
96    // This expect is intentional: calling with empty delays is a programmer error.
97    #[expect(clippy::expect_used)]
98    Err(last_err.expect("retry_fixed_delays called with empty delays"))
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use std::sync::Arc;
105    use std::sync::atomic::AtomicUsize;
106    use std::sync::atomic::Ordering;
107
108    #[derive(Debug)]
109    struct TestError(String);
110
111    impl std::fmt::Display for TestError {
112        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113            write!(f, "{}", self.0)
114        }
115    }
116
117    #[tokio::test]
118    async fn semaphore_limits_concurrency() {
119        let semaphore = Semaphore::new(2);
120        let in_flight = Arc::new(AtomicUsize::new(0));
121        let max_observed = Arc::new(AtomicUsize::new(0));
122
123        let mut handles = vec![];
124        for _ in 0..4 {
125            let sem = &semaphore;
126            let in_flight = Arc::clone(&in_flight);
127            let max_observed = Arc::clone(&max_observed);
128
129            handles.push(async move {
130                let result: Result<(), AsyncControlError> =
131                    with_permit_and_timeout(sem, Duration::from_secs(10), || async {
132                        let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
133                        max_observed.fetch_max(current, Ordering::SeqCst);
134                        tokio::time::sleep(Duration::from_millis(50)).await;
135                        in_flight.fetch_sub(1, Ordering::SeqCst);
136                        Ok::<_, TestError>(())
137                    })
138                    .await;
139                result
140            });
141        }
142
143        futures::future::join_all(handles).await;
144
145        // Max in-flight should be exactly 2 (the semaphore limit)
146        assert_eq!(max_observed.load(Ordering::SeqCst), 2);
147    }
148
149    #[tokio::test]
150    async fn timeout_returns_error_when_exceeded() {
151        let semaphore = Semaphore::new(1);
152
153        let result: Result<(), AsyncControlError> =
154            with_permit_and_timeout(&semaphore, Duration::from_millis(10), || async {
155                tokio::time::sleep(Duration::from_millis(100)).await;
156                Ok::<_, TestError>(())
157            })
158            .await;
159
160        assert!(result.is_err());
161        match result.unwrap_err() {
162            AsyncControlError::Timeout(_) => {}
163            other => panic!("Expected Timeout error, got: {other:?}"),
164        }
165    }
166
167    #[tokio::test]
168    async fn timeout_returns_success_when_op_completes_in_time() {
169        let semaphore = Semaphore::new(1);
170
171        let result: Result<i32, AsyncControlError> =
172            with_permit_and_timeout(&semaphore, Duration::from_secs(10), || async {
173                Ok::<_, TestError>(42)
174            })
175            .await;
176
177        assert!(result.is_ok());
178        assert_eq!(result.unwrap(), 42);
179    }
180
181    #[tokio::test]
182    async fn retry_succeeds_on_third_attempt() {
183        let attempt_count = Arc::new(AtomicUsize::new(0));
184        let delays_observed = Arc::new(std::sync::Mutex::new(Vec::new()));
185
186        let delays = [
187            Duration::from_millis(0),
188            Duration::from_millis(10),
189            Duration::from_millis(20),
190        ];
191
192        let result: Result<&str, TestError> = retry_fixed_delays(
193            &delays,
194            |d| {
195                let delays_observed = Arc::clone(&delays_observed);
196                async move {
197                    delays_observed.lock().unwrap().push(d);
198                }
199            },
200            || {
201                let attempt_count = Arc::clone(&attempt_count);
202                async move {
203                    let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
204                    if attempt < 3 {
205                        Err(TestError(format!("attempt {attempt} failed")))
206                    } else {
207                        Ok("success")
208                    }
209                }
210            },
211        )
212        .await;
213
214        assert!(result.is_ok());
215        assert_eq!(result.unwrap(), "success");
216        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
217    }
218
219    #[tokio::test]
220    async fn retry_returns_last_error_when_all_fail() {
221        let delays = [Duration::from_millis(0), Duration::from_millis(0)];
222
223        let result: Result<(), TestError> = retry_fixed_delays(
224            &delays,
225            |_| async {},
226            || async { Err(TestError("always fails".into())) },
227        )
228        .await;
229
230        assert!(result.is_err());
231        assert_eq!(result.unwrap_err().0, "always fails");
232    }
233
234    #[tokio::test]
235    async fn retry_succeeds_on_first_attempt() {
236        let delays = [Duration::from_millis(0)];
237
238        let result: Result<i32, TestError> =
239            retry_fixed_delays(&delays, |_| async {}, || async { Ok(42) }).await;
240
241        assert!(result.is_ok());
242        assert_eq!(result.unwrap(), 42);
243    }
244}