oxify_connect_llm/
fallback.rs

1//! Provider fallback mechanism for automatic failover
2
3use crate::{
4    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmError, LlmProvider, LlmRequest,
5    LlmResponse, LlmStream, Result, StreamingLlmProvider,
6};
7use async_trait::async_trait;
8
9/// A provider that automatically falls back to alternative providers on failure
10pub struct FallbackProvider<P> {
11    providers: Vec<P>,
12    /// Whether to retry on all errors or just retryable errors
13    retry_all_errors: bool,
14}
15
16impl<P> FallbackProvider<P> {
17    /// Create a new fallback provider with a list of providers
18    ///
19    /// # Panics
20    /// Panics if the providers list is empty
21    pub fn new(providers: Vec<P>) -> Self {
22        assert!(!providers.is_empty(), "Must provide at least one provider");
23        Self {
24            providers,
25            retry_all_errors: false,
26        }
27    }
28
29    /// Configure whether to fallback on all errors (default: only retryable errors)
30    pub fn with_retry_all_errors(mut self, retry_all: bool) -> Self {
31        self.retry_all_errors = retry_all;
32        self
33    }
34
35    /// Get the number of providers
36    pub fn provider_count(&self) -> usize {
37        self.providers.len()
38    }
39
40    /// Check if an error should trigger fallback
41    fn should_fallback(&self, error: &LlmError) -> bool {
42        if self.retry_all_errors {
43            true
44        } else {
45            // Only fallback on retryable errors
46            matches!(
47                error,
48                LlmError::RateLimited(_)
49                    | LlmError::NetworkError(_)
50                    | LlmError::ApiError(_)
51                    | LlmError::Timeout(_)
52            )
53        }
54    }
55}
56
57#[async_trait]
58impl<P: LlmProvider> LlmProvider for FallbackProvider<P> {
59    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
60        let mut last_error = None;
61
62        for (idx, provider) in self.providers.iter().enumerate() {
63            match provider.complete(request.clone()).await {
64                Ok(response) => {
65                    if idx > 0 {
66                        tracing::info!(
67                            provider_index = idx,
68                            "Successfully failed over to alternative provider"
69                        );
70                    }
71                    return Ok(response);
72                }
73                Err(e) => {
74                    if self.should_fallback(&e) && idx < self.providers.len() - 1 {
75                        tracing::warn!(
76                            provider_index = idx,
77                            error = %e,
78                            "Provider failed, trying next provider"
79                        );
80                        last_error = Some(e);
81                    } else {
82                        return Err(e);
83                    }
84                }
85            }
86        }
87
88        Err(last_error.unwrap_or(LlmError::ApiError("All providers failed".to_string())))
89    }
90}
91
92#[async_trait]
93impl<P: StreamingLlmProvider> StreamingLlmProvider for FallbackProvider<P> {
94    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
95        let mut last_error = None;
96
97        for (idx, provider) in self.providers.iter().enumerate() {
98            match provider.complete_stream(request.clone()).await {
99                Ok(stream) => {
100                    if idx > 0 {
101                        tracing::info!(
102                            provider_index = idx,
103                            "Successfully failed over to alternative provider for streaming"
104                        );
105                    }
106                    return Ok(stream);
107                }
108                Err(e) => {
109                    if self.should_fallback(&e) && idx < self.providers.len() - 1 {
110                        tracing::warn!(
111                            provider_index = idx,
112                            error = %e,
113                            "Provider failed for streaming, trying next provider"
114                        );
115                        last_error = Some(e);
116                    } else {
117                        return Err(e);
118                    }
119                }
120            }
121        }
122
123        Err(last_error.unwrap_or(LlmError::ApiError("All providers failed".to_string())))
124    }
125}
126
127#[async_trait]
128impl<P: EmbeddingProvider> EmbeddingProvider for FallbackProvider<P> {
129    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
130        let mut last_error = None;
131
132        for (idx, provider) in self.providers.iter().enumerate() {
133            match provider.embed(request.clone()).await {
134                Ok(response) => {
135                    if idx > 0 {
136                        tracing::info!(
137                            provider_index = idx,
138                            "Successfully failed over to alternative embedding provider"
139                        );
140                    }
141                    return Ok(response);
142                }
143                Err(e) => {
144                    if self.should_fallback(&e) && idx < self.providers.len() - 1 {
145                        tracing::warn!(
146                            provider_index = idx,
147                            error = %e,
148                            "Embedding provider failed, trying next provider"
149                        );
150                        last_error = Some(e);
151                    } else {
152                        return Err(e);
153                    }
154                }
155            }
156        }
157
158        Err(last_error.unwrap_or(LlmError::ApiError("All providers failed".to_string())))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use std::time::Duration;
166
167    #[derive(Clone)]
168    #[allow(dead_code)]
169    enum MockErrorType {
170        RateLimited,
171        RateLimitedWithDelay(Duration),
172        ApiError,
173        InvalidRequest,
174        Timeout,
175    }
176
177    struct MockProvider {
178        should_fail: bool,
179        fail_with: MockErrorType,
180    }
181
182    #[async_trait]
183    impl LlmProvider for MockProvider {
184        async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
185            if self.should_fail {
186                let err = match &self.fail_with {
187                    MockErrorType::RateLimited => LlmError::RateLimited(None),
188                    MockErrorType::RateLimitedWithDelay(d) => LlmError::RateLimited(Some(*d)),
189                    MockErrorType::ApiError => LlmError::ApiError("API error".to_string()),
190                    MockErrorType::InvalidRequest => {
191                        LlmError::InvalidRequest("bad request".to_string())
192                    }
193                    MockErrorType::Timeout => LlmError::Timeout(Duration::from_secs(30)),
194                };
195                Err(err)
196            } else {
197                Ok(LlmResponse {
198                    content: format!("Response to: {}", request.prompt),
199                    model: "mock".to_string(),
200                    usage: None,
201                    tool_calls: Vec::new(),
202                })
203            }
204        }
205    }
206
207    #[tokio::test]
208    async fn test_fallback_first_provider_success() {
209        let provider1 = MockProvider {
210            should_fail: false,
211            fail_with: MockErrorType::RateLimited,
212        };
213        let provider2 = MockProvider {
214            should_fail: false,
215            fail_with: MockErrorType::RateLimited,
216        };
217
218        let fallback = FallbackProvider::new(vec![provider1, provider2]);
219
220        let request = LlmRequest {
221            prompt: "test".to_string(),
222            system_prompt: None,
223            temperature: None,
224            max_tokens: None,
225            tools: Vec::new(),
226            images: Vec::new(),
227        };
228
229        let result = fallback.complete(request).await;
230        assert!(result.is_ok());
231        assert_eq!(result.unwrap().content, "Response to: test");
232    }
233
234    #[tokio::test]
235    async fn test_fallback_to_second_provider() {
236        let provider1 = MockProvider {
237            should_fail: true,
238            fail_with: MockErrorType::RateLimitedWithDelay(Duration::from_secs(5)),
239        };
240        let provider2 = MockProvider {
241            should_fail: false,
242            fail_with: MockErrorType::RateLimited,
243        };
244
245        let fallback = FallbackProvider::new(vec![provider1, provider2]);
246
247        let request = LlmRequest {
248            prompt: "test".to_string(),
249            system_prompt: None,
250            temperature: None,
251            max_tokens: None,
252            tools: Vec::new(),
253            images: Vec::new(),
254        };
255
256        let result = fallback.complete(request).await;
257        assert!(result.is_ok());
258        assert_eq!(result.unwrap().content, "Response to: test");
259    }
260
261    #[tokio::test]
262    async fn test_fallback_all_providers_fail() {
263        let provider1 = MockProvider {
264            should_fail: true,
265            fail_with: MockErrorType::RateLimited,
266        };
267        let provider2 = MockProvider {
268            should_fail: true,
269            fail_with: MockErrorType::ApiError,
270        };
271
272        let fallback = FallbackProvider::new(vec![provider1, provider2]);
273
274        let request = LlmRequest {
275            prompt: "test".to_string(),
276            system_prompt: None,
277            temperature: None,
278            max_tokens: None,
279            tools: Vec::new(),
280            images: Vec::new(),
281        };
282
283        let result = fallback.complete(request).await;
284        assert!(result.is_err());
285    }
286
287    #[tokio::test]
288    async fn test_fallback_non_retryable_error() {
289        let provider1 = MockProvider {
290            should_fail: true,
291            fail_with: MockErrorType::InvalidRequest,
292        };
293        let provider2 = MockProvider {
294            should_fail: false,
295            fail_with: MockErrorType::RateLimited,
296        };
297
298        let fallback = FallbackProvider::new(vec![provider1, provider2]);
299
300        let request = LlmRequest {
301            prompt: "test".to_string(),
302            system_prompt: None,
303            temperature: None,
304            max_tokens: None,
305            tools: Vec::new(),
306            images: Vec::new(),
307        };
308
309        // Should not fallback on non-retryable errors
310        let result = fallback.complete(request).await;
311        assert!(result.is_err());
312        assert!(matches!(result.unwrap_err(), LlmError::InvalidRequest(_)));
313    }
314
315    #[tokio::test]
316    async fn test_fallback_retry_all_errors() {
317        let provider1 = MockProvider {
318            should_fail: true,
319            fail_with: MockErrorType::InvalidRequest,
320        };
321        let provider2 = MockProvider {
322            should_fail: false,
323            fail_with: MockErrorType::RateLimited,
324        };
325
326        let fallback =
327            FallbackProvider::new(vec![provider1, provider2]).with_retry_all_errors(true);
328
329        let request = LlmRequest {
330            prompt: "test".to_string(),
331            system_prompt: None,
332            temperature: None,
333            max_tokens: None,
334            tools: Vec::new(),
335            images: Vec::new(),
336        };
337
338        // Should fallback even on non-retryable errors when retry_all_errors is true
339        let result = fallback.complete(request).await;
340        assert!(result.is_ok());
341    }
342
343    #[test]
344    fn test_provider_count() {
345        let provider1 = MockProvider {
346            should_fail: false,
347            fail_with: MockErrorType::RateLimited,
348        };
349        let provider2 = MockProvider {
350            should_fail: false,
351            fail_with: MockErrorType::RateLimited,
352        };
353
354        let fallback = FallbackProvider::new(vec![provider1, provider2]);
355        assert_eq!(fallback.provider_count(), 2);
356    }
357}