1use async_trait::async_trait;
4use backon::{ExponentialBuilder, Retryable};
5use std::sync::Arc;
6use std::time::Duration;
7
8use super::provider::{CompletionRequest, CompletionResponse, CompletionStream, LlmProvider};
9use crate::{Error, Result};
10
11pub struct RetryWrapper {
13 inner: Arc<dyn LlmProvider>,
14 max_attempts: u32,
15 initial_delay: Duration,
16 max_delay: Duration,
17}
18
19impl RetryWrapper {
20 pub fn new(provider: Arc<dyn LlmProvider>) -> Self {
28 Self {
29 inner: provider,
30 max_attempts: 3,
31 initial_delay: Duration::from_secs(1),
32 max_delay: Duration::from_secs(10),
33 }
34 }
35
36 pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
38 self.max_attempts = max_attempts;
39 self
40 }
41
42 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
44 self.initial_delay = delay;
45 self
46 }
47
48 pub fn with_max_delay(mut self, delay: Duration) -> Self {
50 self.max_delay = delay;
51 self
52 }
53
54 fn should_retry(error: &Error) -> bool {
56 error.is_retryable()
57 }
58}
59
60#[async_trait]
61impl LlmProvider for RetryWrapper {
62 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
63 let backoff = ExponentialBuilder::default()
64 .with_min_delay(self.initial_delay)
65 .with_max_delay(self.max_delay)
66 .with_max_times(self.max_attempts as usize);
67
68 let provider = self.inner.clone();
69 let request_clone = request.clone();
70
71 (|| async { provider.complete(request_clone.clone()).await })
73 .retry(backoff)
74 .when(Self::should_retry)
75 .await
76 }
77
78 async fn complete_streaming(&self, request: CompletionRequest) -> Result<CompletionStream> {
79 self.inner.complete_streaming(request).await
81 }
82}
83
84#[cfg(test)]
85#[allow(clippy::unwrap_used)]
86mod tests {
87 use super::*;
88 use crate::llm::MockLlmProvider;
89
90 #[tokio::test]
91 async fn test_retry_wrapper_success() {
92 let mock = Arc::new(MockLlmProvider::with_response("Success"));
93 let retry = RetryWrapper::new(mock);
94
95 let request = CompletionRequest::new(vec![crate::llm::Message::user("Test")]);
96 let response = retry.complete(request).await.unwrap();
97
98 assert_eq!(response.content, "Success");
99 }
100
101 #[test]
102 fn test_retry_wrapper_builder() {
103 let mock = Arc::new(MockLlmProvider::with_response("Test"));
104 let retry = RetryWrapper::new(mock)
105 .with_max_attempts(5)
106 .with_initial_delay(Duration::from_millis(500))
107 .with_max_delay(Duration::from_secs(30));
108
109 assert_eq!(retry.max_attempts, 5);
110 assert_eq!(retry.initial_delay, Duration::from_millis(500));
111 assert_eq!(retry.max_delay, Duration::from_secs(30));
112 }
113}