mcpkit_testing/
async_helpers.rs1use std::future::Future;
7use std::time::Duration;
8
9pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
11
12pub async fn with_timeout<T, F>(timeout: Duration, future: F) -> T
33where
34 F: Future<Output = T>,
35{
36 tokio::time::timeout(timeout, future)
37 .await
38 .expect("Test timed out")
39}
40
41pub async fn with_default_timeout<T, F>(future: F) -> T
45where
46 F: Future<Output = T>,
47{
48 with_timeout(DEFAULT_TIMEOUT, future).await
49}
50
51pub async fn assert_completes_within<T, F>(timeout: Duration, future: F) -> T
57where
58 F: Future<Output = T>,
59{
60 tokio::time::timeout(timeout, future)
61 .await
62 .expect("Operation did not complete within timeout")
63}
64
65pub async fn assert_times_out<T, F>(timeout: Duration, future: F)
71where
72 F: Future<Output = T>,
73{
74 let result = tokio::time::timeout(timeout, future).await;
75 assert!(
76 result.is_err(),
77 "Expected operation to timeout, but it completed"
78 );
79}
80
81pub async fn wait_for<F>(timeout: Duration, interval: Duration, mut condition: F)
90where
91 F: FnMut() -> bool,
92{
93 let start = std::time::Instant::now();
94 while !condition() {
95 assert!(
96 start.elapsed() <= timeout,
97 "Condition not met within timeout"
98 );
99 tokio::time::sleep(interval).await;
100 }
101}
102
103pub async fn wait_for_async<F, Fut>(timeout: Duration, interval: Duration, mut condition: F)
109where
110 F: FnMut() -> Fut,
111 Fut: Future<Output = bool>,
112{
113 let start = std::time::Instant::now();
114 loop {
115 if condition().await {
116 return;
117 }
118 assert!(
119 start.elapsed() <= timeout,
120 "Condition not met within timeout"
121 );
122 tokio::time::sleep(interval).await;
123 }
124}
125
126pub async fn retry<T, E, F, Fut>(
132 max_attempts: usize,
133 delay: Duration,
134 mut operation: F,
135) -> Result<T, E>
136where
137 F: FnMut() -> Fut,
138 Fut: Future<Output = Result<T, E>>,
139{
140 let mut last_error = None;
141
142 for attempt in 0..max_attempts {
143 match operation().await {
144 Ok(result) => return Ok(result),
145 Err(e) => {
146 last_error = Some(e);
147 if attempt < max_attempts - 1 {
148 tokio::time::sleep(delay).await;
149 }
150 }
151 }
152 }
153
154 Err(last_error.expect("At least one attempt should have been made"))
155}
156
157#[derive(Debug)]
161pub struct TestBarrier {
162 notify: tokio::sync::Notify,
163 count: std::sync::atomic::AtomicUsize,
164 target: usize,
165}
166
167impl TestBarrier {
168 #[must_use]
170 pub fn new(target: usize) -> Self {
171 Self {
172 notify: tokio::sync::Notify::new(),
173 count: std::sync::atomic::AtomicUsize::new(0),
174 target,
175 }
176 }
177
178 pub async fn arrive_and_wait(&self) {
180 let count = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
181 if count >= self.target {
182 self.notify.notify_waiters();
183 } else {
184 self.notify.notified().await;
185 }
186 }
187
188 pub fn reset(&self) {
190 self.count.store(0, std::sync::atomic::Ordering::SeqCst);
191 }
192}
193
194#[derive(Debug, Default)]
196pub struct TestLatch {
197 notify: tokio::sync::Notify,
198 triggered: std::sync::atomic::AtomicBool,
199}
200
201impl TestLatch {
202 #[must_use]
204 pub fn new() -> Self {
205 Self::default()
206 }
207
208 pub fn trigger(&self) {
210 self.triggered
211 .store(true, std::sync::atomic::Ordering::SeqCst);
212 self.notify.notify_waiters();
213 }
214
215 pub async fn wait(&self) {
217 if self.triggered.load(std::sync::atomic::Ordering::SeqCst) {
218 return;
219 }
220 self.notify.notified().await;
221 }
222
223 pub async fn wait_timeout(&self, timeout: Duration) -> bool {
225 if self.triggered.load(std::sync::atomic::Ordering::SeqCst) {
226 return true;
227 }
228 tokio::time::timeout(timeout, self.notify.notified())
229 .await
230 .is_ok()
231 }
232
233 #[must_use]
235 pub fn is_triggered(&self) -> bool {
236 self.triggered.load(std::sync::atomic::Ordering::SeqCst)
237 }
238}
239
240pub async fn collect_with_timeout<S, T>(timeout: Duration, mut stream: S) -> Vec<T>
246where
247 S: futures::Stream<Item = T> + Unpin,
248{
249 use futures::StreamExt;
250
251 let mut items = Vec::new();
252 let deadline = tokio::time::Instant::now() + timeout;
253
254 loop {
255 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
256 if remaining.is_zero() {
257 break;
258 }
259
260 match tokio::time::timeout(remaining, stream.next()).await {
261 Ok(Some(item)) => items.push(item),
262 Ok(None) => break,
263 Err(_) => break,
264 }
265 }
266
267 items
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[tokio::test]
275 async fn test_with_timeout_success() {
276 let result = with_timeout(Duration::from_secs(1), async { 42 }).await;
277 assert_eq!(result, 42);
278 }
279
280 #[tokio::test]
281 #[should_panic(expected = "timed out")]
282 async fn test_with_timeout_failure() {
283 with_timeout(Duration::from_millis(10), async {
284 tokio::time::sleep(Duration::from_secs(10)).await;
285 })
286 .await;
287 }
288
289 #[tokio::test]
290 async fn test_wait_for() {
291 let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
292 let counter_clone = counter.clone();
293
294 tokio::spawn(async move {
295 tokio::time::sleep(Duration::from_millis(50)).await;
296 counter_clone.store(5, std::sync::atomic::Ordering::SeqCst);
297 });
298
299 wait_for(Duration::from_secs(1), Duration::from_millis(10), || {
300 counter.load(std::sync::atomic::Ordering::SeqCst) >= 5
301 })
302 .await;
303
304 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 5);
305 }
306
307 #[tokio::test]
308 async fn test_retry_success() {
309 let attempts = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
310 let attempts_clone = attempts.clone();
311
312 let result: Result<&str, &str> = retry(3, Duration::from_millis(10), || {
313 let attempts = attempts_clone.clone();
314 async move {
315 let count = attempts.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
316 if count < 2 {
317 Err("not yet")
318 } else {
319 Ok("success")
320 }
321 }
322 })
323 .await;
324
325 assert_eq!(result, Ok("success"));
326 assert_eq!(attempts.load(std::sync::atomic::Ordering::SeqCst), 3);
327 }
328
329 #[tokio::test]
330 async fn test_test_latch() {
331 let latch = std::sync::Arc::new(TestLatch::new());
332 let latch_clone = latch.clone();
333
334 let handle = tokio::spawn(async move {
335 tokio::time::sleep(Duration::from_millis(50)).await;
336 latch_clone.trigger();
337 });
338
339 assert!(!latch.is_triggered());
340 latch.wait().await;
341 assert!(latch.is_triggered());
342
343 handle.await.unwrap();
344 }
345
346 #[tokio::test]
347 async fn test_test_barrier() {
348 let barrier = std::sync::Arc::new(TestBarrier::new(2));
349 let barrier_clone = barrier.clone();
350
351 let handle = tokio::spawn(async move {
352 barrier_clone.arrive_and_wait().await;
353 "done"
354 });
355
356 barrier.arrive_and_wait().await;
357 let result = handle.await.unwrap();
358 assert_eq!(result, "done");
359 }
360}