Skip to main content

cognate_providers/
retry.rs

1//! Retry logic with exponential backoff
2//!
3//! Provides utilities for retrying failed LLM provider requests
4//! with customizable backoff strategies.
5
6use std::time::Duration;
7use cognate_core::{Error, Result};
8use futures::Future;
9
10/// Configuration for exponential-backoff retry logic.
11#[derive(Debug, Clone)]
12pub struct RetryConfig {
13    /// Maximum number of retry attempts after the initial failure.
14    pub max_retries: u32,
15    /// Minimum delay between retries.
16    pub min_delay: Duration,
17    /// Maximum delay cap — backoff will not exceed this value.
18    pub max_delay: Duration,
19    /// Exponential backoff multiplier applied after each failure.
20    pub factor: f64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            min_delay: Duration::from_millis(500),
28            max_delay: Duration::from_secs(30),
29            factor: 2.0,
30        }
31    }
32}
33
34/// Execute a request with retries
35pub async fn with_retry<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
36where
37    F: FnMut() -> Fut,
38    Fut: Future<Output = Result<T>>,
39{
40    let mut last_error = None;
41    let mut delay = config.min_delay;
42
43    for i in 0..=config.max_retries {
44        match f().await {
45            Ok(res) => return Ok(res),
46            Err(e) if e.is_retryable() && i < config.max_retries => {
47                let actual_delay = e.retry_after()
48                    .map(Duration::from_secs)
49                    .unwrap_or(delay);
50
51                tokio::time::sleep(actual_delay).await;
52
53                // Update delay for next iteration (exponential backoff)
54                delay = Duration::from_secs_f64(
55                    (delay.as_secs_f64() * config.factor).min(config.max_delay.as_secs_f64())
56                );
57                last_error = Some(e);
58            }
59            Err(e) => return Err(e),
60        }
61    }
62
63    Err(last_error.unwrap_or_else(|| Error::RetryExhausted(config.max_retries)))
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use std::sync::atomic::{AtomicU32, Ordering};
70    use std::sync::Arc;
71
72    #[tokio::test]
73    async fn test_retry_success() {
74        let config = RetryConfig::default();
75        let counter = Arc::new(AtomicU32::new(0));
76
77        let result = with_retry(&config, || {
78            let counter = counter.clone();
79            async move {
80                let val = counter.fetch_add(1, Ordering::SeqCst);
81                if val < 2 {
82                    Err(Error::Timeout(1))
83                } else {
84                    Ok("success")
85                }
86            }
87        }).await;
88
89        assert_eq!(result.unwrap(), "success");
90        assert_eq!(counter.load(Ordering::SeqCst), 3);
91    }
92
93    #[tokio::test]
94    async fn test_retry_failure() {
95        let config = RetryConfig {
96             max_retries: 2,
97             ..Default::default()
98        };
99        let counter = Arc::new(AtomicU32::new(0));
100
101        let result: Result<()> = with_retry(&config, || {
102            let counter = counter.clone();
103            async move {
104                counter.fetch_add(1, Ordering::SeqCst);
105                Err(Error::Timeout(1))
106            }
107        }).await;
108
109        assert!(result.is_err());
110        assert_eq!(counter.load(Ordering::SeqCst), 3);
111    }
112}