1use std::collections::HashMap;
26use std::future::Future;
27use std::pin::Pin;
28use std::time::Duration;
29
30use crate::error::{ClientError, ClientResult};
31use crate::streaming::EventStream;
32use crate::transport::Transport;
33
34#[derive(Debug, Clone)]
47pub struct RetryPolicy {
48 pub max_retries: u32,
50 pub initial_backoff: Duration,
52 pub max_backoff: Duration,
54 pub backoff_multiplier: f64,
56}
57
58impl Default for RetryPolicy {
59 fn default() -> Self {
60 Self {
61 max_retries: 3,
62 initial_backoff: Duration::from_millis(500),
63 max_backoff: Duration::from_secs(30),
64 backoff_multiplier: 2.0,
65 }
66 }
67}
68
69impl RetryPolicy {
70 #[must_use]
72 pub const fn with_max_retries(mut self, max_retries: u32) -> Self {
73 self.max_retries = max_retries;
74 self
75 }
76
77 #[must_use]
79 pub const fn with_initial_backoff(mut self, backoff: Duration) -> Self {
80 self.initial_backoff = backoff;
81 self
82 }
83
84 #[must_use]
86 pub const fn with_max_backoff(mut self, max: Duration) -> Self {
87 self.max_backoff = max;
88 self
89 }
90
91 #[must_use]
93 pub const fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
94 self.backoff_multiplier = multiplier;
95 self
96 }
97}
98
99impl ClientError {
102 #[must_use]
109 pub const fn is_retryable(&self) -> bool {
110 match self {
111 Self::Http(_) | Self::HttpClient(_) | Self::Timeout(_) => true,
112 Self::UnexpectedStatus { status, .. } => {
113 matches!(status, 429 | 502 | 503 | 504)
114 }
115 Self::Serialization(_)
117 | Self::Protocol(_)
118 | Self::Transport(_)
119 | Self::InvalidEndpoint(_)
120 | Self::AuthRequired { .. }
121 | Self::ProtocolBindingMismatch(_) => false,
122 }
123 }
124}
125
126pub(crate) struct RetryTransport {
131 inner: Box<dyn Transport>,
132 policy: RetryPolicy,
133}
134
135impl RetryTransport {
136 pub(crate) fn new(inner: Box<dyn Transport>, policy: RetryPolicy) -> Self {
138 Self { inner, policy }
139 }
140}
141
142impl Transport for RetryTransport {
143 fn send_request<'a>(
144 &'a self,
145 method: &'a str,
146 params: serde_json::Value,
147 extra_headers: &'a HashMap<String, String>,
148 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
149 Box::pin(async move {
150 let mut last_err = None;
151 let mut backoff = self.policy.initial_backoff;
152
153 let serialized = serde_json::to_vec(¶ms).map_err(ClientError::Serialization)?;
156
157 for attempt in 0..=self.policy.max_retries {
158 if attempt > 0 {
159 let jittered_backoff = jittered(backoff);
160 trace_info!(method, attempt, ?jittered_backoff, "retrying after backoff");
161 tokio::time::sleep(jittered_backoff).await;
162 backoff = cap_backoff(
163 backoff,
164 self.policy.backoff_multiplier,
165 self.policy.max_backoff,
166 );
167 }
168
169 let attempt_params: serde_json::Value =
170 serde_json::from_slice(&serialized).map_err(ClientError::Serialization)?;
171
172 match self
173 .inner
174 .send_request(method, attempt_params, extra_headers)
175 .await
176 {
177 Ok(result) => return Ok(result),
178 Err(e) if e.is_retryable() => {
179 trace_warn!(method, attempt, error = %e, "transient error, will retry");
180 last_err = Some(e);
181 }
182 Err(e) => return Err(e),
183 }
184 }
185
186 Err(last_err.expect("at least one attempt was made"))
187 })
188 }
189
190 fn send_streaming_request<'a>(
191 &'a self,
192 method: &'a str,
193 params: serde_json::Value,
194 extra_headers: &'a HashMap<String, String>,
195 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
196 Box::pin(async move {
197 let mut last_err = None;
198 let mut backoff = self.policy.initial_backoff;
199
200 let serialized = serde_json::to_vec(¶ms).map_err(ClientError::Serialization)?;
203
204 for attempt in 0..=self.policy.max_retries {
205 if attempt > 0 {
206 let jittered_backoff = jittered(backoff);
207 trace_info!(
208 method,
209 attempt,
210 ?jittered_backoff,
211 "retrying stream connect after backoff"
212 );
213 tokio::time::sleep(jittered_backoff).await;
214 backoff = cap_backoff(
215 backoff,
216 self.policy.backoff_multiplier,
217 self.policy.max_backoff,
218 );
219 }
220
221 let attempt_params: serde_json::Value =
222 serde_json::from_slice(&serialized).map_err(ClientError::Serialization)?;
223
224 match self
225 .inner
226 .send_streaming_request(method, attempt_params, extra_headers)
227 .await
228 {
229 Ok(stream) => return Ok(stream),
230 Err(e) if e.is_retryable() => {
231 trace_warn!(method, attempt, error = %e, "transient error, will retry");
232 last_err = Some(e);
233 }
234 Err(e) => return Err(e),
235 }
236 }
237
238 Err(last_err.expect("at least one attempt was made"))
239 })
240 }
241}
242
243fn cap_backoff(current: Duration, multiplier: f64, max: Duration) -> Duration {
249 let next_secs = current.as_secs_f64() * multiplier;
250 if !next_secs.is_finite() || next_secs < 0.0 {
251 return max;
252 }
253 let next = Duration::from_secs_f64(next_secs);
254 std::cmp::min(next, max)
259}
260
261#[allow(clippy::cast_precision_loss)] fn jitter_factor_from_bits(random_bits: u64) -> f64 {
268 (random_bits as f64 / u64::MAX as f64).mul_add(0.5, 0.5)
269}
270
271fn apply_jitter(backoff: Duration, factor: f64) -> Duration {
276 let jittered_secs = backoff.as_secs_f64() * factor;
277 if !jittered_secs.is_finite() || jittered_secs < 0.0 {
278 backoff
279 } else {
280 Duration::from_secs_f64(jittered_secs)
281 }
282}
283
284fn jittered(backoff: Duration) -> Duration {
291 use std::hash::{BuildHasher, Hasher};
292 let mut hasher = std::collections::hash_map::RandomState::new().build_hasher();
293 hasher.write_u128(backoff.as_nanos());
295 let factor = jitter_factor_from_bits(hasher.finish());
296 apply_jitter(backoff, factor)
297}
298
299#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn http_errors_are_retryable() {
307 let e = ClientError::HttpClient("connection refused".into());
308 assert!(e.is_retryable());
309 }
310
311 #[test]
312 fn timeout_is_retryable() {
313 let e = ClientError::Timeout("request timed out".into());
314 assert!(e.is_retryable());
315 }
316
317 #[test]
318 fn status_503_is_retryable() {
319 let e = ClientError::UnexpectedStatus {
320 status: 503,
321 body: "Service Unavailable".into(),
322 };
323 assert!(e.is_retryable());
324 }
325
326 #[test]
327 fn status_429_is_retryable() {
328 let e = ClientError::UnexpectedStatus {
329 status: 429,
330 body: "Too Many Requests".into(),
331 };
332 assert!(e.is_retryable());
333 }
334
335 #[test]
336 fn status_404_is_not_retryable() {
337 let e = ClientError::UnexpectedStatus {
338 status: 404,
339 body: "Not Found".into(),
340 };
341 assert!(!e.is_retryable());
342 }
343
344 #[test]
345 fn serialization_error_is_not_retryable() {
346 let e = ClientError::Serialization(serde_json::from_str::<String>("not json").unwrap_err());
347 assert!(!e.is_retryable());
348 }
349
350 #[test]
351 fn protocol_error_is_not_retryable() {
352 let e = ClientError::Protocol(a2a_protocol_types::A2aError::task_not_found("t1"));
353 assert!(!e.is_retryable());
354 }
355
356 #[test]
357 fn default_retry_policy() {
358 let p = RetryPolicy::default();
359 assert_eq!(p.max_retries, 3);
360 assert_eq!(p.initial_backoff, Duration::from_millis(500));
361 assert_eq!(p.max_backoff, Duration::from_secs(30));
362 assert!((p.backoff_multiplier - 2.0).abs() < f64::EPSILON);
363 }
364
365 #[test]
366 fn cap_backoff_works() {
367 let result = cap_backoff(Duration::from_secs(1), 2.0, Duration::from_secs(5));
368 assert_eq!(result, Duration::from_secs(2));
369
370 let result = cap_backoff(Duration::from_secs(4), 2.0, Duration::from_secs(5));
371 assert_eq!(result, Duration::from_secs(5));
372 }
373
374 #[test]
375 fn status_502_is_retryable() {
376 let e = ClientError::UnexpectedStatus {
377 status: 502,
378 body: "Bad Gateway".into(),
379 };
380 assert!(e.is_retryable());
381 }
382
383 #[test]
384 fn status_504_is_retryable() {
385 let e = ClientError::UnexpectedStatus {
386 status: 504,
387 body: "Gateway Timeout".into(),
388 };
389 assert!(e.is_retryable());
390 }
391
392 #[test]
394 fn status_boundary_not_retryable() {
395 for status in [428, 430, 500, 501, 505] {
396 let e = ClientError::UnexpectedStatus {
397 status,
398 body: String::new(),
399 };
400 assert!(!e.is_retryable(), "status {status} should not be retryable");
401 }
402 }
403
404 #[test]
405 fn retry_policy_builder_methods() {
406 let p = RetryPolicy::default()
407 .with_max_retries(5)
408 .with_initial_backoff(Duration::from_secs(1))
409 .with_max_backoff(Duration::from_secs(60))
410 .with_backoff_multiplier(3.0);
411 assert_eq!(p.max_retries, 5);
412 assert_eq!(p.initial_backoff, Duration::from_secs(1));
413 assert_eq!(p.max_backoff, Duration::from_secs(60));
414 assert!((p.backoff_multiplier - 3.0).abs() < f64::EPSILON);
415 }
416
417 #[test]
418 fn cap_backoff_exact_boundary() {
419 let result = cap_backoff(Duration::from_secs(5), 1.0, Duration::from_secs(5));
421 assert_eq!(result, Duration::from_secs(5));
422
423 let result = cap_backoff(Duration::from_millis(1), 2.0, Duration::from_secs(5));
425 assert_eq!(result, Duration::from_millis(2));
426 }
427
428 #[test]
429 fn cap_backoff_infinity_returns_max() {
430 let max = Duration::from_secs(30);
432 let result = cap_backoff(Duration::from_secs(u64::MAX / 2), f64::MAX, max);
433 assert_eq!(result, max, "infinity should clamp to max");
434 }
435
436 #[test]
438 fn jittered_backoff_in_expected_range() {
439 let backoff = Duration::from_secs(2);
440 for _ in 0..100 {
442 let result = jittered(backoff);
443 assert!(
444 result >= Duration::from_secs(1),
445 "jittered backoff should be >= backoff/2, got {result:?}"
446 );
447 assert!(
448 result <= backoff,
449 "jittered backoff should be <= backoff, got {result:?}"
450 );
451 }
452 }
453
454 #[test]
456 fn jittered_zero_backoff() {
457 let result = jittered(Duration::ZERO);
458 assert_eq!(result, Duration::ZERO);
459 }
460
461 #[test]
462 fn cap_backoff_nan_returns_max() {
463 let max = Duration::from_secs(30);
464 let result = cap_backoff(Duration::from_secs(0), f64::NAN, max);
465 assert_eq!(result, max, "NaN should clamp to max");
466 }
467
468 #[test]
473 fn jitter_factor_from_bits_zero() {
474 let f = jitter_factor_from_bits(0);
475 assert!(
476 (f - 0.5).abs() < f64::EPSILON,
477 "factor(0) should be 0.5, got {f}"
478 );
479 }
480
481 #[test]
483 fn jitter_factor_from_bits_midpoint() {
484 let f = jitter_factor_from_bits(u64::MAX / 2);
485 assert!(
487 (0.74..=0.76).contains(&f),
488 "factor(u64::MAX/2) should be ~0.75, got {f}"
489 );
490 }
491
492 #[test]
494 fn jitter_factor_from_bits_max() {
495 let f = jitter_factor_from_bits(u64::MAX);
496 assert!(
499 (0.9..=1.0).contains(&f),
500 "factor(u64::MAX) should be ~1.0, got {f}"
501 );
502 }
503
504 #[test]
508 fn jitter_factor_from_bits_always_in_half_to_one() {
509 for bits in [
510 0_u64,
511 1,
512 7,
513 42,
514 1 << 20,
515 1 << 50,
516 u64::MAX / 4,
517 u64::MAX / 2,
518 u64::MAX,
519 ] {
520 let f = jitter_factor_from_bits(bits);
521 assert!(
522 (0.5..=1.0).contains(&f),
523 "factor({bits}) = {f} out of [0.5, 1.0]"
524 );
525 }
526 }
527
528 #[test]
537 fn apply_jitter_normal_factor() {
538 assert_eq!(
540 apply_jitter(Duration::from_secs(2), 0.5),
541 Duration::from_secs(1)
542 );
543 assert_eq!(
545 apply_jitter(Duration::from_secs(4), 0.75),
546 Duration::from_secs(3)
547 );
548 assert_eq!(
550 apply_jitter(Duration::from_secs(5), 1.0),
551 Duration::from_secs(5)
552 );
553 }
554
555 #[test]
559 fn apply_jitter_zero_factor_returns_zero() {
560 assert_eq!(
561 apply_jitter(Duration::from_secs(5), 0.0),
562 Duration::ZERO,
563 "factor=0.0 must produce Duration::ZERO via from_secs_f64 path"
564 );
565 }
566
567 #[test]
571 fn apply_jitter_negative_factor_returns_backoff() {
572 assert_eq!(
573 apply_jitter(Duration::from_secs(3), -0.5),
574 Duration::from_secs(3),
575 "negative factor must short-circuit to backoff"
576 );
577 }
578
579 #[test]
588 fn apply_jitter_infinite_factor_returns_backoff() {
589 assert_eq!(
590 apply_jitter(Duration::from_secs(2), f64::INFINITY),
591 Duration::from_secs(2),
592 "infinite factor must short-circuit to backoff"
593 );
594 }
595
596 #[test]
597 fn apply_jitter_nan_factor_returns_backoff() {
598 assert_eq!(
599 apply_jitter(Duration::from_secs(4), f64::NAN),
600 Duration::from_secs(4),
601 "NaN factor must short-circuit to backoff"
602 );
603 }
604
605 use std::collections::HashMap;
608 use std::future::Future;
609 use std::pin::Pin;
610 use std::sync::atomic::{AtomicUsize, Ordering};
611 use std::sync::Arc;
612
613 use crate::streaming::EventStream;
614
615 struct FailNTransport {
617 failures_remaining: Arc<AtomicUsize>,
618 success_response: serde_json::Value,
619 call_count: Arc<AtomicUsize>,
620 }
621
622 impl FailNTransport {
623 fn new(fail_count: usize, response: serde_json::Value) -> Self {
624 Self {
625 failures_remaining: Arc::new(AtomicUsize::new(fail_count)),
626 success_response: response,
627 call_count: Arc::new(AtomicUsize::new(0)),
628 }
629 }
630 }
631
632 impl crate::transport::Transport for FailNTransport {
633 fn send_request<'a>(
634 &'a self,
635 _method: &'a str,
636 _params: serde_json::Value,
637 _extra_headers: &'a HashMap<String, String>,
638 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
639 self.call_count.fetch_add(1, Ordering::SeqCst);
640 let remaining = self.failures_remaining.fetch_sub(1, Ordering::SeqCst);
641 let resp = self.success_response.clone();
642 Box::pin(async move {
643 if remaining > 0 {
644 Err(ClientError::Timeout("transient".into()))
645 } else {
646 Ok(resp)
647 }
648 })
649 }
650
651 fn send_streaming_request<'a>(
652 &'a self,
653 _method: &'a str,
654 _params: serde_json::Value,
655 _extra_headers: &'a HashMap<String, String>,
656 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
657 self.call_count.fetch_add(1, Ordering::SeqCst);
658 let remaining = self.failures_remaining.fetch_sub(1, Ordering::SeqCst);
659 Box::pin(async move {
660 if remaining > 0 {
661 Err(ClientError::Timeout("transient".into()))
662 } else {
663 Err(ClientError::Transport("streaming not mocked".into()))
664 }
665 })
666 }
667 }
668
669 struct NonRetryableErrorTransport {
671 call_count: Arc<AtomicUsize>,
672 }
673
674 impl NonRetryableErrorTransport {
675 fn new() -> Self {
676 Self {
677 call_count: Arc::new(AtomicUsize::new(0)),
678 }
679 }
680 }
681
682 impl crate::transport::Transport for NonRetryableErrorTransport {
683 fn send_request<'a>(
684 &'a self,
685 _method: &'a str,
686 _params: serde_json::Value,
687 _extra_headers: &'a HashMap<String, String>,
688 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
689 self.call_count.fetch_add(1, Ordering::SeqCst);
690 Box::pin(async move { Err(ClientError::InvalidEndpoint("bad url".into())) })
691 }
692
693 fn send_streaming_request<'a>(
694 &'a self,
695 _method: &'a str,
696 _params: serde_json::Value,
697 _extra_headers: &'a HashMap<String, String>,
698 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
699 self.call_count.fetch_add(1, Ordering::SeqCst);
700 Box::pin(async move { Err(ClientError::InvalidEndpoint("bad url".into())) })
701 }
702 }
703
704 #[tokio::test]
705 async fn retry_transport_retries_on_transient_error() {
706 let inner = FailNTransport::new(2, serde_json::json!({"ok": true}));
707 let call_count = Arc::clone(&inner.call_count);
708 let transport = RetryTransport::new(
709 Box::new(inner),
710 RetryPolicy::default()
711 .with_initial_backoff(Duration::from_millis(1))
712 .with_max_retries(3),
713 );
714
715 let headers = HashMap::new();
716 let result = transport
717 .send_request("test", serde_json::Value::Null, &headers)
718 .await;
719 assert!(result.is_ok(), "should succeed after retries");
720 assert_eq!(
721 call_count.load(Ordering::SeqCst),
722 3,
723 "should have made 3 attempts (2 failures + 1 success)"
724 );
725 }
726
727 #[tokio::test]
728 async fn retry_transport_gives_up_after_max_retries() {
729 let inner = FailNTransport::new(10, serde_json::json!({"ok": true}));
731 let call_count = Arc::clone(&inner.call_count);
732 let transport = RetryTransport::new(
733 Box::new(inner),
734 RetryPolicy::default()
735 .with_initial_backoff(Duration::from_millis(1))
736 .with_max_retries(2),
737 );
738
739 let headers = HashMap::new();
740 let result = transport
741 .send_request("test", serde_json::Value::Null, &headers)
742 .await;
743 assert!(result.is_err(), "should fail after exhausting retries");
744 assert_eq!(
745 call_count.load(Ordering::SeqCst),
746 3,
747 "should have made 3 attempts (initial + 2 retries)"
748 );
749 }
750
751 #[tokio::test]
752 async fn retry_transport_no_retry_on_non_retryable() {
753 let inner = NonRetryableErrorTransport::new();
754 let call_count = Arc::clone(&inner.call_count);
755 let transport = RetryTransport::new(
756 Box::new(inner),
757 RetryPolicy::default()
758 .with_initial_backoff(Duration::from_millis(1))
759 .with_max_retries(3),
760 );
761
762 let headers = HashMap::new();
763 let result = transport
764 .send_request("test", serde_json::Value::Null, &headers)
765 .await;
766 assert!(result.is_err());
767 assert!(matches!(
768 result.unwrap_err(),
769 ClientError::InvalidEndpoint(_)
770 ));
771 assert_eq!(
772 call_count.load(Ordering::SeqCst),
773 1,
774 "non-retryable error should not be retried"
775 );
776 }
777
778 #[tokio::test]
779 async fn retry_transport_streaming_retries() {
780 let inner = FailNTransport::new(1, serde_json::json!(null));
781 let call_count = Arc::clone(&inner.call_count);
782 let transport = RetryTransport::new(
783 Box::new(inner),
784 RetryPolicy::default()
785 .with_initial_backoff(Duration::from_millis(1))
786 .with_max_retries(2),
787 );
788
789 let headers = HashMap::new();
790 let result = transport
791 .send_streaming_request("test", serde_json::Value::Null, &headers)
792 .await;
793 assert!(result.is_err());
796 assert_eq!(
797 call_count.load(Ordering::SeqCst),
798 2,
799 "should have retried once for streaming"
800 );
801 }
802
803 #[tokio::test]
804 async fn retry_transport_streaming_no_retry_on_non_retryable() {
805 let inner = NonRetryableErrorTransport::new();
806 let call_count = Arc::clone(&inner.call_count);
807 let transport = RetryTransport::new(
808 Box::new(inner),
809 RetryPolicy::default()
810 .with_initial_backoff(Duration::from_millis(1))
811 .with_max_retries(3),
812 );
813
814 let headers = HashMap::new();
815 let result = transport
816 .send_streaming_request("test", serde_json::Value::Null, &headers)
817 .await;
818 assert!(matches!(
819 result.unwrap_err(),
820 ClientError::InvalidEndpoint(_)
821 ));
822 assert_eq!(
823 call_count.load(Ordering::SeqCst),
824 1,
825 "non-retryable streaming error should not be retried"
826 );
827 }
828
829 #[tokio::test]
832 async fn retry_transport_streaming_succeeds_after_retry() {
833 use tokio::sync::mpsc;
834
835 struct FailThenStreamTransport {
837 call_count: Arc<AtomicUsize>,
838 }
839
840 impl crate::transport::Transport for FailThenStreamTransport {
841 fn send_request<'a>(
842 &'a self,
843 _method: &'a str,
844 _params: serde_json::Value,
845 _extra_headers: &'a HashMap<String, String>,
846 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
847 {
848 Box::pin(async move { Ok(serde_json::Value::Null) })
849 }
850
851 fn send_streaming_request<'a>(
852 &'a self,
853 _method: &'a str,
854 _params: serde_json::Value,
855 _extra_headers: &'a HashMap<String, String>,
856 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
857 let attempt = self.call_count.fetch_add(1, Ordering::SeqCst);
858 Box::pin(async move {
859 if attempt == 0 {
860 Err(ClientError::Timeout("transient timeout".into()))
861 } else {
862 let (tx, rx) = mpsc::channel(8);
864 drop(tx); Ok(EventStream::new(rx))
866 }
867 })
868 }
869 }
870
871 let call_count = Arc::new(AtomicUsize::new(0));
872 let inner = FailThenStreamTransport {
873 call_count: Arc::clone(&call_count),
874 };
875 let transport = RetryTransport::new(
876 Box::new(inner),
877 RetryPolicy::default()
878 .with_initial_backoff(Duration::from_millis(1))
879 .with_max_retries(2),
880 );
881
882 let headers = HashMap::new();
883 let result = transport
884 .send_streaming_request("test", serde_json::Value::Null, &headers)
885 .await;
886 assert!(result.is_ok(), "streaming should succeed after retry");
887 assert_eq!(
888 call_count.load(Ordering::SeqCst),
889 2,
890 "should have made 2 attempts (1 failure + 1 success)"
891 );
892 }
893
894 #[tokio::test]
895 async fn retry_transport_streaming_exhausts_retries() {
896 let inner = FailNTransport::new(10, serde_json::json!(null));
897 let call_count = Arc::clone(&inner.call_count);
898 let transport = RetryTransport::new(
899 Box::new(inner),
900 RetryPolicy::default()
901 .with_initial_backoff(Duration::from_millis(1))
902 .with_max_retries(2),
903 );
904
905 let headers = HashMap::new();
906 let result = transport
907 .send_streaming_request("test", serde_json::Value::Null, &headers)
908 .await;
909 assert!(result.is_err());
910 assert_eq!(
911 call_count.load(Ordering::SeqCst),
912 3,
913 "should make 3 attempts total for streaming"
914 );
915 }
916
917 #[tokio::test]
918 async fn retry_transport_succeeds_without_retry_on_first_attempt() {
919 let inner = FailNTransport::new(0, serde_json::json!({"ok": true}));
920 let call_count = Arc::clone(&inner.call_count);
921 let transport = RetryTransport::new(
922 Box::new(inner),
923 RetryPolicy::default()
924 .with_initial_backoff(Duration::from_millis(1))
925 .with_max_retries(3),
926 );
927
928 let headers = HashMap::new();
929 let result = transport
930 .send_request("test", serde_json::Value::Null, &headers)
931 .await;
932 assert!(result.is_ok());
933 assert_eq!(
934 call_count.load(Ordering::SeqCst),
935 1,
936 "should succeed on first try"
937 );
938 }
939
940 #[tokio::test(start_paused = true)]
946 async fn no_backoff_before_first_attempt() {
947 let inner = FailNTransport::new(0, serde_json::json!({"ok": true}));
948 let transport = RetryTransport::new(
949 Box::new(inner),
950 RetryPolicy::default()
951 .with_initial_backoff(Duration::from_secs(100))
952 .with_max_retries(1),
953 );
954
955 let start = tokio::time::Instant::now();
956 let headers = HashMap::new();
957 let result = transport
958 .send_request("test", serde_json::Value::Null, &headers)
959 .await;
960 assert!(result.is_ok());
961 assert!(
962 start.elapsed() < Duration::from_secs(1),
963 "first attempt must not sleep, elapsed: {:?}",
964 start.elapsed()
965 );
966 }
967
968 #[tokio::test(start_paused = true)]
972 async fn backoff_applied_on_retry() {
973 let inner = FailNTransport::new(1, serde_json::json!({"ok": true}));
974 let transport = RetryTransport::new(
975 Box::new(inner),
976 RetryPolicy::default()
977 .with_initial_backoff(Duration::from_secs(100))
978 .with_max_retries(2),
979 );
980
981 let start = tokio::time::Instant::now();
982 let headers = HashMap::new();
983 let result = transport
984 .send_request("test", serde_json::Value::Null, &headers)
985 .await;
986 assert!(result.is_ok());
987 assert!(
988 start.elapsed() >= Duration::from_secs(50),
989 "retry should sleep (jittered backoff), elapsed: {:?}",
990 start.elapsed()
991 );
992 }
993
994 #[tokio::test(start_paused = true)]
996 async fn no_backoff_before_first_streaming_attempt() {
997 use tokio::sync::mpsc;
998
999 struct ImmediateStreamTransport;
1000 impl crate::transport::Transport for ImmediateStreamTransport {
1001 fn send_request<'a>(
1002 &'a self,
1003 _method: &'a str,
1004 _params: serde_json::Value,
1005 _extra_headers: &'a HashMap<String, String>,
1006 ) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>>
1007 {
1008 Box::pin(async { Ok(serde_json::Value::Null) })
1009 }
1010 fn send_streaming_request<'a>(
1011 &'a self,
1012 _method: &'a str,
1013 _params: serde_json::Value,
1014 _extra_headers: &'a HashMap<String, String>,
1015 ) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
1016 Box::pin(async {
1017 let (tx, rx) = mpsc::channel(1);
1018 drop(tx);
1019 Ok(EventStream::new(rx))
1020 })
1021 }
1022 }
1023
1024 let transport = RetryTransport::new(
1025 Box::new(ImmediateStreamTransport),
1026 RetryPolicy::default()
1027 .with_initial_backoff(Duration::from_secs(100))
1028 .with_max_retries(1),
1029 );
1030
1031 let start = tokio::time::Instant::now();
1032 let headers = HashMap::new();
1033 let result = transport
1034 .send_streaming_request("test", serde_json::Value::Null, &headers)
1035 .await;
1036 assert!(result.is_ok());
1037 assert!(
1038 start.elapsed() < Duration::from_secs(1),
1039 "first streaming attempt must not sleep, elapsed: {:?}",
1040 start.elapsed()
1041 );
1042 }
1043
1044 #[tokio::test(start_paused = true)]
1046 async fn backoff_applied_on_streaming_retry() {
1047 let inner = FailNTransport::new(1, serde_json::json!(null));
1048 let transport = RetryTransport::new(
1049 Box::new(inner),
1050 RetryPolicy::default()
1051 .with_initial_backoff(Duration::from_secs(100))
1052 .with_max_retries(2),
1053 );
1054
1055 let start = tokio::time::Instant::now();
1056 let headers = HashMap::new();
1057 let _result = transport
1058 .send_streaming_request("test", serde_json::Value::Null, &headers)
1059 .await;
1060 assert!(
1063 start.elapsed() >= Duration::from_secs(50),
1064 "streaming retry should sleep, elapsed: {:?}",
1065 start.elapsed()
1066 );
1067 }
1068
1069 #[test]
1074 fn cap_backoff_zero_multiplier_returns_zero() {
1075 let max = Duration::from_secs(30);
1076 let result = cap_backoff(Duration::from_secs(5), 0.0, max);
1077 assert_eq!(
1078 result,
1079 Duration::ZERO,
1080 "0 * any = 0, should not clamp to max"
1081 );
1082 }
1083}