pe_core/
fallback_provider.rs1use std::future::Future;
16use std::pin::Pin;
17use std::sync::Arc;
18
19use crate::error::PeError;
20use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
21use crate::message::Message;
22
23pub struct FallbackProvider {
29 primary: Arc<dyn LlmProvider>,
30 secondary: Arc<dyn LlmProvider>,
31}
32
33impl FallbackProvider {
34 pub fn new(primary: impl LlmProvider, secondary: impl LlmProvider) -> Self {
36 Self {
37 primary: Arc::new(primary),
38 secondary: Arc::new(secondary),
39 }
40 }
41
42 async fn do_complete(
43 primary: &dyn LlmProvider,
44 secondary: &dyn LlmProvider,
45 messages: Vec<Message>,
46 tools: Vec<ToolSchema>,
47 ) -> Result<LlmResponse, PeError> {
48 match primary.complete(&messages, &tools).await {
49 Ok(resp) => Ok(resp),
50 Err(e) if e.is_transient() => secondary.complete(&messages, &tools).await,
51 Err(e) => Err(e),
52 }
53 }
54}
55
56impl LlmProvider for FallbackProvider {
57 fn complete(
58 &self,
59 messages: &[Message],
60 tools: &[ToolSchema],
61 ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
62 let messages = messages.to_vec();
63 let tools = tools.to_vec();
64 Box::pin(Self::do_complete(
65 self.primary.as_ref(),
66 self.secondary.as_ref(),
67 messages,
68 tools,
69 ))
70 }
71
72 fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
73 let messages = messages.to_vec();
74 let tools = tools.to_vec();
75 Box::pin(async move {
76 match self.primary.stream(&messages, &tools).await {
77 Ok(stream) => Ok(stream),
78 Err(e) if e.is_transient() => self.secondary.stream(&messages, &tools).await,
79 Err(e) => Err(e),
80 }
81 })
82 }
83
84 fn embed(
85 &self,
86 text: &str,
87 ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
88 let text = text.to_owned();
89 Box::pin(async move {
90 match self.primary.embed(&text).await {
91 Ok(v) => Ok(v),
92 Err(e) if e.is_transient() => self.secondary.embed(&text).await,
93 Err(e) => Err(e),
94 }
95 })
96 }
97
98 fn provider_name(&self) -> &'static str {
99 self.primary.provider_name()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::mock_provider::MockProvider;
107
108 #[tokio::test]
109 async fn test_primary_succeeds_no_fallback() {
110 let primary = MockProvider::new().respond_with("primary");
111 let secondary = MockProvider::new().respond_with("secondary");
112
113 let fb = FallbackProvider::new(primary, secondary);
114 let resp = fb.complete(&[], &[]).await.unwrap();
115 assert_eq!(resp.message.content.as_text(), Some("primary"));
116 }
117
118 #[tokio::test]
119 async fn test_falls_back_on_transient_error() {
120 let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
121 details: "503".into(),
122 });
123 let secondary = MockProvider::new().respond_with("fallback");
124
125 let fb = FallbackProvider::new(primary, secondary);
126 let resp = fb.complete(&[], &[]).await.unwrap();
127 assert_eq!(resp.message.content.as_text(), Some("fallback"));
128 }
129
130 #[tokio::test]
131 async fn test_permanent_error_propagates_no_fallback() {
132 let primary = MockProvider::new().respond_with_error(PeError::PermissionDenied {
133 action: "call".into(),
134 });
135 let secondary = MockProvider::new().respond_with("should not reach");
136
137 let fb = FallbackProvider::new(primary, secondary);
138 let err = fb.complete(&[], &[]).await.unwrap_err();
139 assert!(matches!(err, PeError::PermissionDenied { .. }));
140 }
141
142 #[tokio::test]
143 async fn test_both_fail_returns_secondary_error() {
144 let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
145 details: "primary down".into(),
146 });
147 let secondary = MockProvider::new().respond_with_error(PeError::LlmProvider {
148 details: "secondary down".into(),
149 });
150
151 let fb = FallbackProvider::new(primary, secondary);
152 let err = fb.complete(&[], &[]).await.unwrap_err();
153 match err {
154 PeError::LlmProvider { details } => assert_eq!(details, "secondary down"),
155 other => panic!("expected LlmProvider, got {other:?}"),
156 }
157 }
158
159 #[tokio::test]
160 async fn test_provider_name_returns_primary() {
161 let fb = FallbackProvider::new(MockProvider::new(), MockProvider::new());
162 assert_eq!(fb.provider_name(), "mock");
163 }
164}