use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
use crate::message::Message;
pub struct FallbackProvider {
primary: Arc<dyn LlmProvider>,
secondary: Arc<dyn LlmProvider>,
}
impl FallbackProvider {
pub fn new(primary: impl LlmProvider, secondary: impl LlmProvider) -> Self {
Self {
primary: Arc::new(primary),
secondary: Arc::new(secondary),
}
}
async fn do_complete(
primary: &dyn LlmProvider,
secondary: &dyn LlmProvider,
messages: Vec<Message>,
tools: Vec<ToolSchema>,
) -> Result<LlmResponse, PeError> {
match primary.complete(&messages, &tools).await {
Ok(resp) => Ok(resp),
Err(e) if e.is_transient() => secondary.complete(&messages, &tools).await,
Err(e) => Err(e),
}
}
}
impl LlmProvider for FallbackProvider {
fn complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
let messages = messages.to_vec();
let tools = tools.to_vec();
Box::pin(Self::do_complete(
self.primary.as_ref(),
self.secondary.as_ref(),
messages,
tools,
))
}
fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
let messages = messages.to_vec();
let tools = tools.to_vec();
Box::pin(async move {
match self.primary.stream(&messages, &tools).await {
Ok(stream) => Ok(stream),
Err(e) if e.is_transient() => self.secondary.stream(&messages, &tools).await,
Err(e) => Err(e),
}
})
}
fn embed(
&self,
text: &str,
) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
let text = text.to_owned();
Box::pin(async move {
match self.primary.embed(&text).await {
Ok(v) => Ok(v),
Err(e) if e.is_transient() => self.secondary.embed(&text).await,
Err(e) => Err(e),
}
})
}
fn provider_name(&self) -> &'static str {
self.primary.provider_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_provider::MockProvider;
#[tokio::test]
async fn test_primary_succeeds_no_fallback() {
let primary = MockProvider::new().respond_with("primary");
let secondary = MockProvider::new().respond_with("secondary");
let fb = FallbackProvider::new(primary, secondary);
let resp = fb.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("primary"));
}
#[tokio::test]
async fn test_falls_back_on_transient_error() {
let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
details: "503".into(),
});
let secondary = MockProvider::new().respond_with("fallback");
let fb = FallbackProvider::new(primary, secondary);
let resp = fb.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("fallback"));
}
#[tokio::test]
async fn test_permanent_error_propagates_no_fallback() {
let primary = MockProvider::new().respond_with_error(PeError::PermissionDenied {
action: "call".into(),
});
let secondary = MockProvider::new().respond_with("should not reach");
let fb = FallbackProvider::new(primary, secondary);
let err = fb.complete(&[], &[]).await.unwrap_err();
assert!(matches!(err, PeError::PermissionDenied { .. }));
}
#[tokio::test]
async fn test_both_fail_returns_secondary_error() {
let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
details: "primary down".into(),
});
let secondary = MockProvider::new().respond_with_error(PeError::LlmProvider {
details: "secondary down".into(),
});
let fb = FallbackProvider::new(primary, secondary);
let err = fb.complete(&[], &[]).await.unwrap_err();
match err {
PeError::LlmProvider { details } => assert_eq!(details, "secondary down"),
other => panic!("expected LlmProvider, got {other:?}"),
}
}
#[tokio::test]
async fn test_provider_name_returns_primary() {
let fb = FallbackProvider::new(MockProvider::new(), MockProvider::new());
assert_eq!(fb.provider_name(), "mock");
}
}