cognate_providers/fallback.rs
1//! Transparent provider fallback.
2//!
3//! [`FallbackProvider`] tries the primary provider first and, on any
4//! retryable error (rate limit, 5xx, timeout), transparently retries with
5//! the fallback provider.
6
7use async_trait::async_trait;
8use cognate_core::{Chunk, Provider, Request, Response, Result};
9use futures::stream::BoxStream;
10use std::sync::Arc;
11
12/// A provider that falls back to a secondary provider on retryable errors.
13///
14/// # Example
15///
16/// ```rust,no_run
17/// use cognate_providers::{OpenAiProvider, AnthropicProvider, FallbackProvider};
18/// use cognate_core::{Provider, Request, Message};
19/// use std::sync::Arc;
20///
21/// # async fn run() -> cognate_core::Result<()> {
22/// let primary = Arc::new(OpenAiProvider::new(std::env::var("OPENAI_API_KEY").unwrap())?);
23/// let secondary = Arc::new(AnthropicProvider::new(std::env::var("ANTHROPIC_API_KEY").unwrap())?);
24/// let provider = FallbackProvider::new(primary, secondary);
25///
26/// let resp = provider
27/// .complete(Request::new().with_model("gpt-4o").with_message(Message::user("Hi")))
28/// .await?;
29/// println!("{}", resp.content());
30/// # Ok(())
31/// # }
32/// ```
33pub struct FallbackProvider {
34 primary: Arc<dyn Provider>,
35 fallback: Arc<dyn Provider>,
36}
37
38impl FallbackProvider {
39 /// Create a new fallback pair.
40 ///
41 /// On any [`retryable`](cognate_core::Error::is_retryable) error from
42 /// `primary`, `fallback` is tried instead.
43 pub fn new(primary: Arc<dyn Provider>, fallback: Arc<dyn Provider>) -> Self {
44 Self { primary, fallback }
45 }
46}
47
48#[async_trait]
49impl Provider for FallbackProvider {
50 async fn complete(&self, req: Request) -> Result<Response> {
51 match self.primary.complete(req.clone()).await {
52 Ok(resp) => Ok(resp),
53 Err(e) if e.is_retryable() => {
54 tracing::warn!(error = %e, "cognate: primary provider failed, trying fallback");
55 self.fallback.complete(req).await
56 }
57 Err(e) => Err(e),
58 }
59 }
60
61 async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
62 match self.primary.stream(req.clone()).await {
63 Ok(s) => Ok(s),
64 Err(e) if e.is_retryable() => {
65 tracing::warn!(error = %e, "cognate: primary provider failed on stream, trying fallback");
66 self.fallback.stream(req).await
67 }
68 Err(e) => Err(e),
69 }
70 }
71}