1use crate::error::{LlmError, LlmResult};
10use crate::logging::{log_debug, log_error, log_warn};
11
12use std::time::{Duration, Instant};
13use tokio::time::sleep;
14
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct RetryPolicy {
18 pub max_attempts: u32,
20 pub initial_delay: Duration,
22 pub max_delay: Duration,
24 pub backoff_multiplier: f64,
26 pub total_timeout: Duration,
28 pub request_timeout: Duration,
30}
31
32impl Default for RetryPolicy {
33 fn default() -> Self {
34 Self {
35 max_attempts: 5,
36 initial_delay: Duration::from_secs(1),
37 max_delay: Duration::from_secs(16),
38 backoff_multiplier: 2.0,
39 total_timeout: Duration::from_secs(300), request_timeout: Duration::from_secs(120), }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub(crate) enum CircuitState {
48 Closed, Open, HalfOpen, }
52
53#[derive(Debug)]
55pub(crate) struct CircuitBreaker {
56 pub(crate) state: CircuitState,
57 pub(crate) failure_count: u32,
58 pub(crate) last_failure_time: Option<Instant>,
59 pub(crate) failure_threshold: u32,
60 pub(crate) recovery_timeout: Duration,
61}
62
63impl Default for CircuitBreaker {
64 fn default() -> Self {
65 Self {
66 state: CircuitState::Closed,
67 failure_count: 0,
68 last_failure_time: None,
69 failure_threshold: 5,
70 recovery_timeout: Duration::from_secs(30),
71 }
72 }
73}
74
75impl CircuitBreaker {
76 pub fn should_allow_request(&mut self) -> bool {
78 match self.state {
79 CircuitState::Closed => true,
80 CircuitState::Open => self.check_recovery_timeout(),
81 CircuitState::HalfOpen => true,
82 }
83 }
84
85 fn check_recovery_timeout(&mut self) -> bool {
87 let Some(last_failure) = self.last_failure_time else {
88 return false;
89 };
90
91 if last_failure.elapsed() >= self.recovery_timeout {
92 log_debug!(
93 circuit_breaker = "transitioning_to_half_open",
94 recovery_timeout_seconds = self.recovery_timeout.as_secs(),
95 "Circuit breaker attempting recovery"
96 );
97 self.state = CircuitState::HalfOpen;
98 true
99 } else {
100 false
101 }
102 }
103
104 pub fn record_success(&mut self) {
106 match self.state {
107 CircuitState::HalfOpen => {
108 log_debug!(
109 circuit_breaker = "recovered",
110 "Circuit breaker recovered, returning to closed state"
111 );
112 self.state = CircuitState::Closed;
113 self.failure_count = 0;
114 self.last_failure_time = None;
115 }
116 CircuitState::Closed => {
117 self.failure_count = 0;
118 }
119 CircuitState::Open => {
120 self.failure_count = 0;
122 self.last_failure_time = None;
123 }
124 }
125 }
126
127 pub fn record_failure(&mut self) {
129 self.failure_count += 1;
130 self.last_failure_time = Some(Instant::now());
131
132 if self.failure_count >= self.failure_threshold {
133 if self.state != CircuitState::Open {
134 log_warn!(
135 circuit_breaker = "opened",
136 failure_count = self.failure_count,
137 failure_threshold = self.failure_threshold,
138 recovery_timeout_seconds = self.recovery_timeout.as_secs(),
139 "Circuit breaker opened due to repeated failures"
140 );
141 }
142 self.state = CircuitState::Open;
143 }
144 }
145
146 pub fn state(&self) -> CircuitState {
148 self.state.clone()
149 }
150}
151
152#[derive(Debug)]
154pub(crate) struct RetryExecutor {
155 pub(crate) policy: RetryPolicy,
156 pub(crate) circuit_breaker: CircuitBreaker,
157}
158
159impl Default for RetryExecutor {
160 fn default() -> Self {
161 Self::new(RetryPolicy::default())
162 }
163}
164
165impl RetryExecutor {
166 pub fn new(policy: RetryPolicy) -> Self {
168 Self {
169 policy,
170 circuit_breaker: CircuitBreaker::default(),
171 }
172 }
173
174 pub async fn execute<F, Fut, T>(&mut self, operation: F) -> LlmResult<T>
176 where
177 F: Fn() -> Fut,
178 Fut: std::future::Future<Output = LlmResult<T>>,
179 {
180 let start_time = Instant::now();
181 let mut attempt = 0;
182 let mut last_error = None;
183
184 while attempt < self.policy.max_attempts {
185 self.check_circuit_breaker()?;
186 self.check_total_timeout(&start_time)?;
187
188 attempt += 1;
189
190 match self
191 .execute_single_attempt(&operation, attempt, &mut last_error)
192 .await
193 {
194 Ok(response) => return Ok(response),
195 Err(should_continue) => {
196 if !should_continue {
197 break;
198 }
199 }
200 }
201 }
202
203 self.handle_exhausted_retries(attempt, last_error, &start_time)
204 }
205
206 async fn execute_single_attempt<F, Fut, T>(
208 &mut self,
209 operation: &F,
210 attempt: u32,
211 last_error: &mut Option<LlmError>,
212 ) -> Result<T, bool>
213 where
214 F: Fn() -> Fut,
215 Fut: std::future::Future<Output = LlmResult<T>>,
216 {
217 self.log_attempt(attempt);
218
219 let operation_start = Instant::now();
220 let result = tokio::time::timeout(self.policy.request_timeout, operation()).await;
221
222 match result {
223 Ok(Ok(response)) => {
224 self.circuit_breaker.record_success();
225 log_debug!(
226 attempt = attempt,
227 duration_ms = operation_start.elapsed().as_millis(),
228 "Request succeeded"
229 );
230 Ok(response)
231 }
232 Ok(Err(error)) => {
233 let should_continue = self.handle_error(error, attempt, last_error).await;
234 Err(should_continue)
235 }
236 Err(_timeout) => {
237 let should_continue = self.handle_timeout(attempt, last_error).await;
238 Err(should_continue)
239 }
240 }
241 }
242
243 fn check_circuit_breaker(&mut self) -> LlmResult<()> {
244 if !self.circuit_breaker.should_allow_request() {
245 return Err(LlmError::request_failed(
246 "Circuit breaker is open - service temporarily unavailable".to_string(),
247 None,
248 ));
249 }
250 Ok(())
251 }
252
253 fn check_total_timeout(&mut self, start_time: &Instant) -> LlmResult<()> {
254 if start_time.elapsed() >= self.policy.total_timeout {
255 return Err(LlmError::timeout(self.policy.total_timeout.as_secs()));
256 }
257 Ok(())
258 }
259
260 fn log_attempt(&mut self, attempt: u32) {
261 log_debug!(
262 attempt = attempt,
263 max_attempts = self.policy.max_attempts,
264 circuit_state = ?self.circuit_breaker.state(),
265 "Executing request with retry logic"
266 );
267 }
268
269 async fn handle_error(
270 &mut self,
271 error: LlmError,
272 attempt: u32,
273 last_error: &mut Option<LlmError>,
274 ) -> bool {
275 let should_retry = self.should_retry_error(&error);
276 *last_error = Some(error);
277
278 if should_retry && attempt < self.policy.max_attempts {
279 self.circuit_breaker.record_failure();
280 let delay = self.calculate_delay(attempt);
281 log_debug!(
282 attempt = attempt,
283 max_attempts = self.policy.max_attempts,
284 delay_ms = delay.as_millis(),
285 error = ?last_error.as_ref(),
286 "Request failed, retrying after delay"
287 );
288 sleep(delay).await;
289 true
290 } else {
291 self.circuit_breaker.record_failure();
292 false
293 }
294 }
295
296 async fn handle_timeout(&mut self, attempt: u32, last_error: &mut Option<LlmError>) -> bool {
297 let timeout_error = LlmError::timeout(self.policy.request_timeout.as_secs());
298 *last_error = Some(timeout_error);
299
300 if attempt < self.policy.max_attempts {
301 self.circuit_breaker.record_failure();
302 let delay = self.calculate_delay(attempt);
303 log_debug!(
304 attempt = attempt,
305 max_attempts = self.policy.max_attempts,
306 delay_ms = delay.as_millis(),
307 timeout_seconds = self.policy.request_timeout.as_secs(),
308 "Request timed out, retrying after delay"
309 );
310 sleep(delay).await;
311 true
312 } else {
313 self.circuit_breaker.record_failure();
314 false
315 }
316 }
317
318 fn handle_exhausted_retries<T>(
319 &mut self,
320 attempt: u32,
321 last_error: Option<LlmError>,
322 start_time: &Instant,
323 ) -> LlmResult<T> {
324 let final_error = last_error.unwrap_or_else(|| {
325 LlmError::request_failed("Maximum retry attempts exceeded".to_string(), None)
326 });
327
328 log_error!(
329 attempts = attempt,
330 total_duration_ms = start_time.elapsed().as_millis(),
331 circuit_state = ?self.circuit_breaker.state(),
332 error = %final_error,
333 "Request failed after all retry attempts"
334 );
335
336 Err(final_error)
337 }
338
339 fn should_retry_error(&self, error: &LlmError) -> bool {
341 match error {
342 LlmError::RequestFailed { .. } => true,
343 LlmError::Timeout { .. } => true,
344 LlmError::RateLimitExceeded { .. } => true,
345 LlmError::AuthenticationFailed { .. } => false, LlmError::ConfigurationError { .. } => false, LlmError::TokenLimitExceeded { .. } => false, LlmError::UnsupportedProvider { .. } => false, LlmError::ResponseParsingError { .. } => false, LlmError::ToolExecutionFailed { .. } => false, LlmError::SchemaValidationFailed { .. } => false, }
353 }
354
355 pub fn calculate_delay(&self, attempt: u32) -> Duration {
357 let delay_seconds = self.policy.initial_delay.as_secs_f64()
358 * self.policy.backoff_multiplier.powi((attempt - 1) as i32);
359
360 let delay = Duration::from_secs_f64(delay_seconds.min(self.policy.max_delay.as_secs_f64()));
361
362 let jitter = fastrand::f64() * 0.1; Duration::from_secs_f64(delay.as_secs_f64() * (1.0 + jitter))
365 }
366}