1use crate::error::ComposioError;
2use std::time::Duration;
3use tokio_retry::strategy::ExponentialBackoff;
4
5#[derive(Debug, Clone)]
7pub struct RetryPolicy {
8 pub max_retries: u32,
10 pub initial_delay: Duration,
12 pub max_delay: Duration,
14}
15
16impl Default for RetryPolicy {
17 fn default() -> Self {
18 Self {
19 max_retries: 3,
20 initial_delay: Duration::from_secs(1),
21 max_delay: Duration::from_secs(10),
22 }
23 }
24}
25
26impl RetryPolicy {
27 pub fn strategy(&self) -> impl Iterator<Item = Duration> {
29 ExponentialBackoff::from_millis(self.initial_delay.as_millis() as u64)
30 .max_delay(self.max_delay)
31 .take(self.max_retries as usize)
32 }
33}
34
35pub async fn with_retry<F, Fut, T>(
37 policy: &RetryPolicy,
38 operation: F,
39) -> Result<T, ComposioError>
40where
41 F: Fn() -> Fut,
42 Fut: std::future::Future<Output = Result<T, ComposioError>>,
43{
44 let mut last_error = None;
45
46 for delay in std::iter::once(Duration::ZERO).chain(policy.strategy()) {
47 if delay > Duration::ZERO {
48 tokio::time::sleep(delay).await;
49 }
50
51 match operation().await {
52 Ok(value) => return Ok(value),
53 Err(e) if should_retry(&e) => {
54 last_error = Some(e);
55 continue;
56 }
57 Err(e) => return Err(e),
58 }
59 }
60
61 Err(last_error.unwrap())
62}
63
64pub fn should_retry(error: &ComposioError) -> bool {
66 error.is_retryable()
67}
68
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn test_default_retry_policy() {
76 let policy = RetryPolicy::default();
77 assert_eq!(policy.max_retries, 3);
78 assert_eq!(policy.initial_delay, Duration::from_secs(1));
79 assert_eq!(policy.max_delay, Duration::from_secs(10));
80 }
81
82 #[test]
83 fn test_custom_retry_policy() {
84 let policy = RetryPolicy {
85 max_retries: 5,
86 initial_delay: Duration::from_millis(500),
87 max_delay: Duration::from_secs(30),
88 };
89
90 assert_eq!(policy.max_retries, 5);
91 assert_eq!(policy.initial_delay, Duration::from_millis(500));
92 assert_eq!(policy.max_delay, Duration::from_secs(30));
93 }
94
95 #[test]
96 fn test_strategy_yields_correct_number_of_delays() {
97 let policy = RetryPolicy {
98 max_retries: 3,
99 initial_delay: Duration::from_millis(100),
100 max_delay: Duration::from_secs(5),
101 };
102
103 let delays: Vec<_> = policy.strategy().collect();
104 assert_eq!(delays.len(), 3);
105 }
106
107 #[test]
108 fn test_strategy_respects_max_delay() {
109 let policy = RetryPolicy {
110 max_retries: 10,
111 initial_delay: Duration::from_secs(1),
112 max_delay: Duration::from_secs(5),
113 };
114
115 let delays: Vec<_> = policy.strategy().collect();
116
117 for delay in delays {
118 assert!(delay <= policy.max_delay);
119 }
120 }
121
122 #[test]
123 fn test_should_retry_for_rate_limit() {
124 let error = ComposioError::ApiError {
125 status: 429,
126 message: "Rate limited".to_string(),
127 code: None,
128 slug: None,
129 request_id: None,
130 suggested_fix: None,
131 errors: None,
132 };
133
134 assert!(should_retry(&error));
135 }
136
137 #[test]
138 fn test_should_retry_for_server_errors() {
139 for status in [500, 502, 503, 504] {
140 let error = ComposioError::ApiError {
141 status,
142 message: "Server error".to_string(),
143 code: None,
144 slug: None,
145 request_id: None,
146 suggested_fix: None,
147 errors: None,
148 };
149
150 assert!(
151 should_retry(&error),
152 "Status {} should be retryable",
153 status
154 );
155 }
156 }
157
158 #[test]
159 fn test_should_not_retry_for_client_errors() {
160 for status in [400, 401, 403, 404] {
161 let error = ComposioError::ApiError {
162 status,
163 message: "Client error".to_string(),
164 code: None,
165 slug: None,
166 request_id: None,
167 suggested_fix: None,
168 errors: None,
169 };
170
171 assert!(
172 !should_retry(&error),
173 "Status {} should not be retryable",
174 status
175 );
176 }
177 }
178
179 #[test]
180 fn test_should_not_retry_for_serialization_error() {
181 let json_error = serde_json::from_str::<serde_json::Value>("invalid json")
182 .unwrap_err();
183 let error: ComposioError = json_error.into();
184
185 assert!(!should_retry(&error));
186 }
187
188 #[test]
189 fn test_should_not_retry_for_invalid_input() {
190 let error = ComposioError::InvalidInput("Invalid API key".to_string());
191 assert!(!should_retry(&error));
192 }
193
194 #[test]
195 fn test_should_not_retry_for_config_error() {
196 let error = ComposioError::ConfigError("Invalid base URL".to_string());
197 assert!(!should_retry(&error));
198 }
199
200 #[tokio::test]
201 async fn test_with_retry_succeeds_on_first_attempt() {
202 use std::sync::Arc;
203 use std::sync::atomic::{AtomicU32, Ordering};
204
205 let policy = RetryPolicy::default();
206 let call_count = Arc::new(AtomicU32::new(0));
207 let call_count_clone = call_count.clone();
208
209 let result = with_retry(&policy, move || {
210 let count = call_count_clone.clone();
211 async move {
212 count.fetch_add(1, Ordering::SeqCst);
213 Ok::<_, ComposioError>("success")
214 }
215 })
216 .await;
217
218 assert!(result.is_ok());
219 assert_eq!(result.unwrap(), "success");
220 assert_eq!(call_count.load(Ordering::SeqCst), 1);
221 }
222
223 #[tokio::test]
224 async fn test_with_retry_succeeds_after_retries() {
225 use std::sync::Arc;
226 use std::sync::atomic::{AtomicU32, Ordering};
227
228 let policy = RetryPolicy {
229 max_retries: 3,
230 initial_delay: Duration::from_millis(10),
231 max_delay: Duration::from_millis(50),
232 };
233 let call_count = Arc::new(AtomicU32::new(0));
234 let call_count_clone = call_count.clone();
235
236 let result = with_retry(&policy, move || {
237 let count = call_count_clone.clone();
238 async move {
239 let current = count.fetch_add(1, Ordering::SeqCst) + 1;
240 if current < 3 {
241 Err(ComposioError::ApiError {
242 status: 503,
243 message: "Service unavailable".to_string(),
244 code: None,
245 slug: None,
246 request_id: None,
247 suggested_fix: None,
248 errors: None,
249 })
250 } else {
251 Ok::<_, ComposioError>("success")
252 }
253 }
254 })
255 .await;
256
257 assert!(result.is_ok());
258 assert_eq!(result.unwrap(), "success");
259 assert_eq!(call_count.load(Ordering::SeqCst), 3);
260 }
261
262 #[tokio::test]
263 async fn test_with_retry_fails_after_max_retries() {
264 use std::sync::Arc;
265 use std::sync::atomic::{AtomicU32, Ordering};
266
267 let policy = RetryPolicy {
268 max_retries: 2,
269 initial_delay: Duration::from_millis(10),
270 max_delay: Duration::from_millis(50),
271 };
272 let call_count = Arc::new(AtomicU32::new(0));
273 let call_count_clone = call_count.clone();
274
275 let result = with_retry(&policy, move || {
276 let count = call_count_clone.clone();
277 async move {
278 count.fetch_add(1, Ordering::SeqCst);
279 Err::<String, _>(ComposioError::ApiError {
280 status: 503,
281 message: "Service unavailable".to_string(),
282 code: None,
283 slug: None,
284 request_id: None,
285 suggested_fix: None,
286 errors: None,
287 })
288 }
289 })
290 .await;
291
292 assert!(result.is_err());
293 assert_eq!(call_count.load(Ordering::SeqCst), 3);
294 }
295
296 #[tokio::test]
297 async fn test_with_retry_does_not_retry_non_retryable_errors() {
298 use std::sync::Arc;
299 use std::sync::atomic::{AtomicU32, Ordering};
300
301 let policy = RetryPolicy::default();
302 let call_count = Arc::new(AtomicU32::new(0));
303 let call_count_clone = call_count.clone();
304
305 let result = with_retry(&policy, move || {
306 let count = call_count_clone.clone();
307 async move {
308 count.fetch_add(1, Ordering::SeqCst);
309 Err::<String, _>(ComposioError::ApiError {
310 status: 404,
311 message: "Not found".to_string(),
312 code: None,
313 slug: None,
314 request_id: None,
315 suggested_fix: None,
316 errors: None,
317 })
318 }
319 })
320 .await;
321
322 assert!(result.is_err());
323 assert_eq!(call_count.load(Ordering::SeqCst), 1);
324 }
325}