1use crate::{
4 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5 LlmResponse, LlmStream, Result, StreamingLlmProvider,
6};
7use async_trait::async_trait;
8
9pub struct FallbackProvider<P> {
11 providers: Vec<P>,
12 retry_all_errors: bool,
14}
15
16impl<P> FallbackProvider<P> {
17 pub fn new(providers: Vec<P>) -> Self {
22 assert!(!providers.is_empty(), "Must provide at least one provider");
23 Self {
24 providers,
25 retry_all_errors: false,
26 }
27 }
28
29 pub fn with_retry_all_errors(mut self, retry_all: bool) -> Self {
31 self.retry_all_errors = retry_all;
32 self
33 }
34
35 pub fn provider_count(&self) -> usize {
37 self.providers.len()
38 }
39
40 fn should_fallback(&self, error: &LlmError) -> bool {
42 if self.retry_all_errors {
43 true
44 } else {
45 matches!(
47 error,
48 LlmError::RateLimited(_)
49 | LlmError::NetworkError(_)
50 | LlmError::ApiError(_)
51 | LlmError::Timeout(_)
52 )
53 }
54 }
55}
56
57#[async_trait]
58impl<P: LlmProvider> LlmProvider for FallbackProvider<P> {
59 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
60 let mut last_error = None;
61
62 for (idx, provider) in self.providers.iter().enumerate() {
63 match provider.complete(request.clone()).await {
64 Ok(response) => {
65 if idx > 0 {
66 tracing::info!(
67 provider_index = idx,
68 "Successfully failed over to alternative provider"
69 );
70 }
71 return Ok(response);
72 }
73 Err(e) => {
74 if self.should_fallback(&e) && idx < self.providers.len() - 1 {
75 tracing::warn!(
76 provider_index = idx,
77 error = %e,
78 "Provider failed, trying next provider"
79 );
80 last_error = Some(e);
81 } else {
82 return Err(e);
83 }
84 }
85 }
86 }
87
88 Err(last_error.unwrap_or(LlmError::ApiError("All providers failed".to_string())))
89 }
90}
91
92#[async_trait]
93impl<P: StreamingLlmProvider> StreamingLlmProvider for FallbackProvider<P> {
94 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
95 let mut last_error = None;
96
97 for (idx, provider) in self.providers.iter().enumerate() {
98 match provider.complete_stream(request.clone()).await {
99 Ok(stream) => {
100 if idx > 0 {
101 tracing::info!(
102 provider_index = idx,
103 "Successfully failed over to alternative provider for streaming"
104 );
105 }
106 return Ok(stream);
107 }
108 Err(e) => {
109 if self.should_fallback(&e) && idx < self.providers.len() - 1 {
110 tracing::warn!(
111 provider_index = idx,
112 error = %e,
113 "Provider failed for streaming, trying next provider"
114 );
115 last_error = Some(e);
116 } else {
117 return Err(e);
118 }
119 }
120 }
121 }
122
123 Err(last_error.unwrap_or(LlmError::ApiError("All providers failed".to_string())))
124 }
125}
126
127#[async_trait]
128impl<P: EmbeddingProvider> EmbeddingProvider for FallbackProvider<P> {
129 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
130 let mut last_error = None;
131
132 for (idx, provider) in self.providers.iter().enumerate() {
133 match provider.embed(request.clone()).await {
134 Ok(response) => {
135 if idx > 0 {
136 tracing::info!(
137 provider_index = idx,
138 "Successfully failed over to alternative embedding provider"
139 );
140 }
141 return Ok(response);
142 }
143 Err(e) => {
144 if self.should_fallback(&e) && idx < self.providers.len() - 1 {
145 tracing::warn!(
146 provider_index = idx,
147 error = %e,
148 "Embedding provider failed, trying next provider"
149 );
150 last_error = Some(e);
151 } else {
152 return Err(e);
153 }
154 }
155 }
156 }
157
158 Err(last_error.unwrap_or(LlmError::ApiError("All providers failed".to_string())))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use std::time::Duration;
166
167 #[derive(Clone)]
168 #[allow(dead_code)]
169 enum MockErrorType {
170 RateLimited,
171 RateLimitedWithDelay(Duration),
172 ApiError,
173 InvalidRequest,
174 Timeout,
175 }
176
177 struct MockProvider {
178 should_fail: bool,
179 fail_with: MockErrorType,
180 }
181
182 #[async_trait]
183 impl LlmProvider for MockProvider {
184 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
185 if self.should_fail {
186 let err = match &self.fail_with {
187 MockErrorType::RateLimited => LlmError::RateLimited(None),
188 MockErrorType::RateLimitedWithDelay(d) => LlmError::RateLimited(Some(*d)),
189 MockErrorType::ApiError => LlmError::ApiError("API error".to_string()),
190 MockErrorType::InvalidRequest => {
191 LlmError::InvalidRequest("bad request".to_string())
192 }
193 MockErrorType::Timeout => LlmError::Timeout(Duration::from_secs(30)),
194 };
195 Err(err)
196 } else {
197 Ok(LlmResponse {
198 content: format!("Response to: {}", request.prompt),
199 model: "mock".to_string(),
200 usage: None,
201 tool_calls: Vec::new(),
202 })
203 }
204 }
205 }
206
207 #[tokio::test]
208 async fn test_fallback_first_provider_success() {
209 let provider1 = MockProvider {
210 should_fail: false,
211 fail_with: MockErrorType::RateLimited,
212 };
213 let provider2 = MockProvider {
214 should_fail: false,
215 fail_with: MockErrorType::RateLimited,
216 };
217
218 let fallback = FallbackProvider::new(vec![provider1, provider2]);
219
220 let request = LlmRequest {
221 prompt: "test".to_string(),
222 system_prompt: None,
223 temperature: None,
224 max_tokens: None,
225 tools: Vec::new(),
226 images: Vec::new(),
227 };
228
229 let result = fallback.complete(request).await;
230 assert!(result.is_ok());
231 assert_eq!(result.unwrap().content, "Response to: test");
232 }
233
234 #[tokio::test]
235 async fn test_fallback_to_second_provider() {
236 let provider1 = MockProvider {
237 should_fail: true,
238 fail_with: MockErrorType::RateLimitedWithDelay(Duration::from_secs(5)),
239 };
240 let provider2 = MockProvider {
241 should_fail: false,
242 fail_with: MockErrorType::RateLimited,
243 };
244
245 let fallback = FallbackProvider::new(vec![provider1, provider2]);
246
247 let request = LlmRequest {
248 prompt: "test".to_string(),
249 system_prompt: None,
250 temperature: None,
251 max_tokens: None,
252 tools: Vec::new(),
253 images: Vec::new(),
254 };
255
256 let result = fallback.complete(request).await;
257 assert!(result.is_ok());
258 assert_eq!(result.unwrap().content, "Response to: test");
259 }
260
261 #[tokio::test]
262 async fn test_fallback_all_providers_fail() {
263 let provider1 = MockProvider {
264 should_fail: true,
265 fail_with: MockErrorType::RateLimited,
266 };
267 let provider2 = MockProvider {
268 should_fail: true,
269 fail_with: MockErrorType::ApiError,
270 };
271
272 let fallback = FallbackProvider::new(vec![provider1, provider2]);
273
274 let request = LlmRequest {
275 prompt: "test".to_string(),
276 system_prompt: None,
277 temperature: None,
278 max_tokens: None,
279 tools: Vec::new(),
280 images: Vec::new(),
281 };
282
283 let result = fallback.complete(request).await;
284 assert!(result.is_err());
285 }
286
287 #[tokio::test]
288 async fn test_fallback_non_retryable_error() {
289 let provider1 = MockProvider {
290 should_fail: true,
291 fail_with: MockErrorType::InvalidRequest,
292 };
293 let provider2 = MockProvider {
294 should_fail: false,
295 fail_with: MockErrorType::RateLimited,
296 };
297
298 let fallback = FallbackProvider::new(vec![provider1, provider2]);
299
300 let request = LlmRequest {
301 prompt: "test".to_string(),
302 system_prompt: None,
303 temperature: None,
304 max_tokens: None,
305 tools: Vec::new(),
306 images: Vec::new(),
307 };
308
309 let result = fallback.complete(request).await;
311 assert!(result.is_err());
312 assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
313 }
314
315 #[tokio::test]
316 async fn test_fallback_retry_all_errors() {
317 let provider1 = MockProvider {
318 should_fail: true,
319 fail_with: MockErrorType::InvalidRequest,
320 };
321 let provider2 = MockProvider {
322 should_fail: false,
323 fail_with: MockErrorType::RateLimited,
324 };
325
326 let fallback =
327 FallbackProvider::new(vec![provider1, provider2]).with_retry_all_errors(true);
328
329 let request = LlmRequest {
330 prompt: "test".to_string(),
331 system_prompt: None,
332 temperature: None,
333 max_tokens: None,
334 tools: Vec::new(),
335 images: Vec::new(),
336 };
337
338 let result = fallback.complete(request).await;
340 assert!(result.is_ok());
341 }
342
343 #[test]
344 fn test_provider_count() {
345 let provider1 = MockProvider {
346 should_fail: false,
347 fail_with: MockErrorType::RateLimited,
348 };
349 let provider2 = MockProvider {
350 should_fail: false,
351 fail_with: MockErrorType::RateLimited,
352 };
353
354 let fallback = FallbackProvider::new(vec![provider1, provider2]);
355 assert_eq!(fallback.provider_count(), 2);
356 }
357}