codex_memory/mcp/
retry.rs

1use std::future::Future;
2use std::time::Duration;
3use tokio::time::sleep;
4use tracing::{debug, warn};
5
6#[derive(Debug, Clone)]
7pub struct RetryConfig {
8    pub max_attempts: u32,
9    pub initial_delay: Duration,
10    pub max_delay: Duration,
11    pub exponential_base: f64,
12    pub jitter: bool,
13}
14
15impl Default for RetryConfig {
16    fn default() -> Self {
17        Self {
18            max_attempts: 3,
19            initial_delay: Duration::from_millis(100),
20            max_delay: Duration::from_secs(10),
21            exponential_base: 2.0,
22            jitter: true,
23        }
24    }
25}
26
27pub struct RetryPolicy {
28    config: RetryConfig,
29}
30
31impl RetryPolicy {
32    pub fn new(config: RetryConfig) -> Self {
33        Self { config }
34    }
35
36    pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
37    where
38        F: FnMut() -> Fut,
39        Fut: Future<Output = Result<T, E>>,
40        E: std::fmt::Display,
41    {
42        let mut attempt = 0;
43        let mut delay = self.config.initial_delay;
44
45        loop {
46            attempt += 1;
47
48            match f().await {
49                Ok(result) => {
50                    if attempt > 1 {
51                        debug!("Retry succeeded on attempt {}", attempt);
52                    }
53                    return Ok(result);
54                }
55                Err(error) if attempt >= self.config.max_attempts => {
56                    warn!("All {} retry attempts exhausted", self.config.max_attempts);
57                    return Err(error);
58                }
59                Err(error) => {
60                    warn!(
61                        "Attempt {} failed: {}. Retrying in {:?}",
62                        attempt, error, delay
63                    );
64
65                    sleep(delay).await;
66
67                    // Calculate next delay with exponential backoff
68                    delay = self.calculate_next_delay(delay);
69                }
70            }
71        }
72    }
73
74    fn calculate_next_delay(&self, current_delay: Duration) -> Duration {
75        let mut next_delay =
76            Duration::from_secs_f64(current_delay.as_secs_f64() * self.config.exponential_base);
77
78        // Apply jitter if enabled
79        if self.config.jitter {
80            let jitter_amount = next_delay.as_secs_f64() * 0.1 * rand::random::<f64>();
81            next_delay = Duration::from_secs_f64(next_delay.as_secs_f64() + jitter_amount);
82        }
83
84        // Cap at max delay
85        if next_delay > self.config.max_delay {
86            next_delay = self.config.max_delay;
87        }
88
89        next_delay
90    }
91
92    pub async fn execute_with_circuit_breaker<F, Fut, T, E>(
93        &self,
94        _circuit_breaker: &crate::mcp::circuit_breaker::CircuitBreaker,
95        f: F,
96    ) -> Result<T, E>
97    where
98        F: Fn() -> Fut + Clone,
99        Fut: Future<Output = Result<T, E>>,
100        E: std::fmt::Display,
101    {
102        // For now, bypass circuit breaker and just use retry policy
103        // TODO: Implement proper async circuit breaker integration
104        self.execute(|| {
105            let f = f.clone();
106            f()
107        })
108        .await
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::sync::atomic::{AtomicU32, Ordering};
116    use std::sync::Arc;
117
118    #[tokio::test]
119    async fn test_retry_succeeds_on_second_attempt() {
120        let counter = Arc::new(AtomicU32::new(0));
121        let counter_clone = counter.clone();
122
123        let config = RetryConfig {
124            max_attempts: 3,
125            initial_delay: Duration::from_millis(10),
126            ..Default::default()
127        };
128
129        let policy = RetryPolicy::new(config);
130
131        let result = policy
132            .execute(|| {
133                let counter = counter_clone.clone();
134                async move {
135                    let count = counter.fetch_add(1, Ordering::SeqCst);
136                    if count == 0 {
137                        Err("First attempt fails")
138                    } else {
139                        Ok("Success")
140                    }
141                }
142            })
143            .await;
144
145        assert!(result.is_ok());
146        assert_eq!(counter.load(Ordering::SeqCst), 2);
147    }
148
149    #[tokio::test]
150    async fn test_retry_exhausts_attempts() {
151        let counter = Arc::new(AtomicU32::new(0));
152        let counter_clone = counter.clone();
153
154        let config = RetryConfig {
155            max_attempts: 2,
156            initial_delay: Duration::from_millis(10),
157            ..Default::default()
158        };
159
160        let policy = RetryPolicy::new(config);
161
162        let result: Result<(), &str> = policy
163            .execute(|| {
164                let counter = counter_clone.clone();
165                async move {
166                    counter.fetch_add(1, Ordering::SeqCst);
167                    Err("Always fails")
168                }
169            })
170            .await;
171
172        assert!(result.is_err());
173        assert_eq!(counter.load(Ordering::SeqCst), 2);
174    }
175
176    #[test]
177    fn test_calculate_next_delay() {
178        let config = RetryConfig {
179            exponential_base: 2.0,
180            max_delay: Duration::from_secs(5),
181            jitter: false,
182            ..Default::default()
183        };
184
185        let policy = RetryPolicy::new(config);
186
187        let delay1 = Duration::from_secs(1);
188        let delay2 = policy.calculate_next_delay(delay1);
189        assert_eq!(delay2, Duration::from_secs(2));
190
191        let delay3 = policy.calculate_next_delay(Duration::from_secs(3));
192        assert_eq!(delay3, Duration::from_secs(5)); // Capped at max_delay
193    }
194}