Skip to main content

cognate_providers/
fallback.rs

1//! Transparent provider fallback.
2//!
3//! [`FallbackProvider`] tries the primary provider first and, on any
4//! retryable error (rate limit, 5xx, timeout), transparently retries with
5//! the fallback provider.
6
7use async_trait::async_trait;
8use cognate_core::{Chunk, Provider, Request, Response, Result};
9use futures::stream::BoxStream;
10use std::sync::Arc;
11
12/// A provider that falls back to a secondary provider on retryable errors.
13///
14/// # Example
15///
16/// ```rust,no_run
17/// use cognate_providers::{OpenAiProvider, AnthropicProvider, FallbackProvider};
18/// use cognate_core::{Provider, Request, Message};
19/// use std::sync::Arc;
20///
21/// # async fn run() -> cognate_core::Result<()> {
22/// let primary = Arc::new(OpenAiProvider::new(std::env::var("OPENAI_API_KEY").unwrap())?);
23/// let secondary = Arc::new(AnthropicProvider::new(std::env::var("ANTHROPIC_API_KEY").unwrap())?);
24/// let provider = FallbackProvider::new(primary, secondary);
25///
26/// let resp = provider
27///     .complete(Request::new().with_model("gpt-4o").with_message(Message::user("Hi")))
28///     .await?;
29/// println!("{}", resp.content());
30/// # Ok(())
31/// # }
32/// ```
33pub struct FallbackProvider {
34    primary: Arc<dyn Provider>,
35    fallback: Arc<dyn Provider>,
36}
37
38impl FallbackProvider {
39    /// Create a new fallback pair.
40    ///
41    /// On any [`retryable`](cognate_core::Error::is_retryable) error from
42    /// `primary`, `fallback` is tried instead.
43    pub fn new(primary: Arc<dyn Provider>, fallback: Arc<dyn Provider>) -> Self {
44        Self { primary, fallback }
45    }
46}
47
48#[async_trait]
49impl Provider for FallbackProvider {
50    async fn complete(&self, req: Request) -> Result<Response> {
51        match self.primary.complete(req.clone()).await {
52            Ok(resp) => Ok(resp),
53            Err(e) if e.is_retryable() => {
54                tracing::warn!(error = %e, "cognate: primary provider failed, trying fallback");
55                self.fallback.complete(req).await
56            }
57            Err(e) => Err(e),
58        }
59    }
60
61    async fn stream(&self, req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
62        match self.primary.stream(req.clone()).await {
63            Ok(s) => Ok(s),
64            Err(e) if e.is_retryable() => {
65                tracing::warn!(error = %e, "cognate: primary provider failed on stream, trying fallback");
66                self.fallback.stream(req).await
67            }
68            Err(e) => Err(e),
69        }
70    }
71}