1use aidale_core::error::AiError;
4use aidale_core::layer::{Layer, LayeredProvider};
5use aidale_core::provider::{ChatCompletionStream, Provider};
6use aidale_core::types::*;
7use async_trait::async_trait;
8use std::fmt::Debug;
9use std::sync::Arc;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct RetryLayer {
15 max_retries: u32,
16 initial_delay: Duration,
17 max_delay: Duration,
18 backoff_multiplier: f64,
19}
20
21impl RetryLayer {
22 pub fn new() -> Self {
24 Self {
25 max_retries: 3,
26 initial_delay: Duration::from_millis(100),
27 max_delay: Duration::from_secs(10),
28 backoff_multiplier: 2.0,
29 }
30 }
31
32 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
34 self.max_retries = max_retries;
35 self
36 }
37
38 pub fn with_initial_delay(mut self, initial_delay: Duration) -> Self {
40 self.initial_delay = initial_delay;
41 self
42 }
43
44 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
46 self.max_delay = max_delay;
47 self
48 }
49
50 pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
52 self.backoff_multiplier = multiplier;
53 self
54 }
55
56 fn calculate_delay(&self, attempt: u32) -> Duration {
58 let delay_ms =
59 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
60 let delay = Duration::from_millis(delay_ms as u64);
61 delay.min(self.max_delay)
62 }
63}
64
65impl Default for RetryLayer {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl<P: Provider> Layer<P> for RetryLayer {
72 type LayeredProvider = RetryProvider<P>;
73
74 fn layer(&self, inner: P) -> Self::LayeredProvider {
75 RetryProvider {
76 inner,
77 config: self.clone(),
78 }
79 }
80}
81
82#[derive(Debug)]
84pub struct RetryProvider<P> {
85 inner: P,
86 config: RetryLayer,
87}
88
89impl<P: Provider> RetryProvider<P> {
90 async fn execute_with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T, AiError>
92 where
93 F: FnMut() -> Fut,
94 Fut: std::future::Future<Output = Result<T, AiError>>,
95 {
96 let mut attempt = 0;
97
98 loop {
99 match operation().await {
100 Ok(result) => return Ok(result),
101 Err(e) => {
102 if !e.is_retryable() || attempt >= self.config.max_retries {
103 return Err(e);
104 }
105
106 let delay = self.config.calculate_delay(attempt);
107 tracing::debug!(
108 "Retry attempt {}/{}, waiting {:?}",
109 attempt + 1,
110 self.config.max_retries,
111 delay
112 );
113
114 tokio::time::sleep(delay).await;
115 attempt += 1;
116 }
117 }
118 }
119 }
120}
121
122#[async_trait]
123impl<P: Provider> LayeredProvider for RetryProvider<P> {
124 type Inner = P;
125
126 fn inner(&self) -> &Self::Inner {
127 &self.inner
128 }
129
130 async fn layered_chat_completion(
131 &self,
132 req: ChatCompletionRequest,
133 ) -> Result<ChatCompletionResponse, AiError> {
134 let req_clone = req.clone();
136 self.execute_with_retry(|| {
137 let req = req_clone.clone();
138 async move { self.inner.chat_completion(req).await }
139 })
140 .await
141 }
142
143 async fn layered_stream_chat_completion(
144 &self,
145 req: ChatCompletionRequest,
146 ) -> Result<Box<ChatCompletionStream>, AiError> {
147 let req_clone = req.clone();
149 self.execute_with_retry(|| {
150 let req = req_clone.clone();
151 async move { self.inner.stream_chat_completion(req).await }
152 })
153 .await
154 }
155}
156
157#[async_trait]
158impl<P: Provider> Provider for RetryProvider<P> {
159 fn info(&self) -> Arc<ProviderInfo> {
160 LayeredProvider::layered_info(self)
161 }
162
163 async fn chat_completion(
164 &self,
165 req: ChatCompletionRequest,
166 ) -> Result<ChatCompletionResponse, AiError> {
167 LayeredProvider::layered_chat_completion(self, req).await
168 }
169
170 async fn stream_chat_completion(
171 &self,
172 req: ChatCompletionRequest,
173 ) -> Result<Box<ChatCompletionStream>, AiError> {
174 LayeredProvider::layered_stream_chat_completion(self, req).await
175 }
176}