Skip to main content

agentic_tools_utils/
async_control.rs

1//! Async control flow utilities: retry, semaphore, and timeout helpers.
2
3use std::time::Duration;
4use thiserror::Error;
5use tokio::sync::Semaphore;
6
7/// Error type for async control operations.
8#[derive(Debug, Error)]
9pub enum AsyncControlError {
10    /// The semaphore was closed.
11    #[error("Semaphore closed")]
12    SemaphoreClosed,
13
14    /// The operation timed out.
15    #[error("Timed out after {0}s")]
16    Timeout(u64),
17
18    /// An operation-specific error occurred.
19    #[error("{0}")]
20    Operation(String),
21}
22
23/// Generic helper that acquires a semaphore permit and wraps an operation in a timeout.
24///
25/// This is a testable building block: tests can inject a local semaphore and short timeout
26/// to verify behavior without real Claude sessions.
27///
28/// # Errors
29///
30/// Returns an error if:
31/// - The semaphore is closed
32/// - The operation times out
33/// - The operation itself returns an error
34pub async fn with_permit_and_timeout<F, Fut, T, E>(
35    semaphore: &Semaphore,
36    timeout_dur: Duration,
37    op: F,
38) -> Result<T, AsyncControlError>
39where
40    F: FnOnce() -> Fut,
41    Fut: std::future::Future<Output = Result<T, E>>,
42    E: std::fmt::Display,
43{
44    let _permit = semaphore
45        .acquire()
46        .await
47        .map_err(|_| AsyncControlError::SemaphoreClosed)?;
48
49    match tokio::time::timeout(timeout_dur, op()).await {
50        Ok(Ok(v)) => Ok(v),
51        Ok(Err(e)) => Err(AsyncControlError::Operation(e.to_string())),
52        Err(_) => Err(AsyncControlError::Timeout(timeout_dur.as_secs())),
53    }
54}
55
56/// Generic retry helper with fixed delays.
57///
58/// This is a testable building block: tests can inject a custom sleep function
59/// to verify retry behavior without real waits.
60///
61/// # Arguments
62///
63/// * `delays` - Slice of durations to wait before each attempt (first attempt uses delays[0])
64/// * `sleep_fn` - Async function to call for sleeping
65/// * `op` - The operation to retry
66///
67/// # Errors
68///
69/// Returns the last error from the operation if all retries fail.
70pub async fn retry_fixed_delays<F, Fut, SleepFn, SleepFut, T, E>(
71    delays: &[Duration],
72    mut sleep_fn: SleepFn,
73    mut op: F,
74) -> Result<T, E>
75where
76    F: FnMut() -> Fut,
77    Fut: std::future::Future<Output = Result<T, E>>,
78    SleepFn: FnMut(Duration) -> SleepFut,
79    SleepFut: std::future::Future<Output = ()>,
80    E: std::fmt::Debug,
81{
82    let mut last_err = None;
83
84    for d in delays {
85        sleep_fn(*d).await;
86
87        match op().await {
88            Ok(v) => return Ok(v),
89            Err(e) => {
90                last_err = Some(e);
91            }
92        }
93    }
94
95    // Preserve the last underlying error
96    Err(last_err.expect("retry_fixed_delays called with empty delays"))
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use std::sync::Arc;
103    use std::sync::atomic::{AtomicUsize, Ordering};
104
105    #[derive(Debug)]
106    struct TestError(String);
107
108    impl std::fmt::Display for TestError {
109        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110            write!(f, "{}", self.0)
111        }
112    }
113
114    #[tokio::test]
115    async fn semaphore_limits_concurrency() {
116        let semaphore = Semaphore::new(2);
117        let in_flight = Arc::new(AtomicUsize::new(0));
118        let max_observed = Arc::new(AtomicUsize::new(0));
119
120        let mut handles = vec![];
121        for _ in 0..4 {
122            let sem = &semaphore;
123            let in_flight = Arc::clone(&in_flight);
124            let max_observed = Arc::clone(&max_observed);
125
126            handles.push(async move {
127                let result: Result<(), AsyncControlError> =
128                    with_permit_and_timeout(sem, Duration::from_secs(10), || async {
129                        let current = in_flight.fetch_add(1, Ordering::SeqCst) + 1;
130                        max_observed.fetch_max(current, Ordering::SeqCst);
131                        tokio::time::sleep(Duration::from_millis(50)).await;
132                        in_flight.fetch_sub(1, Ordering::SeqCst);
133                        Ok::<_, TestError>(())
134                    })
135                    .await;
136                result
137            });
138        }
139
140        futures::future::join_all(handles).await;
141
142        // Max in-flight should be exactly 2 (the semaphore limit)
143        assert_eq!(max_observed.load(Ordering::SeqCst), 2);
144    }
145
146    #[tokio::test]
147    async fn timeout_returns_error_when_exceeded() {
148        let semaphore = Semaphore::new(1);
149
150        let result: Result<(), AsyncControlError> =
151            with_permit_and_timeout(&semaphore, Duration::from_millis(10), || async {
152                tokio::time::sleep(Duration::from_millis(100)).await;
153                Ok::<_, TestError>(())
154            })
155            .await;
156
157        assert!(result.is_err());
158        match result.unwrap_err() {
159            AsyncControlError::Timeout(_) => {}
160            other => panic!("Expected Timeout error, got: {other:?}"),
161        }
162    }
163
164    #[tokio::test]
165    async fn timeout_returns_success_when_op_completes_in_time() {
166        let semaphore = Semaphore::new(1);
167
168        let result: Result<i32, AsyncControlError> =
169            with_permit_and_timeout(&semaphore, Duration::from_secs(10), || async {
170                Ok::<_, TestError>(42)
171            })
172            .await;
173
174        assert!(result.is_ok());
175        assert_eq!(result.unwrap(), 42);
176    }
177
178    #[tokio::test]
179    async fn retry_succeeds_on_third_attempt() {
180        let attempt_count = Arc::new(AtomicUsize::new(0));
181        let delays_observed = Arc::new(std::sync::Mutex::new(Vec::new()));
182
183        let delays = [
184            Duration::from_millis(0),
185            Duration::from_millis(10),
186            Duration::from_millis(20),
187        ];
188
189        let result: Result<&str, TestError> = retry_fixed_delays(
190            &delays,
191            |d| {
192                let delays_observed = Arc::clone(&delays_observed);
193                async move {
194                    delays_observed.lock().unwrap().push(d);
195                }
196            },
197            || {
198                let attempt_count = Arc::clone(&attempt_count);
199                async move {
200                    let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
201                    if attempt < 3 {
202                        Err(TestError(format!("attempt {attempt} failed")))
203                    } else {
204                        Ok("success")
205                    }
206                }
207            },
208        )
209        .await;
210
211        assert!(result.is_ok());
212        assert_eq!(result.unwrap(), "success");
213        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
214    }
215
216    #[tokio::test]
217    async fn retry_returns_last_error_when_all_fail() {
218        let delays = [Duration::from_millis(0), Duration::from_millis(0)];
219
220        let result: Result<(), TestError> = retry_fixed_delays(
221            &delays,
222            |_| async {},
223            || async { Err(TestError("always fails".into())) },
224        )
225        .await;
226
227        assert!(result.is_err());
228        assert_eq!(result.unwrap_err().0, "always fails");
229    }
230
231    #[tokio::test]
232    async fn retry_succeeds_on_first_attempt() {
233        let delays = [Duration::from_millis(0)];
234
235        let result: Result<i32, TestError> =
236            retry_fixed_delays(&delays, |_| async {}, || async { Ok(42) }).await;
237
238        assert!(result.is_ok());
239        assert_eq!(result.unwrap(), 42);
240    }
241}