1use std::sync::Arc;
4
5use async_trait::async_trait;
6use llmkit_core::{
7 ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
8 LlmProvider, LlmResult,
9};
10
11pub struct FallbackProvider {
16 providers: Vec<Arc<dyn LlmProvider>>,
17}
18
19impl FallbackProvider {
20 pub fn new(providers: Vec<Arc<dyn LlmProvider>>) -> Self {
22 assert!(!providers.is_empty(), "FallbackProvider requires at least one provider");
23 Self { providers }
24 }
25
26 fn should_advance(err: &LlmError) -> bool {
27 err.is_retryable()
28 }
29}
30
31#[async_trait]
32impl LlmProvider for FallbackProvider {
33 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
34 let mut errors = Vec::new();
35 let last = self.providers.len() - 1;
36 for (i, p) in self.providers.iter().enumerate() {
37 match p.chat(req.clone()).await {
38 Ok(resp) => return Ok(resp),
39 Err(e) if i < last && Self::should_advance(&e) => {
40 tracing::warn!(provider = p.name(), error = %e, "falling back to next provider");
41 errors.push(e);
42 }
43 Err(e) => {
44 errors.push(e);
45 return Err(LlmError::AllProvidersFailed(errors));
46 }
47 }
48 }
49 Err(LlmError::AllProvidersFailed(errors))
50 }
51
52 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
53 let mut errors = Vec::new();
54 let last = self.providers.len() - 1;
55 for (i, p) in self.providers.iter().enumerate() {
56 match p.chat_stream(req.clone()).await {
57 Ok(s) => return Ok(s),
58 Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
59 Err(e) => {
60 errors.push(e);
61 return Err(LlmError::AllProvidersFailed(errors));
62 }
63 }
64 }
65 Err(LlmError::AllProvidersFailed(errors))
66 }
67
68 async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
69 let mut errors = Vec::new();
70 let last = self.providers.len() - 1;
71 for (i, p) in self.providers.iter().enumerate() {
72 match p.embed(req.clone()).await {
73 Ok(r) => return Ok(r),
74 Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
75 Err(e) => {
76 errors.push(e);
77 return Err(LlmError::AllProvidersFailed(errors));
78 }
79 }
80 }
81 Err(LlmError::AllProvidersFailed(errors))
82 }
83
84 fn name(&self) -> &'static str {
85 self.providers[0].name()
86 }
87
88 fn model(&self) -> &str {
89 self.providers[0].model()
90 }
91
92 fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
93 self.providers[0].estimate_cost(req)
94 }
95}