Skip to main content

llmkit_tower/
fallback.rs

1//! Provider fallback chaining: primary → secondary on error/timeout.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use llmkit_core::{
7    ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
8    LlmProvider, LlmResult,
9};
10
11/// Tries providers in order, advancing to the next on a retryable failure.
12///
13/// Non-retryable errors (invalid request, auth, budget) short-circuit — there
14/// is no point trying the next provider if the request itself is malformed.
15pub struct FallbackProvider {
16    providers: Vec<Arc<dyn LlmProvider>>,
17}
18
19impl FallbackProvider {
20    /// Build a chain from an ordered list of providers (primary first).
21    pub fn new(providers: Vec<Arc<dyn LlmProvider>>) -> Self {
22        assert!(!providers.is_empty(), "FallbackProvider requires at least one provider");
23        Self { providers }
24    }
25
26    fn should_advance(err: &LlmError) -> bool {
27        err.is_retryable()
28    }
29}
30
31#[async_trait]
32impl LlmProvider for FallbackProvider {
33    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
34        let mut errors = Vec::new();
35        let last = self.providers.len() - 1;
36        for (i, p) in self.providers.iter().enumerate() {
37            match p.chat(req.clone()).await {
38                Ok(resp) => return Ok(resp),
39                Err(e) if i < last && Self::should_advance(&e) => {
40                    tracing::warn!(provider = p.name(), error = %e, "falling back to next provider");
41                    errors.push(e);
42                }
43                Err(e) => {
44                    errors.push(e);
45                    return Err(LlmError::AllProvidersFailed(errors));
46                }
47            }
48        }
49        Err(LlmError::AllProvidersFailed(errors))
50    }
51
52    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
53        let mut errors = Vec::new();
54        let last = self.providers.len() - 1;
55        for (i, p) in self.providers.iter().enumerate() {
56            match p.chat_stream(req.clone()).await {
57                Ok(s) => return Ok(s),
58                Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
59                Err(e) => {
60                    errors.push(e);
61                    return Err(LlmError::AllProvidersFailed(errors));
62                }
63            }
64        }
65        Err(LlmError::AllProvidersFailed(errors))
66    }
67
68    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
69        let mut errors = Vec::new();
70        let last = self.providers.len() - 1;
71        for (i, p) in self.providers.iter().enumerate() {
72            match p.embed(req.clone()).await {
73                Ok(r) => return Ok(r),
74                Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
75                Err(e) => {
76                    errors.push(e);
77                    return Err(LlmError::AllProvidersFailed(errors));
78                }
79            }
80        }
81        Err(LlmError::AllProvidersFailed(errors))
82    }
83
84    fn name(&self) -> &'static str {
85        self.providers[0].name()
86    }
87
88    fn model(&self) -> &str {
89        self.providers[0].model()
90    }
91
92    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
93        self.providers[0].estimate_cost(req)
94    }
95}