use std::sync::Arc;
use async_trait::async_trait;
use llmkit_core::{
ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmError,
LlmProvider, LlmResult,
};
pub struct FallbackProvider {
providers: Vec<Arc<dyn LlmProvider>>,
}
impl FallbackProvider {
pub fn new(providers: Vec<Arc<dyn LlmProvider>>) -> Self {
assert!(!providers.is_empty(), "FallbackProvider requires at least one provider");
Self { providers }
}
fn should_advance(err: &LlmError) -> bool {
err.is_retryable()
}
}
#[async_trait]
impl LlmProvider for FallbackProvider {
async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
let mut errors = Vec::new();
let last = self.providers.len() - 1;
for (i, p) in self.providers.iter().enumerate() {
match p.chat(req.clone()).await {
Ok(resp) => return Ok(resp),
Err(e) if i < last && Self::should_advance(&e) => {
tracing::warn!(provider = p.name(), error = %e, "falling back to next provider");
errors.push(e);
}
Err(e) => {
errors.push(e);
return Err(LlmError::AllProvidersFailed(errors));
}
}
}
Err(LlmError::AllProvidersFailed(errors))
}
async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
let mut errors = Vec::new();
let last = self.providers.len() - 1;
for (i, p) in self.providers.iter().enumerate() {
match p.chat_stream(req.clone()).await {
Ok(s) => return Ok(s),
Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
Err(e) => {
errors.push(e);
return Err(LlmError::AllProvidersFailed(errors));
}
}
}
Err(LlmError::AllProvidersFailed(errors))
}
async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
let mut errors = Vec::new();
let last = self.providers.len() - 1;
for (i, p) in self.providers.iter().enumerate() {
match p.embed(req.clone()).await {
Ok(r) => return Ok(r),
Err(e) if i < last && Self::should_advance(&e) => errors.push(e),
Err(e) => {
errors.push(e);
return Err(LlmError::AllProvidersFailed(errors));
}
}
}
Err(LlmError::AllProvidersFailed(errors))
}
fn name(&self) -> &'static str {
self.providers[0].name()
}
fn model(&self) -> &str {
self.providers[0].model()
}
fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
self.providers[0].estimate_cost(req)
}
}