1use std::collections::HashMap;
26use std::future::Future;
27use std::pin::Pin;
28use std::time::Duration;
29
30use crate::error::{ClientError, ClientResult};
31use crate::streaming::EventStream;
32use crate::transport::Transport;
33
34#[derive(Debug, Clone)]
47pub struct RetryPolicy {
48 pub max_retries: u32,
50 pub initial_backoff: Duration,
52 pub max_backoff: Duration,
54 pub backoff_multiplier: f64,
56}
57
58impl Default for RetryPolicy {
59 fn default() -> Self {
60 Self {
61 max_retries: 3,
62 initial_backoff: Duration::from_millis(500),
63 max_backoff: Duration::from_secs(30),
64 backoff_multiplier: 2.0,
65 }
66 }
67}
68
69impl RetryPolicy {
70 #[must_use]
72 pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
73 self.max_retries = max_retries;
74 self
75 }
76
77 #[must_use]
79 pub const fn with_initial_backoff(mut self, backoff: Duration) -> Self {
80 self.initial_backoff = backoff;
81 self
82 }
83
84 #[must_use]
86 pub const fn with_max_backoff(mut self, max: Duration) -> Self {
87 self.max_backoff = max;
88 self
89 }
90
91 #[must_use]
93 pub const fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
94 self.backoff_multiplier = multiplier;
95 self
96 }
97}
98
99impl ClientError {
102 #[must_use]
109 pub const fn is_retryable(&self) -> bool {
110 match self {
111 Self::Http(_) | Self::HttpClient(_) | Self::Timeout(_) => true,
112 Self::UnexpectedStatus { status, .. } => {
113 matches!(status, 429 | 502 | 503 | 504)
114 }
115 Self::Serialization(_)
117 | Self::Protocol(_)
118 | Self::Transport(_)
119 | Self::InvalidEndpoint(_)
120 | Self::AuthRequired { .. }
121 | Self::ProtocolBindingMismatch(_) => false,
122 }
123 }
124}
125
126pub(crate) struct RetryTransport {
131 inner: Box<dyn Transport>,
132 policy: RetryPolicy,
133}
134
135impl RetryTransport {
136 pub(crate) fn new(inner: Box<dyn Transport>, policy: RetryPolicy) -> Self {
138 Self { inner, policy }
139 }
140}
141
142impl Transport for RetryTransport {
143 fn send_request<'a>(
144 &'a self,
145 method: &'a str,
146 params: serde_json::Value,
147 extra_headers: &'a HashMap<String, String>,
148 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
149 Box::pin(async move {
150 let mut last_err = None;
151 let mut backoff = self.policy.initial_backoff;
152
153 let serialized = serde_json::to_vec(¶ms).map_err(ClientError::Serialization)?;
156
157 for attempt in 0..=self.policy.max_retries {
158 if attempt > 0 {
159 let jittered_backoff = jittered(backoff);
160 trace_info!(method, attempt, ?jittered_backoff, "retrying after backoff");
161 tokio::time::sleep(jittered_backoff).await;
162 backoff = cap_backoff(
163 backoff,
164 self.policy.backoff_multiplier,
165 self.policy.max_backoff,
166 );
167 }
168
169 let attempt_params: serde_json::Value =
170 serde_json::from_slice(&serialized).map_err(ClientError::Serialization)?;
171
172 match self
173 .inner
174 .send_request(method, attempt_params, extra_headers)
175 .await
176 {
177 Ok(result) => return Ok(result),
178 Err(e) if e.is_retryable() => {
179 trace_warn!(method, attempt, error = %e, "transient error, will retry");
180 last_err = Some(e);
181 }
182 Err(e) => return Err(e),
183 }
184 }
185
186 Err(last_err.expect("at least one attempt was made"))
187 })
188 }
189
190 fn send_streaming_request<'a>(
191 &'a self,
192 method: &'a str,
193 params: serde_json::Value,
194 extra_headers: &'a HashMap<String, String>,
195 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
196 Box::pin(async move {
197 let mut last_err = None;
198 let mut backoff = self.policy.initial_backoff;
199
200 let serialized = serde_json::to_vec(¶ms).map_err(ClientError::Serialization)?;
203
204 for attempt in 0..=self.policy.max_retries {
205 if attempt > 0 {
206 let jittered_backoff = jittered(backoff);
207 trace_info!(
208 method,
209 attempt,
210 ?jittered_backoff,
211 "retrying stream connect after backoff"
212 );
213 tokio::time::sleep(jittered_backoff).await;
214 backoff = cap_backoff(
215 backoff,
216 self.policy.backoff_multiplier,
217 self.policy.max_backoff,
218 );
219 }
220
221 let attempt_params: serde_json::Value =
222 serde_json::from_slice(&serialized).map_err(ClientError::Serialization)?;
223
224 match self
225 .inner
226 .send_streaming_request(method, attempt_params, extra_headers)
227 .await
228 {
229 Ok(stream) => return Ok(stream),
230 Err(e) if e.is_retryable() => {
231 trace_warn!(method, attempt, error = %e, "transient error, will retry");
232 last_err = Some(e);
233 }
234 Err(e) => return Err(e),
235 }
236 }
237
238 Err(last_err.expect("at least one attempt was made"))
239 })
240 }
241}
242
243fn cap_backoff(current: Duration, multiplier: f64, max: Duration) -> Duration {
249 let next_secs = current.as_secs_f64() * multiplier;
250 if !next_secs.is_finite() || next_secs < 0.0 {
251 return max;
252 }
253 let next = Duration::from_secs_f64(next_secs);
254 if next > max {
255 max
256 } else {
257 next
258 }
259}
260
261fn jittered(backoff: Duration) -> Duration {
268 use std::hash::{BuildHasher, Hasher};
269 let mut hasher = std::collections::hash_map::RandomState::new().build_hasher();
270 hasher.write_u128(backoff.as_nanos());
272 let random_bits = hasher.finish();
273 #[allow(clippy::cast_precision_loss)] let factor = (random_bits as f64 / u64::MAX as f64).mul_add(0.5, 0.5);
276 let jittered_secs = backoff.as_secs_f64() * factor;
277 if !jittered_secs.is_finite() || jittered_secs < 0.0 {
278 backoff
279 } else {
280 Duration::from_secs_f64(jittered_secs)
281 }
282}
283
284#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn http_errors_are_retryable() {
292 let e = ClientError::HttpClient("connection refused".into());
293 assert!(e.is_retryable());
294 }
295
296 #[test]
297 fn timeout_is_retryable() {
298 let e = ClientError::Timeout("request timed out".into());
299 assert!(e.is_retryable());
300 }
301
302 #[test]
303 fn status_503_is_retryable() {
304 let e = ClientError::UnexpectedStatus {
305 status: 503,
306 body: "Service Unavailable".into(),
307 };
308 assert!(e.is_retryable());
309 }
310
311 #[test]
312 fn status_429_is_retryable() {
313 let e = ClientError::UnexpectedStatus {
314 status: 429,
315 body: "Too Many Requests".into(),
316 };
317 assert!(e.is_retryable());
318 }
319
320 #[test]
321 fn status_404_is_not_retryable() {
322 let e = ClientError::UnexpectedStatus {
323 status: 404,
324 body: "Not Found".into(),
325 };
326 assert!(!e.is_retryable());
327 }
328
329 #[test]
330 fn serialization_error_is_not_retryable() {
331 let e = ClientError::Serialization(serde_json::from_str::<String>("not json").unwrap_err());
332 assert!(!e.is_retryable());
333 }
334
335 #[test]
336 fn protocol_error_is_not_retryable() {
337 let e = ClientError::Protocol(a2a_protocol_types::A2aError::task_not_found("t1"));
338 assert!(!e.is_retryable());
339 }
340
341 #[test]
342 fn default_retry_policy() {
343 let p = RetryPolicy::default();
344 assert_eq!(p.max_retries, 3);
345 assert_eq!(p.initial_backoff, Duration::from_millis(500));
346 assert_eq!(p.max_backoff, Duration::from_secs(30));
347 assert!((p.backoff_multiplier - 2.0).abs() < f64::EPSILON);
348 }
349
350 #[test]
351 fn cap_backoff_works() {
352 let result = cap_backoff(Duration::from_secs(1), 2.0, Duration::from_secs(5));
353 assert_eq!(result, Duration::from_secs(2));
354
355 let result = cap_backoff(Duration::from_secs(4), 2.0, Duration::from_secs(5));
356 assert_eq!(result, Duration::from_secs(5));
357 }
358
359 #[test]
360 fn status_502_is_retryable() {
361 let e = ClientError::UnexpectedStatus {
362 status: 502,
363 body: "Bad Gateway".into(),
364 };
365 assert!(e.is_retryable());
366 }
367
368 #[test]
369 fn status_504_is_retryable() {
370 let e = ClientError::UnexpectedStatus {
371 status: 504,
372 body: "Gateway Timeout".into(),
373 };
374 assert!(e.is_retryable());
375 }
376
377 #[test]
379 fn status_boundary_not_retryable() {
380 for status in [428, 430, 500, 501, 505] {
381 let e = ClientError::UnexpectedStatus {
382 status,
383 body: String::new(),
384 };
385 assert!(!e.is_retryable(), "status {status} should not be retryable");
386 }
387 }
388
389 #[test]
390 fn retry_policy_builder_methods() {
391 let p = RetryPolicy::default()
392 .with_max_retries(5)
393 .with_initial_backoff(Duration::from_secs(1))
394 .with_max_backoff(Duration::from_secs(60))
395 .with_backoff_multiplier(3.0);
396 assert_eq!(p.max_retries, 5);
397 assert_eq!(p.initial_backoff, Duration::from_secs(1));
398 assert_eq!(p.max_backoff, Duration::from_secs(60));
399 assert!((p.backoff_multiplier - 3.0).abs() < f64::EPSILON);
400 }
401
402 #[test]
403 fn cap_backoff_exact_boundary() {
404 let result = cap_backoff(Duration::from_secs(5), 1.0, Duration::from_secs(5));
406 assert_eq!(result, Duration::from_secs(5));
407
408 let result = cap_backoff(Duration::from_millis(1), 2.0, Duration::from_secs(5));
410 assert_eq!(result, Duration::from_millis(2));
411 }
412
413 #[test]
414 fn cap_backoff_infinity_returns_max() {
415 let max = Duration::from_secs(30);
417 let result = cap_backoff(Duration::from_secs(u64::MAX / 2), f64::MAX, max);
418 assert_eq!(result, max, "infinity should clamp to max");
419 }
420
421 #[test]
423 fn jittered_backoff_in_expected_range() {
424 let backoff = Duration::from_secs(2);
425 for _ in 0..100 {
427 let result = jittered(backoff);
428 assert!(
429 result >= Duration::from_secs(1),
430 "jittered backoff should be >= backoff/2, got {result:?}"
431 );
432 assert!(
433 result <= backoff,
434 "jittered backoff should be <= backoff, got {result:?}"
435 );
436 }
437 }
438
439 #[test]
441 fn jittered_zero_backoff() {
442 let result = jittered(Duration::ZERO);
443 assert_eq!(result, Duration::ZERO);
444 }
445
446 #[test]
447 fn cap_backoff_nan_returns_max() {
448 let max = Duration::from_secs(30);
449 let result = cap_backoff(Duration::from_secs(0), f64::NAN, max);
450 assert_eq!(result, max, "NaN should clamp to max");
451 }
452
453 use std::collections::HashMap;
456 use std::future::Future;
457 use std::pin::Pin;
458 use std::sync::atomic::{AtomicUsize, Ordering};
459 use std::sync::Arc;
460
461 use crate::streaming::EventStream;
462
463 struct FailNTransport {
465 failures_remaining: Arc<AtomicUsize>,
466 success_response: serde_json::Value,
467 call_count: Arc<AtomicUsize>,
468 }
469
470 impl FailNTransport {
471 fn new(fail_count: usize, response: serde_json::Value) -> Self {
472 Self {
473 failures_remaining: Arc::new(AtomicUsize::new(fail_count)),
474 success_response: response,
475 call_count: Arc::new(AtomicUsize::new(0)),
476 }
477 }
478 }
479
480 impl crate::transport::Transport for FailNTransport {
481 fn send_request<'a>(
482 &'a self,
483 _method: &'a str,
484 _params: serde_json::Value,
485 _extra_headers: &'a HashMap<String, String>,
486 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
487 self.call_count.fetch_add(1, Ordering::SeqCst);
488 let remaining = self.failures_remaining.fetch_sub(1, Ordering::SeqCst);
489 let resp = self.success_response.clone();
490 Box::pin(async move {
491 if remaining > 0 {
492 Err(ClientError::Timeout("transient".into()))
493 } else {
494 Ok(resp)
495 }
496 })
497 }
498
499 fn send_streaming_request<'a>(
500 &'a self,
501 _method: &'a str,
502 _params: serde_json::Value,
503 _extra_headers: &'a HashMap<String, String>,
504 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
505 self.call_count.fetch_add(1, Ordering::SeqCst);
506 let remaining = self.failures_remaining.fetch_sub(1, Ordering::SeqCst);
507 Box::pin(async move {
508 if remaining > 0 {
509 Err(ClientError::Timeout("transient".into()))
510 } else {
511 Err(ClientError::Transport("streaming not mocked".into()))
512 }
513 })
514 }
515 }
516
517 struct NonRetryableErrorTransport {
519 call_count: Arc<AtomicUsize>,
520 }
521
522 impl NonRetryableErrorTransport {
523 fn new() -> Self {
524 Self {
525 call_count: Arc::new(AtomicUsize::new(0)),
526 }
527 }
528 }
529
530 impl crate::transport::Transport for NonRetryableErrorTransport {
531 fn send_request<'a>(
532 &'a self,
533 _method: &'a str,
534 _params: serde_json::Value,
535 _extra_headers: &'a HashMap<String, String>,
536 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
537 self.call_count.fetch_add(1, Ordering::SeqCst);
538 Box::pin(async move { Err(ClientError::InvalidEndpoint("bad url".into())) })
539 }
540
541 fn send_streaming_request<'a>(
542 &'a self,
543 _method: &'a str,
544 _params: serde_json::Value,
545 _extra_headers: &'a HashMap<String, String>,
546 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
547 self.call_count.fetch_add(1, Ordering::SeqCst);
548 Box::pin(async move { Err(ClientError::InvalidEndpoint("bad url".into())) })
549 }
550 }
551
552 #[tokio::test]
553 async fn retry_transport_retries_on_transient_error() {
554 let inner = FailNTransport::new(2, serde_json::json!({"ok": true}));
555 let call_count = Arc::clone(&inner.call_count);
556 let transport = RetryTransport::new(
557 Box::new(inner),
558 RetryPolicy::default()
559 .with_initial_backoff(Duration::from_millis(1))
560 .with_max_retries(3),
561 );
562
563 let headers = HashMap::new();
564 let result = transport
565 .send_request("test", serde_json::Value::Null, &headers)
566 .await;
567 assert!(result.is_ok(), "should succeed after retries");
568 assert_eq!(
569 call_count.load(Ordering::SeqCst),
570 3,
571 "should have made 3 attempts (2 failures + 1 success)"
572 );
573 }
574
575 #[tokio::test]
576 async fn retry_transport_gives_up_after_max_retries() {
577 let inner = FailNTransport::new(10, serde_json::json!({"ok": true}));
579 let call_count = Arc::clone(&inner.call_count);
580 let transport = RetryTransport::new(
581 Box::new(inner),
582 RetryPolicy::default()
583 .with_initial_backoff(Duration::from_millis(1))
584 .with_max_retries(2),
585 );
586
587 let headers = HashMap::new();
588 let result = transport
589 .send_request("test", serde_json::Value::Null, &headers)
590 .await;
591 assert!(result.is_err(), "should fail after exhausting retries");
592 assert_eq!(
593 call_count.load(Ordering::SeqCst),
594 3,
595 "should have made 3 attempts (initial + 2 retries)"
596 );
597 }
598
599 #[tokio::test]
600 async fn retry_transport_no_retry_on_non_retryable() {
601 let inner = NonRetryableErrorTransport::new();
602 let call_count = Arc::clone(&inner.call_count);
603 let transport = RetryTransport::new(
604 Box::new(inner),
605 RetryPolicy::default()
606 .with_initial_backoff(Duration::from_millis(1))
607 .with_max_retries(3),
608 );
609
610 let headers = HashMap::new();
611 let result = transport
612 .send_request("test", serde_json::Value::Null, &headers)
613 .await;
614 assert!(result.is_err());
615 assert!(matches!(
616 result.unwrap_err(),
617 ClientError::InvalidEndpoint(_)
618 ));
619 assert_eq!(
620 call_count.load(Ordering::SeqCst),
621 1,
622 "non-retryable error should not be retried"
623 );
624 }
625
626 #[tokio::test]
627 async fn retry_transport_streaming_retries() {
628 let inner = FailNTransport::new(1, serde_json::json!(null));
629 let call_count = Arc::clone(&inner.call_count);
630 let transport = RetryTransport::new(
631 Box::new(inner),
632 RetryPolicy::default()
633 .with_initial_backoff(Duration::from_millis(1))
634 .with_max_retries(2),
635 );
636
637 let headers = HashMap::new();
638 let result = transport
639 .send_streaming_request("test", serde_json::Value::Null, &headers)
640 .await;
641 assert!(result.is_err());
644 assert_eq!(
645 call_count.load(Ordering::SeqCst),
646 2,
647 "should have retried once for streaming"
648 );
649 }
650
651 #[tokio::test]
652 async fn retry_transport_streaming_no_retry_on_non_retryable() {
653 let inner = NonRetryableErrorTransport::new();
654 let call_count = Arc::clone(&inner.call_count);
655 let transport = RetryTransport::new(
656 Box::new(inner),
657 RetryPolicy::default()
658 .with_initial_backoff(Duration::from_millis(1))
659 .with_max_retries(3),
660 );
661
662 let headers = HashMap::new();
663 let result = transport
664 .send_streaming_request("test", serde_json::Value::Null, &headers)
665 .await;
666 assert!(matches!(
667 result.unwrap_err(),
668 ClientError::InvalidEndpoint(_)
669 ));
670 assert_eq!(
671 call_count.load(Ordering::SeqCst),
672 1,
673 "non-retryable streaming error should not be retried"
674 );
675 }
676
677 #[tokio::test]
680 async fn retry_transport_streaming_succeeds_after_retry() {
681 use tokio::sync::mpsc;
682
683 struct FailThenStreamTransport {
685 call_count: Arc<AtomicUsize>,
686 }
687
688 impl crate::transport::Transport for FailThenStreamTransport {
689 fn send_request<'a>(
690 &'a self,
691 _method: &'a str,
692 _params: serde_json::Value,
693 _extra_headers: &'a HashMap<String, String>,
694 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
695 {
696 Box::pin(async move { Ok(serde_json::Value::Null) })
697 }
698
699 fn send_streaming_request<'a>(
700 &'a self,
701 _method: &'a str,
702 _params: serde_json::Value,
703 _extra_headers: &'a HashMap<String, String>,
704 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
705 let attempt = self.call_count.fetch_add(1, Ordering::SeqCst);
706 Box::pin(async move {
707 if attempt == 0 {
708 Err(ClientError::Timeout("transient timeout".into()))
709 } else {
710 let (tx, rx) = mpsc::channel(8);
712 drop(tx); Ok(EventStream::new(rx))
714 }
715 })
716 }
717 }
718
719 let call_count = Arc::new(AtomicUsize::new(0));
720 let inner = FailThenStreamTransport {
721 call_count: Arc::clone(&call_count),
722 };
723 let transport = RetryTransport::new(
724 Box::new(inner),
725 RetryPolicy::default()
726 .with_initial_backoff(Duration::from_millis(1))
727 .with_max_retries(2),
728 );
729
730 let headers = HashMap::new();
731 let result = transport
732 .send_streaming_request("test", serde_json::Value::Null, &headers)
733 .await;
734 assert!(result.is_ok(), "streaming should succeed after retry");
735 assert_eq!(
736 call_count.load(Ordering::SeqCst),
737 2,
738 "should have made 2 attempts (1 failure + 1 success)"
739 );
740 }
741
742 #[tokio::test]
743 async fn retry_transport_streaming_exhausts_retries() {
744 let inner = FailNTransport::new(10, serde_json::json!(null));
745 let call_count = Arc::clone(&inner.call_count);
746 let transport = RetryTransport::new(
747 Box::new(inner),
748 RetryPolicy::default()
749 .with_initial_backoff(Duration::from_millis(1))
750 .with_max_retries(2),
751 );
752
753 let headers = HashMap::new();
754 let result = transport
755 .send_streaming_request("test", serde_json::Value::Null, &headers)
756 .await;
757 assert!(result.is_err());
758 assert_eq!(
759 call_count.load(Ordering::SeqCst),
760 3,
761 "should make 3 attempts total for streaming"
762 );
763 }
764
765 #[tokio::test]
766 async fn retry_transport_succeeds_without_retry_on_first_attempt() {
767 let inner = FailNTransport::new(0, serde_json::json!({"ok": true}));
768 let call_count = Arc::clone(&inner.call_count);
769 let transport = RetryTransport::new(
770 Box::new(inner),
771 RetryPolicy::default()
772 .with_initial_backoff(Duration::from_millis(1))
773 .with_max_retries(3),
774 );
775
776 let headers = HashMap::new();
777 let result = transport
778 .send_request("test", serde_json::Value::Null, &headers)
779 .await;
780 assert!(result.is_ok());
781 assert_eq!(
782 call_count.load(Ordering::SeqCst),
783 1,
784 "should succeed on first try"
785 );
786 }
787
788 #[tokio::test(start_paused = true)]
794 async fn no_backoff_before_first_attempt() {
795 let inner = FailNTransport::new(0, serde_json::json!({"ok": true}));
796 let transport = RetryTransport::new(
797 Box::new(inner),
798 RetryPolicy::default()
799 .with_initial_backoff(Duration::from_secs(100))
800 .with_max_retries(1),
801 );
802
803 let start = tokio::time::Instant::now();
804 let headers = HashMap::new();
805 let result = transport
806 .send_request("test", serde_json::Value::Null, &headers)
807 .await;
808 assert!(result.is_ok());
809 assert!(
810 start.elapsed() < Duration::from_secs(1),
811 "first attempt must not sleep, elapsed: {:?}",
812 start.elapsed()
813 );
814 }
815
816 #[tokio::test(start_paused = true)]
820 async fn backoff_applied_on_retry() {
821 let inner = FailNTransport::new(1, serde_json::json!({"ok": true}));
822 let transport = RetryTransport::new(
823 Box::new(inner),
824 RetryPolicy::default()
825 .with_initial_backoff(Duration::from_secs(100))
826 .with_max_retries(2),
827 );
828
829 let start = tokio::time::Instant::now();
830 let headers = HashMap::new();
831 let result = transport
832 .send_request("test", serde_json::Value::Null, &headers)
833 .await;
834 assert!(result.is_ok());
835 assert!(
836 start.elapsed() >= Duration::from_secs(50),
837 "retry should sleep (jittered backoff), elapsed: {:?}",
838 start.elapsed()
839 );
840 }
841
842 #[tokio::test(start_paused = true)]
844 async fn no_backoff_before_first_streaming_attempt() {
845 use tokio::sync::mpsc;
846
847 struct ImmediateStreamTransport;
848 impl crate::transport::Transport for ImmediateStreamTransport {
849 fn send_request<'a>(
850 &'a self,
851 _method: &'a str,
852 _params: serde_json::Value,
853 _extra_headers: &'a HashMap<String, String>,
854 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
855 {
856 Box::pin(async { Ok(serde_json::Value::Null) })
857 }
858 fn send_streaming_request<'a>(
859 &'a self,
860 _method: &'a str,
861 _params: serde_json::Value,
862 _extra_headers: &'a HashMap<String, String>,
863 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
864 Box::pin(async {
865 let (tx, rx) = mpsc::channel(1);
866 drop(tx);
867 Ok(EventStream::new(rx))
868 })
869 }
870 }
871
872 let transport = RetryTransport::new(
873 Box::new(ImmediateStreamTransport),
874 RetryPolicy::default()
875 .with_initial_backoff(Duration::from_secs(100))
876 .with_max_retries(1),
877 );
878
879 let start = tokio::time::Instant::now();
880 let headers = HashMap::new();
881 let result = transport
882 .send_streaming_request("test", serde_json::Value::Null, &headers)
883 .await;
884 assert!(result.is_ok());
885 assert!(
886 start.elapsed() < Duration::from_secs(1),
887 "first streaming attempt must not sleep, elapsed: {:?}",
888 start.elapsed()
889 );
890 }
891
892 #[tokio::test(start_paused = true)]
894 async fn backoff_applied_on_streaming_retry() {
895 let inner = FailNTransport::new(1, serde_json::json!(null));
896 let transport = RetryTransport::new(
897 Box::new(inner),
898 RetryPolicy::default()
899 .with_initial_backoff(Duration::from_secs(100))
900 .with_max_retries(2),
901 );
902
903 let start = tokio::time::Instant::now();
904 let headers = HashMap::new();
905 let _result = transport
906 .send_streaming_request("test", serde_json::Value::Null, &headers)
907 .await;
908 assert!(
911 start.elapsed() >= Duration::from_secs(50),
912 "streaming retry should sleep, elapsed: {:?}",
913 start.elapsed()
914 );
915 }
916
917 #[test]
922 fn cap_backoff_zero_multiplier_returns_zero() {
923 let max = Duration::from_secs(30);
924 let result = cap_backoff(Duration::from_secs(5), 0.0, max);
925 assert_eq!(
926 result,
927 Duration::ZERO,
928 "0 * any = 0, should not clamp to max"
929 );
930 }
931}