codex_memory/mcp/
retry.rs1use std::future::Future;
2use std::time::Duration;
3use tokio::time::sleep;
4use tracing::{debug, warn};
5
6#[derive(Debug, Clone)]
7pub struct RetryConfig {
8 pub max_attempts: u32,
9 pub initial_delay: Duration,
10 pub max_delay: Duration,
11 pub exponential_base: f64,
12 pub jitter: bool,
13}
14
15impl Default for RetryConfig {
16 fn default() -> Self {
17 Self {
18 max_attempts: 3,
19 initial_delay: Duration::from_millis(100),
20 max_delay: Duration::from_secs(10),
21 exponential_base: 2.0,
22 jitter: true,
23 }
24 }
25}
26
27pub struct RetryPolicy {
28 config: RetryConfig,
29}
30
31impl RetryPolicy {
32 pub fn new(config: RetryConfig) -> Self {
33 Self { config }
34 }
35
36 pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
37 where
38 F: FnMut() -> Fut,
39 Fut: Future<Output = Result<T, E>>,
40 E: std::fmt::Display,
41 {
42 let mut attempt = 0;
43 let mut delay = self.config.initial_delay;
44
45 loop {
46 attempt += 1;
47
48 match f().await {
49 Ok(result) => {
50 if attempt > 1 {
51 debug!("Retry succeeded on attempt {}", attempt);
52 }
53 return Ok(result);
54 }
55 Err(error) if attempt >= self.config.max_attempts => {
56 warn!("All {} retry attempts exhausted", self.config.max_attempts);
57 return Err(error);
58 }
59 Err(error) => {
60 warn!(
61 "Attempt {} failed: {}. Retrying in {:?}",
62 attempt, error, delay
63 );
64
65 sleep(delay).await;
66
67 delay = self.calculate_next_delay(delay);
69 }
70 }
71 }
72 }
73
74 fn calculate_next_delay(&self, current_delay: Duration) -> Duration {
75 let mut next_delay =
76 Duration::from_secs_f64(current_delay.as_secs_f64() * self.config.exponential_base);
77
78 if self.config.jitter {
80 let jitter_amount = next_delay.as_secs_f64() * 0.1 * rand::random::<f64>();
81 next_delay = Duration::from_secs_f64(next_delay.as_secs_f64() + jitter_amount);
82 }
83
84 if next_delay > self.config.max_delay {
86 next_delay = self.config.max_delay;
87 }
88
89 next_delay
90 }
91
92 pub async fn execute_with_circuit_breaker<F, Fut, T, E>(
93 &self,
94 _circuit_breaker: &crate::mcp::circuit_breaker::CircuitBreaker,
95 f: F,
96 ) -> Result<T, E>
97 where
98 F: Fn() -> Fut + Clone,
99 Fut: Future<Output = Result<T, E>>,
100 E: std::fmt::Display,
101 {
102 self.execute(|| {
105 let f = f.clone();
106 f()
107 })
108 .await
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::sync::atomic::{AtomicU32, Ordering};
116 use std::sync::Arc;
117
118 #[tokio::test]
119 async fn test_retry_succeeds_on_second_attempt() {
120 let counter = Arc::new(AtomicU32::new(0));
121 let counter_clone = counter.clone();
122
123 let config = RetryConfig {
124 max_attempts: 3,
125 initial_delay: Duration::from_millis(10),
126 ..Default::default()
127 };
128
129 let policy = RetryPolicy::new(config);
130
131 let result = policy
132 .execute(|| {
133 let counter = counter_clone.clone();
134 async move {
135 let count = counter.fetch_add(1, Ordering::SeqCst);
136 if count == 0 {
137 Err("First attempt fails")
138 } else {
139 Ok("Success")
140 }
141 }
142 })
143 .await;
144
145 assert!(result.is_ok());
146 assert_eq!(counter.load(Ordering::SeqCst), 2);
147 }
148
149 #[tokio::test]
150 async fn test_retry_exhausts_attempts() {
151 let counter = Arc::new(AtomicU32::new(0));
152 let counter_clone = counter.clone();
153
154 let config = RetryConfig {
155 max_attempts: 2,
156 initial_delay: Duration::from_millis(10),
157 ..Default::default()
158 };
159
160 let policy = RetryPolicy::new(config);
161
162 let result: Result<(), &str> = policy
163 .execute(|| {
164 let counter = counter_clone.clone();
165 async move {
166 counter.fetch_add(1, Ordering::SeqCst);
167 Err("Always fails")
168 }
169 })
170 .await;
171
172 assert!(result.is_err());
173 assert_eq!(counter.load(Ordering::SeqCst), 2);
174 }
175
176 #[test]
177 fn test_calculate_next_delay() {
178 let config = RetryConfig {
179 exponential_base: 2.0,
180 max_delay: Duration::from_secs(5),
181 jitter: false,
182 ..Default::default()
183 };
184
185 let policy = RetryPolicy::new(config);
186
187 let delay1 = Duration::from_secs(1);
188 let delay2 = policy.calculate_next_delay(delay1);
189 assert_eq!(delay2, Duration::from_secs(2));
190
191 let delay3 = policy.calculate_next_delay(Duration::from_secs(3));
192 assert_eq!(delay3, Duration::from_secs(5)); }
194}