agent_chain_core/runnables/
retry.rs

1//! Runnable that retries a Runnable if it fails.
2//!
3//! This module provides `RunnableRetry`, a Runnable that wraps another Runnable
4//! and retries it on failure with configurable retry logic.
5//! This mirrors `langchain_core.runnables.retry`.
6
7use std::fmt::Debug;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::callbacks::CallbackManagerForChainRun;
15use crate::error::{Error, Result};
16
17use super::base::Runnable;
18use super::config::{
19    ConfigOrList, RunnableConfig, ensure_config, get_callback_manager_for_config, get_config_list,
20    patch_config,
21};
22
23/// Parameters for exponential backoff with jitter.
24///
25/// These parameters control the wait time between retry attempts.
26/// The wait time is calculated as:
27/// `min(max, initial * exp_base^attempt) + random(0, jitter)`
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ExponentialJitterParams {
30    /// Initial wait time in seconds. Default: 1.0
31    #[serde(default = "default_initial")]
32    pub initial: f64,
33
34    /// Maximum wait time in seconds. Default: 60.0
35    #[serde(default = "default_max")]
36    pub max: f64,
37
38    /// Base for exponential backoff. Default: 2.0
39    #[serde(default = "default_exp_base")]
40    pub exp_base: f64,
41
42    /// Maximum jitter in seconds to add. Default: 1.0
43    #[serde(default = "default_jitter")]
44    pub jitter: f64,
45}
46
47fn default_initial() -> f64 {
48    1.0
49}
50
51fn default_max() -> f64 {
52    60.0
53}
54
55fn default_exp_base() -> f64 {
56    2.0
57}
58
59fn default_jitter() -> f64 {
60    1.0
61}
62
63impl Default for ExponentialJitterParams {
64    fn default() -> Self {
65        Self {
66            initial: 1.0,
67            max: 60.0,
68            exp_base: 2.0,
69            jitter: 1.0,
70        }
71    }
72}
73
74impl ExponentialJitterParams {
75    /// Create new exponential jitter parameters with default values.
76    pub fn new() -> Self {
77        Self::default()
78    }
79
80    /// Set the initial wait time.
81    pub fn with_initial(mut self, initial: f64) -> Self {
82        self.initial = initial;
83        self
84    }
85
86    /// Set the maximum wait time.
87    pub fn with_max(mut self, max: f64) -> Self {
88        self.max = max;
89        self
90    }
91
92    /// Set the exponential base.
93    pub fn with_exp_base(mut self, exp_base: f64) -> Self {
94        self.exp_base = exp_base;
95        self
96    }
97
98    /// Set the jitter value.
99    pub fn with_jitter(mut self, jitter: f64) -> Self {
100        self.jitter = jitter;
101        self
102    }
103
104    /// Calculate the wait time for a given attempt number (1-indexed).
105    pub fn calculate_wait(&self, attempt: usize) -> Duration {
106        let exp_wait = self.initial * self.exp_base.powi(attempt.saturating_sub(1) as i32);
107        let capped_wait = exp_wait.min(self.max);
108        let jitter_amount = if self.jitter > 0.0 {
109            let mut rng = rand::rng();
110            rng.random_range(0.0..self.jitter)
111        } else {
112            0.0
113        };
114        let total_seconds = capped_wait + jitter_amount;
115        Duration::from_secs_f64(total_seconds)
116    }
117}
118
119/// State of a retry attempt.
120#[derive(Debug, Clone)]
121pub struct RetryCallState {
122    /// The current attempt number (1-indexed).
123    pub attempt_number: usize,
124    /// Whether the attempt succeeded.
125    pub succeeded: bool,
126}
127
128impl RetryCallState {
129    fn new(attempt_number: usize) -> Self {
130        Self {
131            attempt_number,
132            succeeded: false,
133        }
134    }
135}
136
137/// Error type predicate for retry logic.
138///
139/// This enum allows specifying which error types should trigger a retry.
140#[derive(Debug, Clone, Default)]
141pub enum RetryErrorPredicate {
142    /// Retry on all errors (default).
143    #[default]
144    All,
145    /// Retry only on HTTP/API errors.
146    HttpErrors,
147    /// Retry only on specific error variants using a custom predicate.
148    Custom(fn(&Error) -> bool),
149}
150
151impl RetryErrorPredicate {
152    /// Check if the given error should trigger a retry.
153    pub fn should_retry(&self, error: &Error) -> bool {
154        match self {
155            RetryErrorPredicate::All => true,
156            RetryErrorPredicate::HttpErrors => matches!(error, Error::Http(_) | Error::Api { .. }),
157            RetryErrorPredicate::Custom(predicate) => predicate(error),
158        }
159    }
160}
161
162/// Configuration for creating a RunnableRetry.
163#[derive(Debug, Clone)]
164pub struct RunnableRetryConfig {
165    /// The exception types to retry on. By default all exceptions are retried.
166    pub retry_predicate: RetryErrorPredicate,
167
168    /// Whether to add jitter to the exponential backoff.
169    pub wait_exponential_jitter: bool,
170
171    /// Parameters for exponential backoff with jitter.
172    pub exponential_jitter_params: Option<ExponentialJitterParams>,
173
174    /// The maximum number of attempts to retry the Runnable.
175    pub max_attempt_number: usize,
176}
177
178impl Default for RunnableRetryConfig {
179    fn default() -> Self {
180        Self {
181            retry_predicate: RetryErrorPredicate::All,
182            wait_exponential_jitter: true,
183            exponential_jitter_params: None,
184            max_attempt_number: 3,
185        }
186    }
187}
188
189impl RunnableRetryConfig {
190    /// Create a new retry configuration with default values.
191    pub fn new() -> Self {
192        Self::default()
193    }
194
195    /// Set the retry predicate.
196    pub fn with_retry_predicate(mut self, predicate: RetryErrorPredicate) -> Self {
197        self.retry_predicate = predicate;
198        self
199    }
200
201    /// Set whether to use exponential jitter.
202    pub fn with_wait_exponential_jitter(mut self, wait: bool) -> Self {
203        self.wait_exponential_jitter = wait;
204        self
205    }
206
207    /// Set the exponential jitter parameters.
208    pub fn with_exponential_jitter_params(mut self, params: ExponentialJitterParams) -> Self {
209        self.exponential_jitter_params = Some(params);
210        self
211    }
212
213    /// Set the maximum number of attempts.
214    pub fn with_max_attempt_number(mut self, max: usize) -> Self {
215        self.max_attempt_number = max;
216        self
217    }
218}
219
220/// A Runnable that retries on failure.
221///
222/// `RunnableRetry` wraps another `Runnable` and retries it if it fails.
223/// This is particularly useful for network calls that may fail due to transient errors.
224///
225/// # Example
226///
227/// ```ignore
228/// use agent_chain_core::runnables::{RunnableLambda, RunnableRetry, RunnableRetryConfig};
229///
230/// // Create a runnable that might fail
231/// let runnable = RunnableLambda::new(|x: i32| {
232///     // Simulated unreliable operation
233///     if x > 0 { Ok(x * 2) }
234///     else { Err(Error::other("negative input")) }
235/// });
236///
237/// // Wrap it with retry logic
238/// let config = RunnableRetryConfig::new()
239///     .with_max_attempt_number(3)
240///     .with_wait_exponential_jitter(true);
241///
242/// let with_retry = RunnableRetry::new(runnable, config);
243/// ```
244pub struct RunnableRetry<R>
245where
246    R: Runnable,
247{
248    /// The wrapped runnable.
249    bound: R,
250
251    /// Retry configuration.
252    config: RunnableRetryConfig,
253}
254
255impl<R> Debug for RunnableRetry<R>
256where
257    R: Runnable,
258{
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.debug_struct("RunnableRetry")
261            .field("bound", &self.bound)
262            .field("max_attempt_number", &self.config.max_attempt_number)
263            .field(
264                "wait_exponential_jitter",
265                &self.config.wait_exponential_jitter,
266            )
267            .finish()
268    }
269}
270
271impl<R> RunnableRetry<R>
272where
273    R: Runnable,
274{
275    /// Create a new RunnableRetry with the given configuration.
276    pub fn new(bound: R, config: RunnableRetryConfig) -> Self {
277        Self { bound, config }
278    }
279
280    /// Create a new RunnableRetry with simple parameters.
281    pub fn with_simple(bound: R, max_attempts: usize, wait_exponential_jitter: bool) -> Self {
282        Self {
283            bound,
284            config: RunnableRetryConfig {
285                max_attempt_number: max_attempts,
286                wait_exponential_jitter,
287                ..Default::default()
288            },
289        }
290    }
291
292    /// Get the exponential jitter parameters, using defaults if not set.
293    fn get_jitter_params(&self) -> ExponentialJitterParams {
294        self.config
295            .exponential_jitter_params
296            .clone()
297            .unwrap_or_default()
298    }
299
300    /// Check if the error should trigger a retry.
301    fn should_retry(&self, error: &Error) -> bool {
302        self.config.retry_predicate.should_retry(error)
303    }
304
305    /// Calculate the wait time for a given attempt.
306    fn calculate_wait(&self, attempt: usize) -> Duration {
307        if self.config.wait_exponential_jitter {
308            self.get_jitter_params().calculate_wait(attempt)
309        } else {
310            Duration::ZERO
311        }
312    }
313
314    /// Patch the config for a retry attempt.
315    fn patch_config_for_retry(
316        config: &RunnableConfig,
317        run_manager: &CallbackManagerForChainRun,
318        retry_state: &RetryCallState,
319    ) -> RunnableConfig {
320        let tag = if retry_state.attempt_number > 1 {
321            Some(format!("retry:attempt:{}", retry_state.attempt_number))
322        } else {
323            None
324        };
325
326        patch_config(
327            Some(config.clone()),
328            Some(run_manager.get_child(tag.as_deref())),
329            None,
330            None,
331            None,
332            None,
333        )
334    }
335
336    /// Patch configs for batch retry.
337    fn patch_config_list_for_retry(
338        configs: &[RunnableConfig],
339        run_managers: &[CallbackManagerForChainRun],
340        retry_state: &RetryCallState,
341    ) -> Vec<RunnableConfig> {
342        configs
343            .iter()
344            .zip(run_managers.iter())
345            .map(|(config, run_manager)| {
346                Self::patch_config_for_retry(config, run_manager, retry_state)
347            })
348            .collect()
349    }
350}
351
352#[async_trait]
353impl<R> Runnable for RunnableRetry<R>
354where
355    R: Runnable + 'static,
356{
357    type Input = R::Input;
358    type Output = R::Output;
359
360    fn name(&self) -> Option<String> {
361        self.bound.name()
362    }
363
364    fn invoke(&self, input: Self::Input, config: Option<RunnableConfig>) -> Result<Self::Output> {
365        let config = ensure_config(config);
366        let callback_manager = get_callback_manager_for_config(&config);
367
368        // Start the chain run
369        let run_manager = callback_manager.on_chain_start(
370            &std::collections::HashMap::new(),
371            &std::collections::HashMap::new(),
372            config.run_id,
373        );
374
375        let mut last_error = None;
376
377        for attempt in 1..=self.config.max_attempt_number {
378            let retry_state = RetryCallState::new(attempt);
379            let patched_config = Self::patch_config_for_retry(&config, &run_manager, &retry_state);
380
381            match self.bound.invoke(input.clone(), Some(patched_config)) {
382                Ok(output) => {
383                    run_manager.on_chain_end(&std::collections::HashMap::new());
384                    return Ok(output);
385                }
386                Err(e) => {
387                    if !self.should_retry(&e) || attempt == self.config.max_attempt_number {
388                        run_manager.on_chain_error(&e);
389                        return Err(e);
390                    }
391                    last_error = Some(e);
392
393                    // Wait before next attempt
394                    if self.config.wait_exponential_jitter
395                        && attempt < self.config.max_attempt_number
396                    {
397                        let wait = self.calculate_wait(attempt);
398                        std::thread::sleep(wait);
399                    }
400                }
401            }
402        }
403
404        let error = last_error.unwrap_or_else(|| Error::other("Max retries exceeded"));
405        run_manager.on_chain_error(&error);
406        Err(error)
407    }
408
409    async fn ainvoke(
410        &self,
411        input: Self::Input,
412        config: Option<RunnableConfig>,
413    ) -> Result<Self::Output>
414    where
415        Self: 'static,
416    {
417        let config = ensure_config(config);
418        let callback_manager = get_callback_manager_for_config(&config);
419
420        // Start the chain run
421        let run_manager = callback_manager.on_chain_start(
422            &std::collections::HashMap::new(),
423            &std::collections::HashMap::new(),
424            config.run_id,
425        );
426
427        let mut last_error = None;
428
429        for attempt in 1..=self.config.max_attempt_number {
430            let retry_state = RetryCallState::new(attempt);
431            let patched_config = Self::patch_config_for_retry(&config, &run_manager, &retry_state);
432
433            match self
434                .bound
435                .ainvoke(input.clone(), Some(patched_config))
436                .await
437            {
438                Ok(output) => {
439                    run_manager.on_chain_end(&std::collections::HashMap::new());
440                    return Ok(output);
441                }
442                Err(e) => {
443                    if !self.should_retry(&e) || attempt == self.config.max_attempt_number {
444                        run_manager.on_chain_error(&e);
445                        return Err(e);
446                    }
447                    last_error = Some(e);
448
449                    // Wait before next attempt
450                    if self.config.wait_exponential_jitter
451                        && attempt < self.config.max_attempt_number
452                    {
453                        let wait = self.calculate_wait(attempt);
454                        tokio::time::sleep(wait).await;
455                    }
456                }
457            }
458        }
459
460        let error = last_error.unwrap_or_else(|| Error::other("Max retries exceeded"));
461        run_manager.on_chain_error(&error);
462        Err(error)
463    }
464
465    fn batch(
466        &self,
467        inputs: Vec<Self::Input>,
468        config: Option<ConfigOrList>,
469        return_exceptions: bool,
470    ) -> Vec<Result<Self::Output>>
471    where
472        Self: 'static,
473    {
474        if inputs.is_empty() {
475            return Vec::new();
476        }
477
478        let configs = get_config_list(config, inputs.len());
479        let n = inputs.len();
480
481        // Create callback managers and start chain runs for each input
482        let run_managers: Vec<CallbackManagerForChainRun> = configs
483            .iter()
484            .map(|config| {
485                let callback_manager = get_callback_manager_for_config(config);
486                callback_manager.on_chain_start(
487                    &std::collections::HashMap::new(),
488                    &std::collections::HashMap::new(),
489                    config.run_id,
490                )
491            })
492            .collect();
493
494        // Track results: None means not yet successful
495        let mut results: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
496
497        // Track which inputs still need to be processed
498        let mut remaining: Vec<usize> = (0..n).collect();
499
500        for attempt in 1..=self.config.max_attempt_number {
501            if remaining.is_empty() {
502                break;
503            }
504
505            let retry_state = RetryCallState::new(attempt);
506
507            // Prepare inputs and configs for remaining items
508            let pending_inputs: Vec<Self::Input> =
509                remaining.iter().map(|&i| inputs[i].clone()).collect();
510            let pending_configs: Vec<RunnableConfig> =
511                remaining.iter().map(|&i| configs[i].clone()).collect();
512            let pending_managers: Vec<CallbackManagerForChainRun> =
513                remaining.iter().map(|&i| run_managers[i].clone()).collect();
514
515            let patched_configs = Self::patch_config_list_for_retry(
516                &pending_configs,
517                &pending_managers,
518                &retry_state,
519            );
520
521            // Invoke the batch on remaining items
522            let batch_results = self.bound.batch(
523                pending_inputs,
524                Some(ConfigOrList::List(patched_configs)),
525                true, // Always return exceptions to handle ourselves
526            );
527
528            // Process results
529            let mut next_remaining = Vec::new();
530            let mut first_non_retryable_error: Option<Error> = None;
531
532            for (offset, result) in batch_results.into_iter().enumerate() {
533                let orig_idx = remaining[offset];
534
535                match result {
536                    Ok(output) => {
537                        results[orig_idx] = Some(Ok(output));
538                    }
539                    Err(e) => {
540                        if self.should_retry(&e) && attempt < self.config.max_attempt_number {
541                            // Will retry this one
542                            results[orig_idx] = Some(Err(e));
543                            next_remaining.push(orig_idx);
544                        } else if !self.should_retry(&e) && !return_exceptions {
545                            // Non-retryable error and we're not returning exceptions
546                            if first_non_retryable_error.is_none() {
547                                first_non_retryable_error = Some(e);
548                            }
549                            results[orig_idx] = Some(Err(Error::other("Batch aborted")));
550                        } else {
551                            // Final attempt or returning exceptions
552                            results[orig_idx] = Some(Err(e));
553                        }
554                    }
555                }
556            }
557
558            // If we had a non-retryable error and we're not returning exceptions, abort
559            if first_non_retryable_error.is_some() && !return_exceptions {
560                // Fill remaining results with errors
561                for result in results.iter_mut().take(n) {
562                    if result.is_none() {
563                        *result = Some(Err(Error::other("Batch aborted due to error")));
564                    }
565                }
566                break;
567            }
568
569            remaining = next_remaining;
570
571            // Wait before next attempt if there are remaining items
572            if !remaining.is_empty()
573                && self.config.wait_exponential_jitter
574                && attempt < self.config.max_attempt_number
575            {
576                let wait = self.calculate_wait(attempt);
577                std::thread::sleep(wait);
578            }
579        }
580
581        // Convert results, using error for any None values
582        results
583            .into_iter()
584            .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result"))))
585            .collect()
586    }
587
588    async fn abatch(
589        &self,
590        inputs: Vec<Self::Input>,
591        config: Option<ConfigOrList>,
592        return_exceptions: bool,
593    ) -> Vec<Result<Self::Output>>
594    where
595        Self: 'static,
596    {
597        if inputs.is_empty() {
598            return Vec::new();
599        }
600
601        let configs = get_config_list(config, inputs.len());
602        let n = inputs.len();
603
604        // Create callback managers and start chain runs for each input
605        let run_managers: Vec<CallbackManagerForChainRun> = configs
606            .iter()
607            .map(|config| {
608                let callback_manager = get_callback_manager_for_config(config);
609                callback_manager.on_chain_start(
610                    &std::collections::HashMap::new(),
611                    &std::collections::HashMap::new(),
612                    config.run_id,
613                )
614            })
615            .collect();
616
617        // Track results: None means not yet successful
618        let mut results: Vec<Option<Result<Self::Output>>> = (0..n).map(|_| None).collect();
619
620        // Track which inputs still need to be processed
621        let mut remaining: Vec<usize> = (0..n).collect();
622
623        for attempt in 1..=self.config.max_attempt_number {
624            if remaining.is_empty() {
625                break;
626            }
627
628            let retry_state = RetryCallState::new(attempt);
629
630            // Prepare inputs and configs for remaining items
631            let pending_inputs: Vec<Self::Input> =
632                remaining.iter().map(|&i| inputs[i].clone()).collect();
633            let pending_configs: Vec<RunnableConfig> =
634                remaining.iter().map(|&i| configs[i].clone()).collect();
635            let pending_managers: Vec<CallbackManagerForChainRun> =
636                remaining.iter().map(|&i| run_managers[i].clone()).collect();
637
638            let patched_configs = Self::patch_config_list_for_retry(
639                &pending_configs,
640                &pending_managers,
641                &retry_state,
642            );
643
644            // Invoke the batch on remaining items
645            let batch_results = self
646                .bound
647                .abatch(
648                    pending_inputs,
649                    Some(ConfigOrList::List(patched_configs)),
650                    true, // Always return exceptions to handle ourselves
651                )
652                .await;
653
654            // Process results
655            let mut next_remaining = Vec::new();
656            let mut first_non_retryable_error: Option<Error> = None;
657
658            for (offset, result) in batch_results.into_iter().enumerate() {
659                let orig_idx = remaining[offset];
660
661                match result {
662                    Ok(output) => {
663                        results[orig_idx] = Some(Ok(output));
664                    }
665                    Err(e) => {
666                        if self.should_retry(&e) && attempt < self.config.max_attempt_number {
667                            // Will retry this one
668                            results[orig_idx] = Some(Err(e));
669                            next_remaining.push(orig_idx);
670                        } else if !self.should_retry(&e) && !return_exceptions {
671                            // Non-retryable error and we're not returning exceptions
672                            if first_non_retryable_error.is_none() {
673                                first_non_retryable_error = Some(e);
674                            }
675                            results[orig_idx] = Some(Err(Error::other("Batch aborted")));
676                        } else {
677                            // Final attempt or returning exceptions
678                            results[orig_idx] = Some(Err(e));
679                        }
680                    }
681                }
682            }
683
684            // If we had a non-retryable error and we're not returning exceptions, abort
685            if first_non_retryable_error.is_some() && !return_exceptions {
686                // Fill remaining results with errors
687                for result in results.iter_mut().take(n) {
688                    if result.is_none() {
689                        *result = Some(Err(Error::other("Batch aborted due to error")));
690                    }
691                }
692                break;
693            }
694
695            remaining = next_remaining;
696
697            // Wait before next attempt if there are remaining items
698            if !remaining.is_empty()
699                && self.config.wait_exponential_jitter
700                && attempt < self.config.max_attempt_number
701            {
702                let wait = self.calculate_wait(attempt);
703                tokio::time::sleep(wait).await;
704            }
705        }
706
707        // Convert results, using error for any None values
708        results
709            .into_iter()
710            .map(|opt| opt.unwrap_or_else(|| Err(Error::other("No result"))))
711            .collect()
712    }
713
714    // Note: stream() and transform() are not retried because retrying a stream
715    // is not very intuitive, matching the Python implementation.
716}
717
718/// Extension trait to add retry configuration method to any Runnable.
719pub trait RunnableRetryExt: Runnable {
720    /// Create a new Runnable that retries this runnable on failure with full config.
721    ///
722    /// # Arguments
723    /// * `config` - Retry configuration
724    ///
725    /// # Returns
726    /// A new `RunnableRetry` instance
727    fn with_retry_config(self, config: RunnableRetryConfig) -> RunnableRetry<Self>
728    where
729        Self: Sized,
730    {
731        RunnableRetry::new(self, config)
732    }
733}
734
735// Implement the extension trait for all Runnables
736impl<R: Runnable> RunnableRetryExt for R {}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use crate::runnables::base::RunnableLambda;
742    use std::sync::Arc;
743    use std::sync::atomic::{AtomicUsize, Ordering};
744
745    #[test]
746    fn test_retry_succeeds_first_attempt() {
747        let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
748        let config = RunnableRetryConfig::new()
749            .with_max_attempt_number(3)
750            .with_wait_exponential_jitter(false);
751        let retry = RunnableRetry::new(runnable, config);
752
753        let result = retry.invoke(1, None).unwrap();
754        assert_eq!(result, 2);
755    }
756
757    #[test]
758    fn test_retry_succeeds_after_failures() {
759        let counter = Arc::new(AtomicUsize::new(0));
760        let counter_clone = counter.clone();
761
762        let runnable = RunnableLambda::new(move |x: i32| {
763            let count = counter_clone.fetch_add(1, Ordering::SeqCst);
764            if count < 2 {
765                Err(Error::other("transient failure"))
766            } else {
767                Ok(x * 2)
768            }
769        });
770
771        let config = RunnableRetryConfig::new()
772            .with_max_attempt_number(5)
773            .with_wait_exponential_jitter(false);
774        let retry = RunnableRetry::new(runnable, config);
775
776        let result = retry.invoke(5, None).unwrap();
777        assert_eq!(result, 10);
778        assert_eq!(counter.load(Ordering::SeqCst), 3);
779    }
780
781    #[test]
782    fn test_retry_exhausted() {
783        let counter = Arc::new(AtomicUsize::new(0));
784        let counter_clone = counter.clone();
785
786        let runnable = RunnableLambda::new(move |_x: i32| {
787            counter_clone.fetch_add(1, Ordering::SeqCst);
788            Err::<i32, _>(Error::other("always fails"))
789        });
790
791        let config = RunnableRetryConfig::new()
792            .with_max_attempt_number(3)
793            .with_wait_exponential_jitter(false);
794        let retry = RunnableRetry::new(runnable, config);
795
796        let result = retry.invoke(1, None);
797        assert!(result.is_err());
798        assert_eq!(counter.load(Ordering::SeqCst), 3);
799    }
800
801    #[test]
802    fn test_retry_predicate_http_errors() {
803        let counter = Arc::new(AtomicUsize::new(0));
804        let counter_clone = counter.clone();
805
806        // This will not retry because it's not an HTTP error
807        let runnable = RunnableLambda::new(move |_x: i32| {
808            counter_clone.fetch_add(1, Ordering::SeqCst);
809            Err::<i32, _>(Error::other("not an HTTP error"))
810        });
811
812        let config = RunnableRetryConfig::new()
813            .with_max_attempt_number(3)
814            .with_retry_predicate(RetryErrorPredicate::HttpErrors)
815            .with_wait_exponential_jitter(false);
816        let retry = RunnableRetry::new(runnable, config);
817
818        let result = retry.invoke(1, None);
819        assert!(result.is_err());
820        // Should only try once since it's not an HTTP error
821        assert_eq!(counter.load(Ordering::SeqCst), 1);
822    }
823
824    #[test]
825    fn test_exponential_jitter_params() {
826        let params = ExponentialJitterParams::new()
827            .with_initial(0.1)
828            .with_max(1.0)
829            .with_exp_base(2.0)
830            .with_jitter(0.0);
831
832        // Attempt 1: 0.1 * 2^0 = 0.1
833        let wait1 = params.calculate_wait(1);
834        assert!(wait1.as_secs_f64() >= 0.1 && wait1.as_secs_f64() < 0.2);
835
836        // Attempt 2: 0.1 * 2^1 = 0.2
837        let wait2 = params.calculate_wait(2);
838        assert!(wait2.as_secs_f64() >= 0.2 && wait2.as_secs_f64() < 0.3);
839
840        // Attempt 3: 0.1 * 2^2 = 0.4
841        let wait3 = params.calculate_wait(3);
842        assert!(wait3.as_secs_f64() >= 0.4 && wait3.as_secs_f64() < 0.5);
843    }
844
845    #[test]
846    fn test_exponential_jitter_max_cap() {
847        let params = ExponentialJitterParams::new()
848            .with_initial(1.0)
849            .with_max(2.0)
850            .with_exp_base(10.0)
851            .with_jitter(0.0);
852
853        // Large attempt should be capped at max
854        let wait = params.calculate_wait(10);
855        assert!(wait.as_secs_f64() >= 2.0 && wait.as_secs_f64() < 2.1);
856    }
857
858    #[test]
859    fn test_retry_ext_trait() {
860        let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
861        let config = RunnableRetryConfig::new().with_max_attempt_number(3);
862        let retry = runnable.with_retry_config(config);
863
864        let result = retry.invoke(1, None).unwrap();
865        assert_eq!(result, 2);
866    }
867
868    #[test]
869    fn test_retry_with_simple() {
870        let runnable = RunnableLambda::new(|x: i32| Ok(x + 1));
871        let retry = runnable.with_retry(3, false);
872
873        let result = retry.invoke(1, None).unwrap();
874        assert_eq!(result, 2);
875    }
876
877    #[test]
878    fn test_batch_retry_partial_failures() {
879        let counter = Arc::new(AtomicUsize::new(0));
880        let counter_clone = counter.clone();
881
882        // Fails for negative numbers on first two attempts
883        let runnable = RunnableLambda::new(move |x: i32| {
884            let count = counter_clone.fetch_add(1, Ordering::SeqCst);
885            if x < 0 && count < 4 {
886                Err(Error::other("negative input"))
887            } else {
888                Ok(x * 2)
889            }
890        });
891
892        let config = RunnableRetryConfig::new()
893            .with_max_attempt_number(3)
894            .with_wait_exponential_jitter(false);
895        let retry = RunnableRetry::new(runnable, config);
896
897        let results = retry.batch(vec![1, -1, 2], None, true);
898
899        // 1 and 2 should succeed on first try
900        // -1 should fail first 2 times, then succeed
901        assert!(results[0].is_ok());
902        assert!(results[2].is_ok());
903        // -1 might succeed or fail depending on retry order
904    }
905
906    #[tokio::test]
907    async fn test_async_retry() {
908        let counter = Arc::new(AtomicUsize::new(0));
909        let counter_clone = counter.clone();
910
911        let runnable = RunnableLambda::new(move |x: i32| {
912            let count = counter_clone.fetch_add(1, Ordering::SeqCst);
913            if count < 1 {
914                Err(Error::other("transient failure"))
915            } else {
916                Ok(x * 2)
917            }
918        });
919
920        let config = RunnableRetryConfig::new()
921            .with_max_attempt_number(3)
922            .with_wait_exponential_jitter(false);
923        let retry = RunnableRetry::new(runnable, config);
924
925        let result = retry.ainvoke(5, None).await.unwrap();
926        assert_eq!(result, 10);
927        assert_eq!(counter.load(Ordering::SeqCst), 2);
928    }
929}