aidale_layer/
retry.rs

1//! Retry layer with exponential backoff.
2
3use aidale_core::error::AiError;
4use aidale_core::layer::{Layer, LayeredProvider};
5use aidale_core::provider::{ChatCompletionStream, Provider};
6use aidale_core::types::*;
7use async_trait::async_trait;
8use std::fmt::Debug;
9use std::sync::Arc;
10use std::time::Duration;
11
12/// Retry layer configuration
13#[derive(Debug, Clone)]
14pub struct RetryLayer {
15    max_retries: u32,
16    initial_delay: Duration,
17    max_delay: Duration,
18    backoff_multiplier: f64,
19}
20
21impl RetryLayer {
22    /// Create a new retry layer with default settings
23    pub fn new() -> Self {
24        Self {
25            max_retries: 3,
26            initial_delay: Duration::from_millis(100),
27            max_delay: Duration::from_secs(10),
28            backoff_multiplier: 2.0,
29        }
30    }
31
32    /// Set maximum number of retries
33    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
34        self.max_retries = max_retries;
35        self
36    }
37
38    /// Set initial delay
39    pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
40        self.initial_delay = initial_delay;
41        self
42    }
43
44    /// Set maximum delay
45    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
46        self.max_delay = max_delay;
47        self
48    }
49
50    /// Set backoff multiplier
51    pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
52        self.backoff_multiplier = multiplier;
53        self
54    }
55
56    /// Calculate delay for a given attempt
57    fn calculate_delay(&self, attempt: u32) -> Duration {
58        let delay_ms =
59            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
60        let delay = Duration::from_millis(delay_ms as u64);
61        delay.min(self.max_delay)
62    }
63}
64
65impl Default for RetryLayer {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl<P: Provider> Layer<P> for RetryLayer {
72    type LayeredProvider = RetryProvider<P>;
73
74    fn layer(&self, inner: P) -> Self::LayeredProvider {
75        RetryProvider {
76            inner,
77            config: self.clone(),
78        }
79    }
80}
81
82/// Provider wrapped with retry logic
83#[derive(Debug)]
84pub struct RetryProvider<P> {
85    inner: P,
86    config: RetryLayer,
87}
88
89impl<P: Provider> RetryProvider<P> {
90    /// Execute with retry logic
91    async fn execute_with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T, AiError>
92    where
93        F: FnMut() -> Fut,
94        Fut: std::future::Future<Output = Result<T, AiError>>,
95    {
96        let mut attempt = 0;
97
98        loop {
99            match operation().await {
100                Ok(result) => return Ok(result),
101                Err(e) => {
102                    if !e.is_retryable() || attempt >= self.config.max_retries {
103                        return Err(e);
104                    }
105
106                    let delay = self.config.calculate_delay(attempt);
107                    tracing::debug!(
108                        "Retry attempt {}/{}, waiting {:?}",
109                        attempt + 1,
110                        self.config.max_retries,
111                        delay
112                    );
113
114                    tokio::time::sleep(delay).await;
115                    attempt += 1;
116                }
117            }
118        }
119    }
120}
121
122#[async_trait]
123impl<P: Provider> LayeredProvider for RetryProvider<P> {
124    type Inner = P;
125
126    fn inner(&self) -> &Self::Inner {
127        &self.inner
128    }
129
130    async fn layered_chat_completion(
131        &self,
132        req: ChatCompletionRequest,
133    ) -> Result<ChatCompletionResponse, AiError> {
134        // Clone req for retry attempts
135        let req_clone = req.clone();
136        self.execute_with_retry(|| {
137            let req = req_clone.clone();
138            async move { self.inner.chat_completion(req).await }
139        })
140        .await
141    }
142
143    async fn layered_stream_chat_completion(
144        &self,
145        req: ChatCompletionRequest,
146    ) -> Result<Box<ChatCompletionStream>, AiError> {
147        // For streaming, we don't retry mid-stream - only retry the initial connection
148        let req_clone = req.clone();
149        self.execute_with_retry(|| {
150            let req = req_clone.clone();
151            async move { self.inner.stream_chat_completion(req).await }
152        })
153        .await
154    }
155}
156
157#[async_trait]
158impl<P: Provider> Provider for RetryProvider<P> {
159    fn info(&self) -> Arc<ProviderInfo> {
160        LayeredProvider::layered_info(self)
161    }
162
163    async fn chat_completion(
164        &self,
165        req: ChatCompletionRequest,
166    ) -> Result<ChatCompletionResponse, AiError> {
167        LayeredProvider::layered_chat_completion(self, req).await
168    }
169
170    async fn stream_chat_completion(
171        &self,
172        req: ChatCompletionRequest,
173    ) -> Result<Box<ChatCompletionStream>, AiError> {
174        LayeredProvider::layered_stream_chat_completion(self, req).await
175    }
176}