Skip to main content

pe_core/
fallback_provider.rs

1//! Fallback provider -- tries a primary LLM, falls back to secondary on transient failure.
2//!
3//! Only falls back when the primary returns a transient error (as classified by
4//! [`PeError::is_transient()`]). Permanent errors propagate immediately.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use pe_core::fallback_provider::FallbackProvider;
10//!
11//! let provider = FallbackProvider::new(openai_provider, anthropic_provider);
12//! // Uses OpenAI by default; on transient failure, retries with Anthropic
13//! ```
14
15use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18
19use crate::error::PeError;
20use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
21use crate::message::Message;
22
23/// Wraps a primary and secondary [`LlmProvider`]. On transient failure from
24/// the primary, automatically tries the secondary.
25///
26/// Implements `LlmProvider` itself, so it can be used anywhere a provider is expected
27/// (including inside a [`MiddlewareStack`](super::middleware_stack::MiddlewareStack)).
28pub struct FallbackProvider {
29    primary: Arc<dyn LlmProvider>,
30    secondary: Arc<dyn LlmProvider>,
31}
32
33impl FallbackProvider {
34    /// Create a fallback provider with a primary and secondary.
35    pub fn new(primary: impl LlmProvider, secondary: impl LlmProvider) -> Self {
36        Self {
37            primary: Arc::new(primary),
38            secondary: Arc::new(secondary),
39        }
40    }
41
42    async fn do_complete(
43        primary: &dyn LlmProvider,
44        secondary: &dyn LlmProvider,
45        messages: Vec<Message>,
46        tools: Vec<ToolSchema>,
47    ) -> Result<LlmResponse, PeError> {
48        match primary.complete(&messages, &tools).await {
49            Ok(resp) => Ok(resp),
50            Err(e) if e.is_transient() => secondary.complete(&messages, &tools).await,
51            Err(e) => Err(e),
52        }
53    }
54}
55
56impl LlmProvider for FallbackProvider {
57    fn complete(
58        &self,
59        messages: &[Message],
60        tools: &[ToolSchema],
61    ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
62        let messages = messages.to_vec();
63        let tools = tools.to_vec();
64        Box::pin(Self::do_complete(
65            self.primary.as_ref(),
66            self.secondary.as_ref(),
67            messages,
68            tools,
69        ))
70    }
71
72    fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
73        let messages = messages.to_vec();
74        let tools = tools.to_vec();
75        Box::pin(async move {
76            match self.primary.stream(&messages, &tools).await {
77                Ok(stream) => Ok(stream),
78                Err(e) if e.is_transient() => self.secondary.stream(&messages, &tools).await,
79                Err(e) => Err(e),
80            }
81        })
82    }
83
84    fn embed(
85        &self,
86        text: &str,
87    ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
88        let text = text.to_owned();
89        Box::pin(async move {
90            match self.primary.embed(&text).await {
91                Ok(v) => Ok(v),
92                Err(e) if e.is_transient() => self.secondary.embed(&text).await,
93                Err(e) => Err(e),
94            }
95        })
96    }
97
98    fn provider_name(&self) -> &'static str {
99        self.primary.provider_name()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::mock_provider::MockProvider;
107
108    #[tokio::test]
109    async fn test_primary_succeeds_no_fallback() {
110        let primary = MockProvider::new().respond_with("primary");
111        let secondary = MockProvider::new().respond_with("secondary");
112
113        let fb = FallbackProvider::new(primary, secondary);
114        let resp = fb.complete(&[], &[]).await.unwrap();
115        assert_eq!(resp.message.content.as_text(), Some("primary"));
116    }
117
118    #[tokio::test]
119    async fn test_falls_back_on_transient_error() {
120        let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
121            details: "503".into(),
122        });
123        let secondary = MockProvider::new().respond_with("fallback");
124
125        let fb = FallbackProvider::new(primary, secondary);
126        let resp = fb.complete(&[], &[]).await.unwrap();
127        assert_eq!(resp.message.content.as_text(), Some("fallback"));
128    }
129
130    #[tokio::test]
131    async fn test_permanent_error_propagates_no_fallback() {
132        let primary = MockProvider::new().respond_with_error(PeError::PermissionDenied {
133            action: "call".into(),
134        });
135        let secondary = MockProvider::new().respond_with("should not reach");
136
137        let fb = FallbackProvider::new(primary, secondary);
138        let err = fb.complete(&[], &[]).await.unwrap_err();
139        assert!(matches!(err, PeError::PermissionDenied { .. }));
140    }
141
142    #[tokio::test]
143    async fn test_both_fail_returns_secondary_error() {
144        let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
145            details: "primary down".into(),
146        });
147        let secondary = MockProvider::new().respond_with_error(PeError::LlmProvider {
148            details: "secondary down".into(),
149        });
150
151        let fb = FallbackProvider::new(primary, secondary);
152        let err = fb.complete(&[], &[]).await.unwrap_err();
153        match err {
154            PeError::LlmProvider { details } => assert_eq!(details, "secondary down"),
155            other => panic!("expected LlmProvider, got {other:?}"),
156        }
157    }
158
159    #[tokio::test]
160    async fn test_provider_name_returns_primary() {
161        let fb = FallbackProvider::new(MockProvider::new(), MockProvider::new());
162        assert_eq!(fb.provider_name(), "mock");
163    }
164}