Skip to main content

composio_sdk/
retry.rs

1use crate::error::ComposioError;
2use std::time::Duration;
3use tokio_retry::strategy::ExponentialBackoff;
4
5/// Retry policy configuration for handling transient failures
6#[derive(Debug, Clone)]
7pub struct RetryPolicy {
8    /// Maximum number of retry attempts
9    pub max_retries: u32,
10    /// Initial delay before first retry
11    pub initial_delay: Duration,
12    /// Maximum delay between retries (caps exponential backoff)
13    pub max_delay: Duration,
14}
15
16impl Default for RetryPolicy {
17    fn default() -> Self {
18        Self {
19            max_retries: 3,
20            initial_delay: Duration::from_secs(1),
21            max_delay: Duration::from_secs(10),
22        }
23    }
24}
25
26impl RetryPolicy {
27    /// Creates an exponential backoff iterator for this policy
28    pub fn strategy(&self) -> impl Iterator<Item = Duration> {
29        ExponentialBackoff::from_millis(self.initial_delay.as_millis() as u64)
30            .max_delay(self.max_delay)
31            .take(self.max_retries as usize)
32    }
33}
34
35/// Execute an async operation with retry logic
36pub async fn with_retry<F, Fut, T>(
37    policy: &RetryPolicy,
38    operation: F,
39) -> Result<T, ComposioError>
40where
41    F: Fn() -> Fut,
42    Fut: std::future::Future<Output = Result<T, ComposioError>>,
43{
44    let mut last_error = None;
45    
46    for delay in std::iter::once(Duration::ZERO).chain(policy.strategy()) {
47        if delay > Duration::ZERO {
48            tokio::time::sleep(delay).await;
49        }
50        
51        match operation().await {
52            Ok(value) => return Ok(value),
53            Err(e) if should_retry(&e) => {
54                last_error = Some(e);
55                continue;
56            }
57            Err(e) => return Err(e),
58        }
59    }
60    
61    Err(last_error.unwrap())
62}
63
64/// Check if an error should be retried
65pub fn should_retry(error: &ComposioError) -> bool {
66    error.is_retryable()
67}
68
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn test_default_retry_policy() {
76        let policy = RetryPolicy::default();
77        assert_eq!(policy.max_retries, 3);
78        assert_eq!(policy.initial_delay, Duration::from_secs(1));
79        assert_eq!(policy.max_delay, Duration::from_secs(10));
80    }
81
82    #[test]
83    fn test_custom_retry_policy() {
84        let policy = RetryPolicy {
85            max_retries: 5,
86            initial_delay: Duration::from_millis(500),
87            max_delay: Duration::from_secs(30),
88        };
89
90        assert_eq!(policy.max_retries, 5);
91        assert_eq!(policy.initial_delay, Duration::from_millis(500));
92        assert_eq!(policy.max_delay, Duration::from_secs(30));
93    }
94
95    #[test]
96    fn test_strategy_yields_correct_number_of_delays() {
97        let policy = RetryPolicy {
98            max_retries: 3,
99            initial_delay: Duration::from_millis(100),
100            max_delay: Duration::from_secs(5),
101        };
102
103        let delays: Vec<_> = policy.strategy().collect();
104        assert_eq!(delays.len(), 3);
105    }
106
107    #[test]
108    fn test_strategy_respects_max_delay() {
109        let policy = RetryPolicy {
110            max_retries: 10,
111            initial_delay: Duration::from_secs(1),
112            max_delay: Duration::from_secs(5),
113        };
114
115        let delays: Vec<_> = policy.strategy().collect();
116        
117        for delay in delays {
118            assert!(delay <= policy.max_delay);
119        }
120    }
121
122    #[test]
123    fn test_should_retry_for_rate_limit() {
124        let error = ComposioError::ApiError {
125            status: 429,
126            message: "Rate limited".to_string(),
127            code: None,
128            slug: None,
129            request_id: None,
130            suggested_fix: None,
131            errors: None,
132        };
133
134        assert!(should_retry(&error));
135    }
136
137    #[test]
138    fn test_should_retry_for_server_errors() {
139        for status in [500, 502, 503, 504] {
140            let error = ComposioError::ApiError {
141                status,
142                message: "Server error".to_string(),
143                code: None,
144                slug: None,
145                request_id: None,
146                suggested_fix: None,
147                errors: None,
148            };
149
150            assert!(
151                should_retry(&error),
152                "Status {} should be retryable",
153                status
154            );
155        }
156    }
157
158    #[test]
159    fn test_should_not_retry_for_client_errors() {
160        for status in [400, 401, 403, 404] {
161            let error = ComposioError::ApiError {
162                status,
163                message: "Client error".to_string(),
164                code: None,
165                slug: None,
166                request_id: None,
167                suggested_fix: None,
168                errors: None,
169            };
170
171            assert!(
172                !should_retry(&error),
173                "Status {} should not be retryable",
174                status
175            );
176        }
177    }
178
179    #[test]
180    fn test_should_not_retry_for_serialization_error() {
181        let json_error = serde_json::from_str::<serde_json::Value>("invalid json")
182            .unwrap_err();
183        let error: ComposioError = json_error.into();
184
185        assert!(!should_retry(&error));
186    }
187
188    #[test]
189    fn test_should_not_retry_for_invalid_input() {
190        let error = ComposioError::InvalidInput("Invalid API key".to_string());
191        assert!(!should_retry(&error));
192    }
193
194    #[test]
195    fn test_should_not_retry_for_config_error() {
196        let error = ComposioError::ConfigError("Invalid base URL".to_string());
197        assert!(!should_retry(&error));
198    }
199
200    #[tokio::test]
201    async fn test_with_retry_succeeds_on_first_attempt() {
202        use std::sync::Arc;
203        use std::sync::atomic::{AtomicU32, Ordering};
204        
205        let policy = RetryPolicy::default();
206        let call_count = Arc::new(AtomicU32::new(0));
207        let call_count_clone = call_count.clone();
208
209        let result = with_retry(&policy, move || {
210            let count = call_count_clone.clone();
211            async move {
212                count.fetch_add(1, Ordering::SeqCst);
213                Ok::<_, ComposioError>("success")
214            }
215        })
216        .await;
217
218        assert!(result.is_ok());
219        assert_eq!(result.unwrap(), "success");
220        assert_eq!(call_count.load(Ordering::SeqCst), 1);
221    }
222
223    #[tokio::test]
224    async fn test_with_retry_succeeds_after_retries() {
225        use std::sync::Arc;
226        use std::sync::atomic::{AtomicU32, Ordering};
227        
228        let policy = RetryPolicy {
229            max_retries: 3,
230            initial_delay: Duration::from_millis(10),
231            max_delay: Duration::from_millis(50),
232        };
233        let call_count = Arc::new(AtomicU32::new(0));
234        let call_count_clone = call_count.clone();
235
236        let result = with_retry(&policy, move || {
237            let count = call_count_clone.clone();
238            async move {
239                let current = count.fetch_add(1, Ordering::SeqCst) + 1;
240                if current < 3 {
241                    Err(ComposioError::ApiError {
242                        status: 503,
243                        message: "Service unavailable".to_string(),
244                        code: None,
245                        slug: None,
246                        request_id: None,
247                        suggested_fix: None,
248                        errors: None,
249                    })
250                } else {
251                    Ok::<_, ComposioError>("success")
252                }
253            }
254        })
255        .await;
256
257        assert!(result.is_ok());
258        assert_eq!(result.unwrap(), "success");
259        assert_eq!(call_count.load(Ordering::SeqCst), 3);
260    }
261
262    #[tokio::test]
263    async fn test_with_retry_fails_after_max_retries() {
264        use std::sync::Arc;
265        use std::sync::atomic::{AtomicU32, Ordering};
266        
267        let policy = RetryPolicy {
268            max_retries: 2,
269            initial_delay: Duration::from_millis(10),
270            max_delay: Duration::from_millis(50),
271        };
272        let call_count = Arc::new(AtomicU32::new(0));
273        let call_count_clone = call_count.clone();
274
275        let result = with_retry(&policy, move || {
276            let count = call_count_clone.clone();
277            async move {
278                count.fetch_add(1, Ordering::SeqCst);
279                Err::<String, _>(ComposioError::ApiError {
280                    status: 503,
281                    message: "Service unavailable".to_string(),
282                    code: None,
283                    slug: None,
284                    request_id: None,
285                    suggested_fix: None,
286                    errors: None,
287                })
288            }
289        })
290        .await;
291
292        assert!(result.is_err());
293        assert_eq!(call_count.load(Ordering::SeqCst), 3);
294    }
295
296    #[tokio::test]
297    async fn test_with_retry_does_not_retry_non_retryable_errors() {
298        use std::sync::Arc;
299        use std::sync::atomic::{AtomicU32, Ordering};
300        
301        let policy = RetryPolicy::default();
302        let call_count = Arc::new(AtomicU32::new(0));
303        let call_count_clone = call_count.clone();
304
305        let result = with_retry(&policy, move || {
306            let count = call_count_clone.clone();
307            async move {
308                count.fetch_add(1, Ordering::SeqCst);
309                Err::<String, _>(ComposioError::ApiError {
310                    status: 404,
311                    message: "Not found".to_string(),
312                    code: None,
313                    slug: None,
314                    request_id: None,
315                    suggested_fix: None,
316                    errors: None,
317                })
318            }
319        })
320        .await;
321
322        assert!(result.is_err());
323        assert_eq!(call_count.load(Ordering::SeqCst), 1);
324    }
325}