Skip to main content

simple_agents_router/
fallback.rs

1//! Fallback routing implementation.
2//!
3//! Attempts providers in order, falling back on retryable errors.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderError, Result,
7    SimpleAgentsError,
8};
9use std::sync::Arc;
10
11/// Configuration for fallback routing.
12#[derive(Debug, Clone, Copy)]
13pub struct FallbackRouterConfig {
14    /// If true, fallback only on retryable provider errors.
15    pub retryable_only: bool,
16}
17
18impl Default for FallbackRouterConfig {
19    fn default() -> Self {
20        Self {
21            retryable_only: true,
22        }
23    }
24}
25
26/// Router that tries providers in order and falls back on eligible errors.
27pub struct FallbackRouter {
28    providers: Vec<Arc<dyn Provider>>,
29    config: FallbackRouterConfig,
30}
31
32impl FallbackRouter {
33    /// Create a new fallback router.
34    ///
35    /// # Errors
36    /// Returns a routing error if no providers are supplied.
37    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
38        Self::with_config(providers, FallbackRouterConfig::default())
39    }
40
41    /// Create a new fallback router with custom configuration.
42    ///
43    /// # Errors
44    /// Returns a routing error if no providers are supplied.
45    pub fn with_config(
46        providers: Vec<Arc<dyn Provider>>,
47        config: FallbackRouterConfig,
48    ) -> Result<Self> {
49        if providers.is_empty() {
50            return Err(SimpleAgentsError::Routing(
51                "no providers configured".to_string(),
52            ));
53        }
54
55        Ok(Self { providers, config })
56    }
57
58    /// Return the number of configured providers.
59    pub fn provider_count(&self) -> usize {
60        self.providers.len()
61    }
62
63    /// Execute a completion request with fallback logic.
64    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
65        let mut last_error: Option<SimpleAgentsError> = None;
66
67        for provider in &self.providers {
68            let attempt = self.execute_provider(provider, request).await;
69            match attempt {
70                Ok(response) => return Ok(response),
71                Err(err) => {
72                    if !self.should_fallback(&err) {
73                        return Err(err);
74                    }
75                    last_error = Some(err);
76                }
77            }
78        }
79
80        Err(last_error
81            .unwrap_or_else(|| SimpleAgentsError::Routing("no providers configured".to_string())))
82    }
83
84    /// Execute a streaming request with fallback logic.
85    pub async fn stream(
86        &self,
87        request: &CompletionRequest,
88    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
89        for provider in &self.providers {
90            let provider_request = provider.transform_request(request)?;
91            match provider.execute_stream(provider_request).await {
92                Ok(stream) => return Ok(stream),
93                Err(err) => {
94                    if !self.should_fallback(&err) {
95                        return Err(err);
96                    }
97                    // Continue to next provider
98                }
99            }
100        }
101
102        Err(SimpleAgentsError::Routing(
103            "no providers configured".to_string(),
104        ))
105    }
106
107    async fn execute_provider(
108        &self,
109        provider: &Arc<dyn Provider>,
110        request: &CompletionRequest,
111    ) -> Result<CompletionResponse> {
112        let provider_request = provider.transform_request(request)?;
113        let provider_response = provider.execute(provider_request).await?;
114        provider.transform_response(provider_response)
115    }
116
117    fn should_fallback(&self, error: &SimpleAgentsError) -> bool {
118        if !self.config.retryable_only {
119            return true;
120        }
121
122        matches!(
123            error,
124            SimpleAgentsError::Provider(
125                ProviderError::RateLimit { .. }
126                    | ProviderError::Timeout(_)
127                    | ProviderError::ServerError(_)
128            ) | SimpleAgentsError::Network(_)
129        )
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use async_trait::async_trait;
137    use simple_agent_type::prelude::*;
138    use std::sync::atomic::{AtomicUsize, Ordering};
139
140    struct MockProvider {
141        name: &'static str,
142        attempts: AtomicUsize,
143        result: MockResult,
144    }
145
146    enum MockResult {
147        Ok,
148        RetryableError,
149        NonRetryableError,
150    }
151
152    impl MockProvider {
153        fn new(name: &'static str, result: MockResult) -> Self {
154            Self {
155                name,
156                attempts: AtomicUsize::new(0),
157                result,
158            }
159        }
160    }
161
162    #[async_trait]
163    impl Provider for MockProvider {
164        fn name(&self) -> &str {
165            self.name
166        }
167
168        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
169            Ok(ProviderRequest::new("http://example.com"))
170        }
171
172        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
173            self.attempts.fetch_add(1, Ordering::Relaxed);
174            match self.result {
175                MockResult::Ok => Ok(ProviderResponse::new(200, serde_json::Value::Null)),
176                MockResult::RetryableError => Err(SimpleAgentsError::Provider(
177                    ProviderError::Timeout(std::time::Duration::from_secs(1)),
178                )),
179                MockResult::NonRetryableError => {
180                    Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
181                }
182            }
183        }
184
185        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
186            Ok(CompletionResponse {
187                id: "resp_test".to_string(),
188                model: "test-model".to_string(),
189                choices: vec![CompletionChoice {
190                    index: 0,
191                    message: Message::assistant("ok"),
192                    finish_reason: FinishReason::Stop,
193                    logprobs: None,
194                }],
195                usage: Usage::new(1, 1),
196                created: None,
197                provider: Some(self.name().to_string()),
198                healing_metadata: None,
199            })
200        }
201    }
202
203    fn build_request() -> CompletionRequest {
204        CompletionRequest::builder()
205            .model("test-model")
206            .message(Message::user("hello"))
207            .build()
208            .unwrap()
209    }
210
211    #[test]
212    fn empty_router_returns_error() {
213        let result = FallbackRouter::new(Vec::new());
214        match result {
215            Ok(_) => panic!("expected error, got Ok"),
216            Err(SimpleAgentsError::Routing(message)) => {
217                assert_eq!(message, "no providers configured");
218            }
219            Err(_) => panic!("unexpected error type"),
220        }
221    }
222
223    #[tokio::test]
224    async fn falls_back_on_retryable_error() {
225        let router = FallbackRouter::new(vec![
226            Arc::new(MockProvider::new("p1", MockResult::RetryableError)),
227            Arc::new(MockProvider::new("p2", MockResult::Ok)),
228        ])
229        .unwrap();
230
231        let response = router.complete(&build_request()).await.unwrap();
232        assert_eq!(response.provider.as_deref(), Some("p2"));
233    }
234
235    #[tokio::test]
236    async fn stops_on_non_retryable_error() {
237        let router = FallbackRouter::new(vec![
238            Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
239            Arc::new(MockProvider::new("p2", MockResult::Ok)),
240        ])
241        .unwrap();
242
243        let err = router.complete(&build_request()).await.unwrap_err();
244        match err {
245            SimpleAgentsError::Provider(ProviderError::InvalidApiKey) => {}
246            _ => panic!("unexpected error"),
247        }
248    }
249
250    #[tokio::test]
251    async fn falls_back_on_all_errors_when_configured() {
252        let config = FallbackRouterConfig {
253            retryable_only: false,
254        };
255        let router = FallbackRouter::with_config(
256            vec![
257                Arc::new(MockProvider::new("p1", MockResult::NonRetryableError)),
258                Arc::new(MockProvider::new("p2", MockResult::Ok)),
259            ],
260            config,
261        )
262        .unwrap();
263
264        let response = router.complete(&build_request()).await.unwrap();
265        assert_eq!(response.provider.as_deref(), Some("p2"));
266    }
267}