mcpkit_testing/
async_helpers.rs

1//! Async testing utilities.
2//!
3//! This module provides helpers for testing async MCP code,
4//! including timeout wrappers and assertion helpers.
5
6use std::future::Future;
7use std::time::Duration;
8
9/// Default timeout for async operations in tests.
10pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
11
12/// Run an async function with a timeout.
13///
14/// # Panics
15///
16/// Panics if the future does not complete within the timeout.
17///
18/// # Example
19///
20/// ```rust,ignore
21/// use mcpkit_testing::async_helpers::with_timeout;
22/// use std::time::Duration;
23///
24/// #[tokio::test]
25/// async fn test_with_timeout() {
26///     let result = with_timeout(Duration::from_secs(1), async {
27///         "hello"
28///     }).await;
29///     assert_eq!(result, "hello");
30/// }
31/// ```
32pub async fn with_timeout<T, F>(timeout: Duration, future: F) -> T
33where
34    F: Future<Output = T>,
35{
36    tokio::time::timeout(timeout, future)
37        .await
38        .expect("Test timed out")
39}
40
41/// Run an async function with the default timeout.
42///
43/// Uses [`DEFAULT_TIMEOUT`] (5 seconds) as the timeout.
44pub async fn with_default_timeout<T, F>(future: F) -> T
45where
46    F: Future<Output = T>,
47{
48    with_timeout(DEFAULT_TIMEOUT, future).await
49}
50
51/// Assert that an async operation completes within a timeout.
52///
53/// # Panics
54///
55/// Panics if the future does not complete within the timeout.
56pub async fn assert_completes_within<T, F>(timeout: Duration, future: F) -> T
57where
58    F: Future<Output = T>,
59{
60    tokio::time::timeout(timeout, future)
61        .await
62        .expect("Operation did not complete within timeout")
63}
64
65/// Assert that an async operation times out.
66///
67/// # Panics
68///
69/// Panics if the future completes before the timeout.
70pub async fn assert_times_out<T, F>(timeout: Duration, future: F)
71where
72    F: Future<Output = T>,
73{
74    let result = tokio::time::timeout(timeout, future).await;
75    assert!(
76        result.is_err(),
77        "Expected operation to timeout, but it completed"
78    );
79}
80
81/// Wait for a condition to become true.
82///
83/// Polls the condition function at regular intervals until it returns true
84/// or the timeout is reached.
85///
86/// # Panics
87///
88/// Panics if the condition is not met within the timeout.
89pub async fn wait_for<F>(timeout: Duration, interval: Duration, mut condition: F)
90where
91    F: FnMut() -> bool,
92{
93    let start = std::time::Instant::now();
94    while !condition() {
95        assert!(
96            start.elapsed() <= timeout,
97            "Condition not met within timeout"
98        );
99        tokio::time::sleep(interval).await;
100    }
101}
102
103/// Wait for an async condition to become true.
104///
105/// # Panics
106///
107/// Panics if the condition is not met within the timeout.
108pub async fn wait_for_async<F, Fut>(timeout: Duration, interval: Duration, mut condition: F)
109where
110    F: FnMut() -> Fut,
111    Fut: Future<Output = bool>,
112{
113    let start = std::time::Instant::now();
114    loop {
115        if condition().await {
116            return;
117        }
118        assert!(
119            start.elapsed() <= timeout,
120            "Condition not met within timeout"
121        );
122        tokio::time::sleep(interval).await;
123    }
124}
125
126/// Retry an async operation until it succeeds or max attempts is reached.
127///
128/// # Errors
129///
130/// Returns the last error if all attempts fail.
131pub async fn retry<T, E, F, Fut>(
132    max_attempts: usize,
133    delay: Duration,
134    mut operation: F,
135) -> Result<T, E>
136where
137    F: FnMut() -> Fut,
138    Fut: Future<Output = Result<T, E>>,
139{
140    let mut last_error = None;
141
142    for attempt in 0..max_attempts {
143        match operation().await {
144            Ok(result) => return Ok(result),
145            Err(e) => {
146                last_error = Some(e);
147                if attempt < max_attempts - 1 {
148                    tokio::time::sleep(delay).await;
149                }
150            }
151        }
152    }
153
154    Err(last_error.expect("At least one attempt should have been made"))
155}
156
157/// A test barrier for synchronizing async tests.
158///
159/// Useful for coordinating between client and server in integration tests.
160#[derive(Debug)]
161pub struct TestBarrier {
162    notify: tokio::sync::Notify,
163    count: std::sync::atomic::AtomicUsize,
164    target: usize,
165}
166
167impl TestBarrier {
168    /// Create a new barrier with the specified target count.
169    #[must_use]
170    pub fn new(target: usize) -> Self {
171        Self {
172            notify: tokio::sync::Notify::new(),
173            count: std::sync::atomic::AtomicUsize::new(0),
174            target,
175        }
176    }
177
178    /// Arrive at the barrier and wait for all parties.
179    pub async fn arrive_and_wait(&self) {
180        let count = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
181        if count >= self.target {
182            self.notify.notify_waiters();
183        } else {
184            self.notify.notified().await;
185        }
186    }
187
188    /// Reset the barrier for reuse.
189    pub fn reset(&self) {
190        self.count.store(0, std::sync::atomic::Ordering::SeqCst);
191    }
192}
193
194/// A test latch that can be awaited once.
195#[derive(Debug, Default)]
196pub struct TestLatch {
197    notify: tokio::sync::Notify,
198    triggered: std::sync::atomic::AtomicBool,
199}
200
201impl TestLatch {
202    /// Create a new latch.
203    #[must_use]
204    pub fn new() -> Self {
205        Self::default()
206    }
207
208    /// Trigger the latch.
209    pub fn trigger(&self) {
210        self.triggered
211            .store(true, std::sync::atomic::Ordering::SeqCst);
212        self.notify.notify_waiters();
213    }
214
215    /// Wait for the latch to be triggered.
216    pub async fn wait(&self) {
217        if self.triggered.load(std::sync::atomic::Ordering::SeqCst) {
218            return;
219        }
220        self.notify.notified().await;
221    }
222
223    /// Wait for the latch with a timeout.
224    pub async fn wait_timeout(&self, timeout: Duration) -> bool {
225        if self.triggered.load(std::sync::atomic::Ordering::SeqCst) {
226            return true;
227        }
228        tokio::time::timeout(timeout, self.notify.notified())
229            .await
230            .is_ok()
231    }
232
233    /// Check if the latch has been triggered.
234    #[must_use]
235    pub fn is_triggered(&self) -> bool {
236        self.triggered.load(std::sync::atomic::Ordering::SeqCst)
237    }
238}
239
240/// Collect async stream items into a vector with timeout.
241///
242/// # Panics
243///
244/// Panics if the collection times out.
245pub async fn collect_with_timeout<S, T>(timeout: Duration, mut stream: S) -> Vec<T>
246where
247    S: futures::Stream<Item = T> + Unpin,
248{
249    use futures::StreamExt;
250
251    let mut items = Vec::new();
252    let deadline = tokio::time::Instant::now() + timeout;
253
254    loop {
255        let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
256        if remaining.is_zero() {
257            break;
258        }
259
260        match tokio::time::timeout(remaining, stream.next()).await {
261            Ok(Some(item)) => items.push(item),
262            Ok(None) => break,
263            Err(_) => break,
264        }
265    }
266
267    items
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[tokio::test]
275    async fn test_with_timeout_success() {
276        let result = with_timeout(Duration::from_secs(1), async { 42 }).await;
277        assert_eq!(result, 42);
278    }
279
280    #[tokio::test]
281    #[should_panic(expected = "timed out")]
282    async fn test_with_timeout_failure() {
283        with_timeout(Duration::from_millis(10), async {
284            tokio::time::sleep(Duration::from_secs(10)).await;
285        })
286        .await;
287    }
288
289    #[tokio::test]
290    async fn test_wait_for() {
291        let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
292        let counter_clone = counter.clone();
293
294        tokio::spawn(async move {
295            tokio::time::sleep(Duration::from_millis(50)).await;
296            counter_clone.store(5, std::sync::atomic::Ordering::SeqCst);
297        });
298
299        wait_for(Duration::from_secs(1), Duration::from_millis(10), || {
300            counter.load(std::sync::atomic::Ordering::SeqCst) >= 5
301        })
302        .await;
303
304        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 5);
305    }
306
307    #[tokio::test]
308    async fn test_retry_success() {
309        let attempts = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
310        let attempts_clone = attempts.clone();
311
312        let result: Result<&str, &str> = retry(3, Duration::from_millis(10), || {
313            let attempts = attempts_clone.clone();
314            async move {
315                let count = attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
316                if count < 2 {
317                    Err("not yet")
318                } else {
319                    Ok("success")
320                }
321            }
322        })
323        .await;
324
325        assert_eq!(result, Ok("success"));
326        assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 3);
327    }
328
329    #[tokio::test]
330    async fn test_test_latch() {
331        let latch = std::sync::Arc::new(TestLatch::new());
332        let latch_clone = latch.clone();
333
334        let handle = tokio::spawn(async move {
335            tokio::time::sleep(Duration::from_millis(50)).await;
336            latch_clone.trigger();
337        });
338
339        assert!(!latch.is_triggered());
340        latch.wait().await;
341        assert!(latch.is_triggered());
342
343        handle.await.unwrap();
344    }
345
346    #[tokio::test]
347    async fn test_test_barrier() {
348        let barrier = std::sync::Arc::new(TestBarrier::new(2));
349        let barrier_clone = barrier.clone();
350
351        let handle = tokio::spawn(async move {
352            barrier_clone.arrive_and_wait().await;
353            "done"
354        });
355
356        barrier.arrive_and_wait().await;
357        let result = handle.await.unwrap();
358        assert_eq!(result, "done");
359    }
360}