pipedash_plugin_api/
utils.rs

1use std::time::Duration;
2
3use crate::{
4    PluginError,
5    PluginResult,
6};
7
8pub struct RetryPolicy {
9    pub max_retries: usize,
10    pub initial_delay: Duration,
11    pub exponential_backoff: bool,
12}
13
14impl Default for RetryPolicy {
15    fn default() -> Self {
16        Self {
17            max_retries: 3,
18            initial_delay: Duration::from_millis(100),
19            exponential_backoff: true,
20        }
21    }
22}
23
24impl RetryPolicy {
25    pub fn new(max_retries: usize, initial_delay: Duration, exponential_backoff: bool) -> Self {
26        Self {
27            max_retries,
28            initial_delay,
29            exponential_backoff,
30        }
31    }
32
33    pub async fn retry<F, Fut, T>(&self, operation: F) -> PluginResult<T>
34    where
35        F: Fn() -> Fut,
36        Fut: std::future::Future<Output = PluginResult<T>>,
37    {
38        let mut delay = self.initial_delay;
39        let mut last_error = None;
40
41        for attempt in 0..self.max_retries {
42            match operation().await {
43                Ok(result) => return Ok(result),
44                Err(e) if attempt < self.max_retries - 1 => match &e {
45                    PluginError::NetworkError(_) | PluginError::ApiError(_) => {
46                        last_error = Some(e);
47                        tokio::time::sleep(delay).await;
48                        if self.exponential_backoff {
49                            delay *= 2;
50                        }
51                        continue;
52                    }
53                    _ => return Err(e),
54                },
55                Err(e) => {
56                    last_error = Some(e);
57                }
58            }
59        }
60
61        Err(last_error
62            .unwrap_or_else(|| PluginError::NetworkError("Max retries exceeded".to_string())))
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[tokio::test]
71    async fn test_retry_success() {
72        let policy = RetryPolicy::default();
73        let result = policy.retry(|| async { Ok::<_, PluginError>(42) }).await;
74        assert_eq!(result.unwrap(), 42);
75    }
76
77    #[tokio::test]
78    async fn test_retry_eventual_success() {
79        let policy = RetryPolicy::new(3, Duration::from_millis(10), false);
80        let attempts = std::cell::Cell::new(0);
81
82        let result = policy
83            .retry(|| async {
84                let count = attempts.get() + 1;
85                attempts.set(count);
86                if count < 2 {
87                    Err(PluginError::NetworkError("Temporary failure".to_string()))
88                } else {
89                    Ok(42)
90                }
91            })
92            .await;
93
94        assert_eq!(result.unwrap(), 42);
95    }
96}