Skip to main content

simple_agents_router/
fallback.rs

1//! Fallback routing implementation.
2//!
3//! Attempts providers in order, falling back on retryable errors.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderError, Result,
7    SimpleAgentsError,
8};
9use std::sync::Arc;
10
11/// Configuration for fallback routing.
12#[derive(Debug, Clone, Copy)]
13pub struct FallbackRouterConfig {
14    /// If true, fallback only on retryable provider errors.
15    pub retryable_only: bool,
16}
17
18impl Default for FallbackRouterConfig {
19    fn default() -> Self {
20        Self {
21            retryable_only: true,
22        }
23    }
24}
25
26/// Router that tries providers in order and falls back on eligible errors.
27pub struct FallbackRouter {
28    providers: Vec<Arc<dyn Provider>>,
29    config: FallbackRouterConfig,
30}
31
32impl FallbackRouter {
33    /// Create a new fallback router.
34    ///
35    /// # Errors
36    /// Returns a routing error if no providers are supplied.
37    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
38        Self::with_config(providers, FallbackRouterConfig::default())
39    }
40
41    /// Create a new fallback router with custom configuration.
42    ///
43    /// # Errors
44    /// Returns a routing error if no providers are supplied.
45    pub fn with_config(
46        providers: Vec<Arc<dyn Provider>>,
47        config: FallbackRouterConfig,
48    ) -> Result<Self> {
49        if providers.is_empty() {
50            return Err(SimpleAgentsError::Routing(
51                "no providers configured".to_string(),
52            ));
53        }
54
55        Ok(Self { providers, config })
56    }
57
58    /// Return the number of configured providers.
59    pub fn provider_count(&self) -> usize {
60        self.providers.len()
61    }
62
63    /// Execute a completion request with fallback logic.
64    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
65        let mut last_error: Option<SimpleAgentsError> = None;
66
67        for provider in &self.providers {
68            let attempt = self.execute_provider(provider, request).await;
69            match attempt {
70                Ok(response) => return Ok(response),
71                Err(err) => {
72                    if !self.should_fallback(&err) {
73                        return Err(err);
74                    }
75                    last_error = Some(err);
76                }
77            }
78        }
79
80        Err(last_error
81            .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
82    }
83
84    /// Execute a streaming request with fallback logic.
85    pub async fn stream(
86        &self,
87        request: &CompletionRequest,
88    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
89        let mut last_error: Option<SimpleAgentsError> = None;
90
91        for provider in &self.providers {
92            let provider_request = provider.transform_request(request)?;
93            match provider.execute_stream(provider_request).await {
94                Ok(stream) => return Ok(stream),
95                Err(err) => {
96                    if !self.should_fallback(&err) {
97                        return Err(err);
98                    }
99                    last_error = Some(err);
100                }
101            }
102        }
103
104        Err(last_error
105            .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
106    }
107
108    async fn execute_provider(
109        &self,
110        provider: &Arc<dyn Provider>,
111        request: &CompletionRequest,
112    ) -> Result<CompletionResponse> {
113        let provider_request = provider.transform_request(request)?;
114        let provider_response = provider.execute(provider_request).await?;
115        provider.transform_response(provider_response)
116    }
117
118    fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
119        if !self.config.retryable_only {
120            return true;
121        }
122
123        matches!(
124            error,
125            SimpleAgentsError::Provider(
126                ProviderError::RateLimit { .. }
127                    | ProviderError::Timeout(_)
128                    | ProviderError::ServerError(_)
129            ) | SimpleAgentsError::Network(_)
130        )
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use async_trait::async_trait;
138    use simple_agent_type::prelude::*;
139    use std::sync::atomic::{AtomicUsize, Ordering};
140
141    struct MockProvider {
142        name: &'static str,
143        attempts: AtomicUsize,
144        result: MockResult,
145    }
146
147    enum MockResult {
148        Ok,
149        RetryableError,
150        NonRetryableError,
151    }
152
153    impl MockProvider {
154        fn new(name: &'static str, result: MockResult) -> Self {
155            Self {
156                name,
157                attempts: AtomicUsize::new(0),
158                result,
159            }
160        }
161    }
162
163    #[async_trait]
164    impl Provider for MockProvider {
165        fn name(&self) -> &str {
166            self.name
167        }
168
169        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
170            Ok(ProviderRequest::new("http://example.com"))
171        }
172
173        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
174            self.attempts.fetch_add(1, Ordering::Relaxed);
175            match self.result {
176                MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
177                MockResult::RetryableError => Err(SimpleAgentsError::Provider(
178                    ProviderError::Timeout(std::time::Duration::from_secs(1)),
179                )),
180                MockResult::NonRetryableError => {
181                    Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
182                }
183            }
184        }
185
186        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
187            Ok(CompletionResponse {
188                id: "resp_test".to_string(),
189                model: "test-model".to_string(),
190                choices: vec![CompletionChoice {
191                    index: 0,
192                    message: Message::assistant("ok"),
193                    finish_reason: FinishReason::Stop,
194                    logprobs: None,
195                }],
196                usage: Usage::new(1, 1),
197                created: None,
198                provider: Some(self.name().to_string()),
199                healing_metadata: None,
200            })
201        }
202    }
203
204    fn build_request() -> CompletionRequest {
205        CompletionRequest::builder()
206            .model("test-model")
207            .message(Message::user("hello"))
208            .build()
209            .unwrap()
210    }
211
212    #[test]
213    fn empty_router_returns_error() {
214        let result = FallbackRouter::new(Vec::new());
215        match result {
216            Ok(_) => panic!("expected error, got Ok"),
217            Err(SimpleAgentsError::Routing(message)) => {
218                assert_eq!(message, "no providers configured");
219            }
220            Err(_) => panic!("unexpected error type"),
221        }
222    }
223
224    #[tokio::test]
225    async fn falls_back_on_retryable_error() {
226        let router = FallbackRouter::new(vec![
227            Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
228            Arc::new(MockProvider::new("p2", MockResult::Ok)),
229        ])
230        .unwrap();
231
232        let response = router.complete(&build_request()).await.unwrap();
233        assert_eq!(response.provider.as_deref(), Some("p2"));
234    }
235
236    #[tokio::test]
237    async fn stops_on_non_retryable_error() {
238        let router = FallbackRouter::new(vec![
239            Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
240            Arc::new(MockProvider::new("p2", MockResult::Ok)),
241        ])
242        .unwrap();
243
244        let err = router.complete(&build_request()).await.unwrap_err();
245        match err {
246            SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
247            _ => panic!("unexpected error"),
248        }
249    }
250
251    #[tokio::test]
252    async fn falls_back_on_all_errors_when_configured() {
253        let config = FallbackRouterConfig {
254            retryable_only: false,
255        };
256        let router = FallbackRouter::with_config(
257            vec![
258                Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
259                Arc::new(MockProvider::new("p2", MockResult::Ok)),
260            ],
261            config,
262        )
263        .unwrap();
264
265        let response = router.complete(&build_request()).await.unwrap();
266        assert_eq!(response.provider.as_deref(), Some("p2"));
267    }
268
269    #[tokio::test]
270    async fn stream_returns_last_provider_error() {
271        let router = FallbackRouter::new(vec![
272            Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
273            Arc::new(MockProvider::new("p2", MockResult::RetryableError)),
274        ])
275        .unwrap();
276
277        let err = match router.stream(&build_request()).await {
278            Ok(_) => panic!("expected stream setup to fail"),
279            Err(err) => err,
280        };
281        match err {
282            SimpleAgentsError::Provider(ProviderError::Timeout(_)) => {}
283            _ => panic!("unexpected error"),
284        }
285    }
286}