1use std::fmt::Debug;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::callbacks::CallbackManagerForChainRun;
15use crate::error::{Error, Result};
16
17use super::base::Runnable;
18use super::config::{
19 ConfigOrList, RunnableConfig, ensure_config, get_callback_manager_for_config, get_config_list,
20 patch_config,
21};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ExponentialJitterParams {
30 #[serde(default = "default_initial")]
32 pub initial: f64,
33
34 #[serde(default = "default_max")]
36 pub max: f64,
37
38 #[serde(default = "default_exp_base")]
40 pub exp_base: f64,
41
42 #[serde(default = "default_jitter")]
44 pub jitter: f64,
45}
46
47fn default_initial() -> f64 {
48 1.0
49}
50
51fn default_max() -> f64 {
52 60.0
53}
54
55fn default_exp_base() -> f64 {
56 2.0
57}
58
59fn default_jitter() -> f64 {
60 1.0
61}
62
63impl Default for ExponentialJitterParams {
64 fn default() -> Self {
65 Self {
66 initial: 1.0,
67 max: 60.0,
68 exp_base: 2.0,
69 jitter: 1.0,
70 }
71 }
72}
73
74impl ExponentialJitterParams {
75 pub fn new() -> Self {
77 Self::default()
78 }
79
80 pub fn with_initial(mut self, initial: f64) -> Self {
82 self.initial = initial;
83 self
84 }
85
86 pub fn with_max(mut self, max: f64) -> Self {
88 self.max = max;
89 self
90 }
91
92 pub fn with_exp_base(mut self, exp_base: f64) -> Self {
94 self.exp_base = exp_base;
95 self
96 }
97
98 pub fn with_jitter(mut self, jitter: f64) -> Self {
100 self.jitter = jitter;
101 self
102 }
103
104 pub fn calculate_wait(&self, attempt: usize) -> Duration {
106 let exp_wait = self.initial * self.exp_base.powi(attempt.saturating_sub(1) as i32);
107 let capped_wait = exp_wait.min(self.max);
108 let jitter_amount = if self.jitter > 0.0 {
109 let mut rng = rand::rng();
110 rng.random_range(0.0..self.jitter)
111 } else {
112 0.0
113 };
114 let total_seconds = capped_wait + jitter_amount;
115 Duration::from_secs_f64(total_seconds)
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct RetryCallState {
122 pub attempt_number: usize,
124 pub succeeded: bool,
126}
127
128impl RetryCallState {
129 fn new(attempt_number: usize) -> Self {
130 Self {
131 attempt_number,
132 succeeded: false,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Default)]
141pub enum RetryErrorPredicate {
142 #[default]
144 All,
145 HttpErrors,
147 Custom(fn(&Error) -> bool),
149}
150
151impl RetryErrorPredicate {
152 pub fn should_retry(&self, error: &Error) -> bool {
154 match self {
155 RetryErrorPredicate::All => true,
156 RetryErrorPredicate::HttpErrors => matches!(error, Error::Http(_) | Error::Api { .. }),
157 RetryErrorPredicate::Custom(predicate) => predicate(error),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct RunnableRetryConfig {
165 pub retry_predicate: RetryErrorPredicate,
167
168 pub wait_exponential_jitter: bool,
170
171 pub exponential_jitter_params: Option<ExponentialJitterParams>,
173
174 pub max_attempt_number: usize,
176}
177
178impl Default for RunnableRetryConfig {
179 fn default() -> Self {
180 Self {
181 retry_predicate: RetryErrorPredicate::All,
182 wait_exponential_jitter: true,
183 exponential_jitter_params: None,
184 max_attempt_number: 3,
185 }
186 }
187}
188
189impl RunnableRetryConfig {
190 pub fn new() -> Self {
192 Self::default()
193 }
194
195 pub fn with_retry_predicate(mut self, predicate: RetryErrorPredicate) -> Self {
197 self.retry_predicate = predicate;
198 self
199 }
200
201 pub fn with_wait_exponential_jitter(mut self, wait: bool) -> Self {
203 self.wait_exponential_jitter = wait;
204 self
205 }
206
207 pub fn with_exponential_jitter_params(mut self, params: ExponentialJitterParams) -> Self {
209 self.exponential_jitter_params = Some(params);
210 self
211 }
212
213 pub fn with_max_attempt_number(mut self, max: usize) -> Self {
215 self.max_attempt_number = max;
216 self
217 }
218}
219
220pub struct RunnableRetry<R>
245where
246 R: Runnable,
247{
248 bound: R,
250
251 config: RunnableRetryConfig,
253}
254
255impl<R> Debug for RunnableRetry<R>
256where
257 R: Runnable,
258{
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.debug_struct("RunnableRetry")
261 .field("bound", &self.bound)
262 .field("max_attempt_number", &self.config.max_attempt_number)
263 .field(
264 "wait_exponential_jitter",
265 &self.config.wait_exponential_jitter,
266 )
267 .finish()
268 }
269}
270
271impl<R> RunnableRetry<R>
272where
273 R: Runnable,
274{
275 pub fn new(bound: R, config: RunnableRetryConfig) -> Self {
277 Self { bound, config }
278 }
279
280 pub fn with_simple(bound: R, max_attempts: usize, wait_exponential_jitter: bool) -> Self {
282 Self {
283 bound,
284 config: RunnableRetryConfig {
285 max_attempt_number: max_attempts,
286 wait_exponential_jitter,
287 ..Default::default()
288 },
289 }
290 }
291
292 fn get_jitter_params(&self) -> ExponentialJitterParams {
294 self.config
295 .exponential_jitter_params
296 .clone()
297 .unwrap_or_default()
298 }
299
300 fn should_retry(&self, error: &Error) -> bool {
302 self.config.retry_predicate.should_retry(error)
303 }
304
305 fn calculate_wait(&self, attempt: usize) -> Duration {
307 if self.config.wait_exponential_jitter {
308 self.get_jitter_params().calculate_wait(attempt)
309 } else {
310 Duration::ZERO
311 }
312 }
313
314 fn patch_config_for_retry(
316 config: &RunnableConfig,
317 run_manager: &CallbackManagerForChainRun,
318 retry_state: &RetryCallState,
319 ) -> RunnableConfig {
320 let tag = if retry_state.attempt_number > 1 {
321 Some(format!("retry:attempt:{}", retry_state.attempt_number))
322 } else {
323 None
324 };
325
326 patch_config(
327 Some(config.clone()),
328 Some(run_manager.get_child(tag.as_deref())),
329 None,
330 None,
331 None,
332 None,
333 )
334 }
335
336 fn patch_config_list_for_retry(
338 configs: &[RunnableConfig],
339 run_managers: &[CallbackManagerForChainRun],
340 retry_state: &RetryCallState,
341 ) -> Vec<RunnableConfig> {
342 configs
343 .iter()
344 .zip(run_managers.iter())
345 .map(|(config, run_manager)| {
346 Self::patch_config_for_retry(config, run_manager, retry_state)
347 })
348 .collect()
349 }
350}
351
352#[async_trait]
353impl<R> Runnable for RunnableRetry<R>
354where
355 R: Runnable + 'static,
356{
357 type Input = R::Input;
358 type Output = R::Output;
359
360 fn name(&self) -> Option<String> {
361 self.bound.name()
362 }
363
364 fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
365 let config = ensure_config(config);
366 let callback_manager = get_callback_manager_for_config(&config);
367
368 let run_manager = callback_manager.on_chain_start(
370 &std::collections::HashMap::new(),
371 &std::collections::HashMap::new(),
372 config.run_id,
373 );
374
375 let mut last_error = None;
376
377 for attempt in 1..=self.config.max_attempt_number {
378 let retry_state = RetryCallState::new(attempt);
379 let patched_config = Self::patch_config_for_retry(&config, &run_manager, &retry_state);
380
381 match self.bound.invoke(input.clone(), Some(patched_config)) {
382 Ok(output) => {
383 run_manager.on_chain_end(&std::collections::HashMap::new());
384 return Ok(output);
385 }
386 Err(e) => {
387 if !self.should_retry(&e) || attempt == self.config.max_attempt_number {
388 run_manager.on_chain_error(&e);
389 return Err(e);
390 }
391 last_error = Some(e);
392
393 if self.config.wait_exponential_jitter
395 && attempt < self.config.max_attempt_number
396 {
397 let wait = self.calculate_wait(attempt);
398 std::thread::sleep(wait);
399 }
400 }
401 }
402 }
403
404 let error = last_error.unwrap_or_else(|| Error::other("Max retries exceeded"));
405 run_manager.on_chain_error(&error);
406 Err(error)
407 }
408
409 async fn ainvoke(
410 &self,
411 input: Self::Input,
412 config: Option<RunnableConfig>,
413 ) -> Result<Self::Output>
414 where
415 Self: 'static,
416 {
417 let config = ensure_config(config);
418 let callback_manager = get_callback_manager_for_config(&config);
419
420 let run_manager = callback_manager.on_chain_start(
422 &std::collections::HashMap::new(),
423 &std::collections::HashMap::new(),
424 config.run_id,
425 );
426
427 let mut last_error = None;
428
429 for attempt in 1..=self.config.max_attempt_number {
430 let retry_state = RetryCallState::new(attempt);
431 let patched_config = Self::patch_config_for_retry(&config, &run_manager, &retry_state);
432
433 match self
434 .bound
435 .ainvoke(input.clone(), Some(patched_config))
436 .await
437 {
438 Ok(output) => {
439 run_manager.on_chain_end(&std::collections::HashMap::new());
440 return Ok(output);
441 }
442 Err(e) => {
443 if !self.should_retry(&e) || attempt == self.config.max_attempt_number {
444 run_manager.on_chain_error(&e);
445 return Err(e);
446 }
447 last_error = Some(e);
448
449 if self.config.wait_exponential_jitter
451 && attempt < self.config.max_attempt_number
452 {
453 let wait = self.calculate_wait(attempt);
454 tokio::time::sleep(wait).await;
455 }
456 }
457 }
458 }
459
460 let error = last_error.unwrap_or_else(|| Error::other("Max retries exceeded"));
461 run_manager.on_chain_error(&error);
462 Err(error)
463 }
464
465 fn batch(
466 &self,
467 inputs: Vec<Self::Input>,
468 config: Option<ConfigOrList>,
469 return_exceptions: bool,
470 ) -> Vec<Result<Self::Output>>
471 where
472 Self: 'static,
473 {
474 if inputs.is_empty() {
475 return Vec::new();
476 }
477
478 let configs = get_config_list(config, inputs.len());
479 let n = inputs.len();
480
481 let run_managers: Vec<CallbackManagerForChainRun> = configs
483 .iter()
484 .map(|config| {
485 let callback_manager = get_callback_manager_for_config(config);
486 callback_manager.on_chain_start(
487 &std::collections::HashMap::new(),
488 &std::collections::HashMap::new(),
489 config.run_id,
490 )
491 })
492 .collect();
493
494 let mut results: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
496
497 let mut remaining: Vec<usize> = (0..n).collect();
499
500 for attempt in 1..=self.config.max_attempt_number {
501 if remaining.is_empty() {
502 break;
503 }
504
505 let retry_state = RetryCallState::new(attempt);
506
507 let pending_inputs: Vec<Self::Input> =
509 remaining.iter().map(|&i| inputs[i].clone()).collect();
510 let pending_configs: Vec<RunnableConfig> =
511 remaining.iter().map(|&i| configs[i].clone()).collect();
512 let pending_managers: Vec<CallbackManagerForChainRun> =
513 remaining.iter().map(|&i| run_managers[i].clone()).collect();
514
515 let patched_configs = Self::patch_config_list_for_retry(
516 &pending_configs,
517 &pending_managers,
518 &retry_state,
519 );
520
521 let batch_results = self.bound.batch(
523 pending_inputs,
524 Some(ConfigOrList::List(patched_configs)),
525 true, );
527
528 let mut next_remaining = Vec::new();
530 let mut first_non_retryable_error: Option<Error> = None;
531
532 for (offset, result) in batch_results.into_iter().enumerate() {
533 let orig_idx = remaining[offset];
534
535 match result {
536 Ok(output) => {
537 results[orig_idx] = Some(Ok(output));
538 }
539 Err(e) => {
540 if self.should_retry(&e) && attempt < self.config.max_attempt_number {
541 results[orig_idx] = Some(Err(e));
543 next_remaining.push(orig_idx);
544 } else if !self.should_retry(&e) && !return_exceptions {
545 if first_non_retryable_error.is_none() {
547 first_non_retryable_error = Some(e);
548 }
549 results[orig_idx] = Some(Err(Error::other("Batch aborted")));
550 } else {
551 results[orig_idx] = Some(Err(e));
553 }
554 }
555 }
556 }
557
558 if first_non_retryable_error.is_some() && !return_exceptions {
560 for result in results.iter_mut().take(n) {
562 if result.is_none() {
563 *result = Some(Err(Error::other("Batch aborted due to error")));
564 }
565 }
566 break;
567 }
568
569 remaining = next_remaining;
570
571 if !remaining.is_empty()
573 && self.config.wait_exponential_jitter
574 && attempt < self.config.max_attempt_number
575 {
576 let wait = self.calculate_wait(attempt);
577 std::thread::sleep(wait);
578 }
579 }
580
581 results
583 .into_iter()
584 .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result"))))
585 .collect()
586 }
587
588 async fn abatch(
589 &self,
590 inputs: Vec<Self::Input>,
591 config: Option<ConfigOrList>,
592 return_exceptions: bool,
593 ) -> Vec<Result<Self::Output>>
594 where
595 Self: 'static,
596 {
597 if inputs.is_empty() {
598 return Vec::new();
599 }
600
601 let configs = get_config_list(config, inputs.len());
602 let n = inputs.len();
603
604 let run_managers: Vec<CallbackManagerForChainRun> = configs
606 .iter()
607 .map(|config| {
608 let callback_manager = get_callback_manager_for_config(config);
609 callback_manager.on_chain_start(
610 &std::collections::HashMap::new(),
611 &std::collections::HashMap::new(),
612 config.run_id,
613 )
614 })
615 .collect();
616
617 let mut results: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
619
620 let mut remaining: Vec<usize> = (0..n).collect();
622
623 for attempt in 1..=self.config.max_attempt_number {
624 if remaining.is_empty() {
625 break;
626 }
627
628 let retry_state = RetryCallState::new(attempt);
629
630 let pending_inputs: Vec<Self::Input> =
632 remaining.iter().map(|&i| inputs[i].clone()).collect();
633 let pending_configs: Vec<RunnableConfig> =
634 remaining.iter().map(|&i| configs[i].clone()).collect();
635 let pending_managers: Vec<CallbackManagerForChainRun> =
636 remaining.iter().map(|&i| run_managers[i].clone()).collect();
637
638 let patched_configs = Self::patch_config_list_for_retry(
639 &pending_configs,
640 &pending_managers,
641 &retry_state,
642 );
643
644 let batch_results = self
646 .bound
647 .abatch(
648 pending_inputs,
649 Some(ConfigOrList::List(patched_configs)),
650 true, )
652 .await;
653
654 let mut next_remaining = Vec::new();
656 let mut first_non_retryable_error: Option<Error> = None;
657
658 for (offset, result) in batch_results.into_iter().enumerate() {
659 let orig_idx = remaining[offset];
660
661 match result {
662 Ok(output) => {
663 results[orig_idx] = Some(Ok(output));
664 }
665 Err(e) => {
666 if self.should_retry(&e) && attempt < self.config.max_attempt_number {
667 results[orig_idx] = Some(Err(e));
669 next_remaining.push(orig_idx);
670 } else if !self.should_retry(&e) && !return_exceptions {
671 if first_non_retryable_error.is_none() {
673 first_non_retryable_error = Some(e);
674 }
675 results[orig_idx] = Some(Err(Error::other("Batch aborted")));
676 } else {
677 results[orig_idx] = Some(Err(e));
679 }
680 }
681 }
682 }
683
684 if first_non_retryable_error.is_some() && !return_exceptions {
686 for result in results.iter_mut().take(n) {
688 if result.is_none() {
689 *result = Some(Err(Error::other("Batch aborted due to error")));
690 }
691 }
692 break;
693 }
694
695 remaining = next_remaining;
696
697 if !remaining.is_empty()
699 && self.config.wait_exponential_jitter
700 && attempt < self.config.max_attempt_number
701 {
702 let wait = self.calculate_wait(attempt);
703 tokio::time::sleep(wait).await;
704 }
705 }
706
707 results
709 .into_iter()
710 .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result"))))
711 .collect()
712 }
713
714 }
717
718pub trait RunnableRetryExt: Runnable {
720 fn with_retry_config(self, config: RunnableRetryConfig) -> RunnableRetry<Self>
728 where
729 Self: Sized,
730 {
731 RunnableRetry::new(self, config)
732 }
733}
734
735impl<R: Runnable> RunnableRetryExt for R {}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use crate::runnables::base::RunnableLambda;
742 use std::sync::Arc;
743 use std::sync::atomic::{AtomicUsize, Ordering};
744
745 #[test]
746 fn test_retry_succeeds_first_attempt() {
747 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
748 let config = RunnableRetryConfig::new()
749 .with_max_attempt_number(3)
750 .with_wait_exponential_jitter(false);
751 let retry = RunnableRetry::new(runnable, config);
752
753 let result = retry.invoke(1, None).unwrap();
754 assert_eq!(result, 2);
755 }
756
757 #[test]
758 fn test_retry_succeeds_after_failures() {
759 let counter = Arc::new(AtomicUsize::new(0));
760 let counter_clone = counter.clone();
761
762 let runnable = RunnableLambda::new(move |x: i32| {
763 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
764 if count < 2 {
765 Err(Error::other("transient failure"))
766 } else {
767 Ok(x * 2)
768 }
769 });
770
771 let config = RunnableRetryConfig::new()
772 .with_max_attempt_number(5)
773 .with_wait_exponential_jitter(false);
774 let retry = RunnableRetry::new(runnable, config);
775
776 let result = retry.invoke(5, None).unwrap();
777 assert_eq!(result, 10);
778 assert_eq!(counter.load(Ordering::SeqCst), 3);
779 }
780
781 #[test]
782 fn test_retry_exhausted() {
783 let counter = Arc::new(AtomicUsize::new(0));
784 let counter_clone = counter.clone();
785
786 let runnable = RunnableLambda::new(move |_x: i32| {
787 counter_clone.fetch_add(1, Ordering::SeqCst);
788 Err::<i32, _>(Error::other("always fails"))
789 });
790
791 let config = RunnableRetryConfig::new()
792 .with_max_attempt_number(3)
793 .with_wait_exponential_jitter(false);
794 let retry = RunnableRetry::new(runnable, config);
795
796 let result = retry.invoke(1, None);
797 assert!(result.is_err());
798 assert_eq!(counter.load(Ordering::SeqCst), 3);
799 }
800
801 #[test]
802 fn test_retry_predicate_http_errors() {
803 let counter = Arc::new(AtomicUsize::new(0));
804 let counter_clone = counter.clone();
805
806 let runnable = RunnableLambda::new(move |_x: i32| {
808 counter_clone.fetch_add(1, Ordering::SeqCst);
809 Err::<i32, _>(Error::other("not an HTTP error"))
810 });
811
812 let config = RunnableRetryConfig::new()
813 .with_max_attempt_number(3)
814 .with_retry_predicate(RetryErrorPredicate::HttpErrors)
815 .with_wait_exponential_jitter(false);
816 let retry = RunnableRetry::new(runnable, config);
817
818 let result = retry.invoke(1, None);
819 assert!(result.is_err());
820 assert_eq!(counter.load(Ordering::SeqCst), 1);
822 }
823
824 #[test]
825 fn test_exponential_jitter_params() {
826 let params = ExponentialJitterParams::new()
827 .with_initial(0.1)
828 .with_max(1.0)
829 .with_exp_base(2.0)
830 .with_jitter(0.0);
831
832 let wait1 = params.calculate_wait(1);
834 assert!(wait1.as_secs_f64() >= 0.1 && wait1.as_secs_f64() < 0.2);
835
836 let wait2 = params.calculate_wait(2);
838 assert!(wait2.as_secs_f64() >= 0.2 && wait2.as_secs_f64() < 0.3);
839
840 let wait3 = params.calculate_wait(3);
842 assert!(wait3.as_secs_f64() >= 0.4 && wait3.as_secs_f64() < 0.5);
843 }
844
845 #[test]
846 fn test_exponential_jitter_max_cap() {
847 let params = ExponentialJitterParams::new()
848 .with_initial(1.0)
849 .with_max(2.0)
850 .with_exp_base(10.0)
851 .with_jitter(0.0);
852
853 let wait = params.calculate_wait(10);
855 assert!(wait.as_secs_f64() >= 2.0 && wait.as_secs_f64() < 2.1);
856 }
857
858 #[test]
859 fn test_retry_ext_trait() {
860 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
861 let config = RunnableRetryConfig::new().with_max_attempt_number(3);
862 let retry = runnable.with_retry_config(config);
863
864 let result = retry.invoke(1, None).unwrap();
865 assert_eq!(result, 2);
866 }
867
868 #[test]
869 fn test_retry_with_simple() {
870 let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
871 let retry = runnable.with_retry(3, false);
872
873 let result = retry.invoke(1, None).unwrap();
874 assert_eq!(result, 2);
875 }
876
877 #[test]
878 fn test_batch_retry_partial_failures() {
879 let counter = Arc::new(AtomicUsize::new(0));
880 let counter_clone = counter.clone();
881
882 let runnable = RunnableLambda::new(move |x: i32| {
884 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
885 if x < 0 && count < 4 {
886 Err(Error::other("negative input"))
887 } else {
888 Ok(x * 2)
889 }
890 });
891
892 let config = RunnableRetryConfig::new()
893 .with_max_attempt_number(3)
894 .with_wait_exponential_jitter(false);
895 let retry = RunnableRetry::new(runnable, config);
896
897 let results = retry.batch(vec![1, -1, 2], None, true);
898
899 assert!(results[0].is_ok());
902 assert!(results[2].is_ok());
903 }
905
906 #[tokio::test]
907 async fn test_async_retry() {
908 let counter = Arc::new(AtomicUsize::new(0));
909 let counter_clone = counter.clone();
910
911 let runnable = RunnableLambda::new(move |x: i32| {
912 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
913 if count < 1 {
914 Err(Error::other("transient failure"))
915 } else {
916 Ok(x * 2)
917 }
918 });
919
920 let config = RunnableRetryConfig::new()
921 .with_max_attempt_number(3)
922 .with_wait_exponential_jitter(false);
923 let retry = RunnableRetry::new(runnable, config);
924
925 let result = retry.ainvoke(5, None).await.unwrap();
926 assert_eq!(result, 10);
927 assert_eq!(counter.load(Ordering::SeqCst), 2);
928 }
929}