simple_agents_router/
fallback.rs1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderError, Result, SimpleAgentsError,
7};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy)]
12pub struct FallbackRouterConfig {
13 pub retryable_only: bool,
15}
16
17impl Default for FallbackRouterConfig {
18 fn default() -> Self {
19 Self {
20 retryable_only: true,
21 }
22 }
23}
24
25pub struct FallbackRouter {
27 providers: Vec<Arc<dyn Provider>>,
28 config: FallbackRouterConfig,
29}
30
31impl FallbackRouter {
32 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
37 Self::with_config(providers, FallbackRouterConfig::default())
38 }
39
40 pub fn with_config(
45 providers: Vec<Arc<dyn Provider>>,
46 config: FallbackRouterConfig,
47 ) -> Result<Self> {
48 if providers.is_empty() {
49 return Err(SimpleAgentsError::Routing(
50 "no providers configured".to_string(),
51 ));
52 }
53
54 Ok(Self { providers, config })
55 }
56
57 pub fn provider_count(&self) -> usize {
59 self.providers.len()
60 }
61
62 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
64 let mut last_error: Option<SimpleAgentsError> = None;
65
66 for provider in &self.providers {
67 let attempt = self.execute_provider(provider, request).await;
68 match attempt {
69 Ok(response) => return Ok(response),
70 Err(err) => {
71 if !self.should_fallback(&err) {
72 return Err(err);
73 }
74 last_error = Some(err);
75 }
76 }
77 }
78
79 Err(last_error
80 .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
81 }
82
83 pub async fn stream(
85 &self,
86 request: &CompletionRequest,
87 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
88 for provider in &self.providers {
89 let provider_request = provider.transform_request(request)?;
90 match provider.execute_stream(provider_request).await {
91 Ok(stream) => return Ok(stream),
92 Err(err) => {
93 if !self.should_fallback(&err) {
94 return Err(err);
95 }
96 }
98 }
99 }
100
101 Err(SimpleAgentsError::Routing("no providers configured".to_string()))
102 }
103
104 async fn execute_provider(
105 &self,
106 provider: &Arc<dyn Provider>,
107 request: &CompletionRequest,
108 ) -> Result<CompletionResponse> {
109 let provider_request = provider.transform_request(request)?;
110 let provider_response = provider.execute(provider_request).await?;
111 provider.transform_response(provider_response)
112 }
113
114 fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
115 if !self.config.retryable_only {
116 return true;
117 }
118
119 matches!(
120 error,
121 SimpleAgentsError::Provider(
122 ProviderError::RateLimit { .. }
123 | ProviderError::Timeout(_)
124 | ProviderError::ServerError(_)
125 ) | SimpleAgentsError::Network(_)
126 )
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use async_trait::async_trait;
134 use simple_agent_type::prelude::*;
135 use std::sync::atomic::{AtomicUsize, Ordering};
136
137 struct MockProvider {
138 name: &'static str,
139 attempts: AtomicUsize,
140 result: MockResult,
141 }
142
143 enum MockResult {
144 Ok,
145 RetryableError,
146 NonRetryableError,
147 }
148
149 impl MockProvider {
150 fn new(name: &'static str, result: MockResult) -> Self {
151 Self {
152 name,
153 attempts: AtomicUsize::new(0),
154 result,
155 }
156 }
157 }
158
159 #[async_trait]
160 impl Provider for MockProvider {
161 fn name(&self) -> &str {
162 self.name
163 }
164
165 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
166 Ok(ProviderRequest::new("http://example.com"))
167 }
168
169 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
170 self.attempts.fetch_add(1, Ordering::Relaxed);
171 match self.result {
172 MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
173 MockResult::RetryableError => Err(SimpleAgentsError::Provider(
174 ProviderError::Timeout(std::time::Duration::from_secs(1)),
175 )),
176 MockResult::NonRetryableError => {
177 Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
178 }
179 }
180 }
181
182 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
183 Ok(CompletionResponse {
184 id: "resp_test".to_string(),
185 model: "test-model".to_string(),
186 choices: vec![CompletionChoice {
187 index: 0,
188 message: Message::assistant("ok"),
189 finish_reason: FinishReason::Stop,
190 logprobs: None,
191 }],
192 usage: Usage::new(1, 1),
193 created: None,
194 provider: Some(self.name().to_string()),
195 healing_metadata: None,
196 })
197 }
198 }
199
200 fn build_request() -> CompletionRequest {
201 CompletionRequest::builder()
202 .model("test-model")
203 .message(Message::user("hello"))
204 .build()
205 .unwrap()
206 }
207
208 #[test]
209 fn empty_router_returns_error() {
210 let result = FallbackRouter::new(Vec::new());
211 match result {
212 Ok(_) => panic!("expected error, got Ok"),
213 Err(SimpleAgentsError::Routing(message)) => {
214 assert_eq!(message, "no providers configured");
215 }
216 Err(_) => panic!("unexpected error type"),
217 }
218 }
219
220 #[tokio::test]
221 async fn falls_back_on_retryable_error() {
222 let router = FallbackRouter::new(vec![
223 Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
224 Arc::new(MockProvider::new("p2", MockResult::Ok)),
225 ])
226 .unwrap();
227
228 let response = router.complete(&build_request()).await.unwrap();
229 assert_eq!(response.provider.as_deref(), Some("p2"));
230 }
231
232 #[tokio::test]
233 async fn stops_on_non_retryable_error() {
234 let router = FallbackRouter::new(vec![
235 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
236 Arc::new(MockProvider::new("p2", MockResult::Ok)),
237 ])
238 .unwrap();
239
240 let err = router.complete(&build_request()).await.unwrap_err();
241 match err {
242 SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
243 _ => panic!("unexpected error"),
244 }
245 }
246
247 #[tokio::test]
248 async fn falls_back_on_all_errors_when_configured() {
249 let config = FallbackRouterConfig {
250 retryable_only: false,
251 };
252 let router = FallbackRouter::with_config(
253 vec![
254 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
255 Arc::new(MockProvider::new("p2", MockResult::Ok)),
256 ],
257 config,
258 )
259 .unwrap();
260
261 let response = router.complete(&build_request()).await.unwrap();
262 assert_eq!(response.provider.as_deref(), Some("p2"));
263 }
264}