simple_agents_router/
fallback.rs1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderError, Result,
7 SimpleAgentsError,
8};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy)]
13pub struct FallbackRouterConfig {
14 pub retryable_only: bool,
16}
17
18impl Default for FallbackRouterConfig {
19 fn default() -> Self {
20 Self {
21 retryable_only: true,
22 }
23 }
24}
25
26pub struct FallbackRouter {
28 providers: Vec<Arc<dyn Provider>>,
29 config: FallbackRouterConfig,
30}
31
32impl FallbackRouter {
33 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
38 Self::with_config(providers, FallbackRouterConfig::default())
39 }
40
41 pub fn with_config(
46 providers: Vec<Arc<dyn Provider>>,
47 config: FallbackRouterConfig,
48 ) -> Result<Self> {
49 if providers.is_empty() {
50 return Err(SimpleAgentsError::Routing(
51 "no providers configured".to_string(),
52 ));
53 }
54
55 Ok(Self { providers, config })
56 }
57
58 pub fn provider_count(&self) -> usize {
60 self.providers.len()
61 }
62
63 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
65 let mut last_error: Option<SimpleAgentsError> = None;
66
67 for provider in &self.providers {
68 let attempt = self.execute_provider(provider, request).await;
69 match attempt {
70 Ok(response) => return Ok(response),
71 Err(err) => {
72 if !self.should_fallback(&err) {
73 return Err(err);
74 }
75 last_error = Some(err);
76 }
77 }
78 }
79
80 Err(last_error
81 .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
82 }
83
84 pub async fn stream(
86 &self,
87 request: &CompletionRequest,
88 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
89 for provider in &self.providers {
90 let provider_request = provider.transform_request(request)?;
91 match provider.execute_stream(provider_request).await {
92 Ok(stream) => return Ok(stream),
93 Err(err) => {
94 if !self.should_fallback(&err) {
95 return Err(err);
96 }
97 }
99 }
100 }
101
102 Err(SimpleAgentsError::Routing(
103 "no providers configured".to_string(),
104 ))
105 }
106
107 async fn execute_provider(
108 &self,
109 provider: &Arc<dyn Provider>,
110 request: &CompletionRequest,
111 ) -> Result<CompletionResponse> {
112 let provider_request = provider.transform_request(request)?;
113 let provider_response = provider.execute(provider_request).await?;
114 provider.transform_response(provider_response)
115 }
116
117 fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
118 if !self.config.retryable_only {
119 return true;
120 }
121
122 matches!(
123 error,
124 SimpleAgentsError::Provider(
125 ProviderError::RateLimit { .. }
126 | ProviderError::Timeout(_)
127 | ProviderError::ServerError(_)
128 ) | SimpleAgentsError::Network(_)
129 )
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use async_trait::async_trait;
137 use simple_agent_type::prelude::*;
138 use std::sync::atomic::{AtomicUsize, Ordering};
139
140 struct MockProvider {
141 name: &'static str,
142 attempts: AtomicUsize,
143 result: MockResult,
144 }
145
146 enum MockResult {
147 Ok,
148 RetryableError,
149 NonRetryableError,
150 }
151
152 impl MockProvider {
153 fn new(name: &'static str, result: MockResult) -> Self {
154 Self {
155 name,
156 attempts: AtomicUsize::new(0),
157 result,
158 }
159 }
160 }
161
162 #[async_trait]
163 impl Provider for MockProvider {
164 fn name(&self) -> &str {
165 self.name
166 }
167
168 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
169 Ok(ProviderRequest::new("http://example.com"))
170 }
171
172 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
173 self.attempts.fetch_add(1, Ordering::Relaxed);
174 match self.result {
175 MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
176 MockResult::RetryableError => Err(SimpleAgentsError::Provider(
177 ProviderError::Timeout(std::time::Duration::from_secs(1)),
178 )),
179 MockResult::NonRetryableError => {
180 Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
181 }
182 }
183 }
184
185 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
186 Ok(CompletionResponse {
187 id: "resp_test".to_string(),
188 model: "test-model".to_string(),
189 choices: vec![CompletionChoice {
190 index: 0,
191 message: Message::assistant("ok"),
192 finish_reason: FinishReason::Stop,
193 logprobs: None,
194 }],
195 usage: Usage::new(1, 1),
196 created: None,
197 provider: Some(self.name().to_string()),
198 healing_metadata: None,
199 })
200 }
201 }
202
203 fn build_request() -> CompletionRequest {
204 CompletionRequest::builder()
205 .model("test-model")
206 .message(Message::user("hello"))
207 .build()
208 .unwrap()
209 }
210
211 #[test]
212 fn empty_router_returns_error() {
213 let result = FallbackRouter::new(Vec::new());
214 match result {
215 Ok(_) => panic!("expected error, got Ok"),
216 Err(SimpleAgentsError::Routing(message)) => {
217 assert_eq!(message, "no providers configured");
218 }
219 Err(_) => panic!("unexpected error type"),
220 }
221 }
222
223 #[tokio::test]
224 async fn falls_back_on_retryable_error() {
225 let router = FallbackRouter::new(vec![
226 Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
227 Arc::new(MockProvider::new("p2", MockResult::Ok)),
228 ])
229 .unwrap();
230
231 let response = router.complete(&build_request()).await.unwrap();
232 assert_eq!(response.provider.as_deref(), Some("p2"));
233 }
234
235 #[tokio::test]
236 async fn stops_on_non_retryable_error() {
237 let router = FallbackRouter::new(vec![
238 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
239 Arc::new(MockProvider::new("p2", MockResult::Ok)),
240 ])
241 .unwrap();
242
243 let err = router.complete(&build_request()).await.unwrap_err();
244 match err {
245 SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
246 _ => panic!("unexpected error"),
247 }
248 }
249
250 #[tokio::test]
251 async fn falls_back_on_all_errors_when_configured() {
252 let config = FallbackRouterConfig {
253 retryable_only: false,
254 };
255 let router = FallbackRouter::with_config(
256 vec![
257 Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
258 Arc::new(MockProvider::new("p2", MockResult::Ok)),
259 ],
260 config,
261 )
262 .unwrap();
263
264 let response = router.complete(&build_request()).await.unwrap();
265 assert_eq!(response.provider.as_deref(), Some("p2"));
266 }
267}