Skip to main content

forgeai_router/
lib.rs

1use async_trait::async_trait;
2use forgeai_core::{
3    AdapterInfo, ChatAdapter, ChatRequest, ChatResponse, ForgeError, StreamEvent, StreamResult,
4};
5use std::sync::Arc;
6
7pub fn pick_first_healthy(adapters: &[AdapterInfo]) -> Option<&AdapterInfo> {
8    adapters.first()
9}
10
11#[derive(Debug, Clone, Copy)]
12pub struct FailoverPolicy {
13    pub max_adapters_to_try: usize,
14}
15
16impl Default for FailoverPolicy {
17    fn default() -> Self {
18        Self {
19            max_adapters_to_try: usize::MAX,
20        }
21    }
22}
23
24pub struct FailoverRouter {
25    adapters: Vec<Arc<dyn ChatAdapter>>,
26    policy: FailoverPolicy,
27}
28
29impl FailoverRouter {
30    pub fn new(adapters: Vec<Arc<dyn ChatAdapter>>) -> Result<Self, ForgeError> {
31        Self::with_policy(adapters, FailoverPolicy::default())
32    }
33
34    pub fn with_policy(
35        adapters: Vec<Arc<dyn ChatAdapter>>,
36        policy: FailoverPolicy,
37    ) -> Result<Self, ForgeError> {
38        if adapters.is_empty() {
39            return Err(ForgeError::Validation(
40                "failover router requires at least one adapter".to_string(),
41            ));
42        }
43        Ok(Self { adapters, policy })
44    }
45
46    fn adapters_to_try(&self) -> impl Iterator<Item = &Arc<dyn ChatAdapter>> {
47        self.adapters.iter().take(self.policy.max_adapters_to_try)
48    }
49}
50
51#[async_trait]
52impl ChatAdapter for FailoverRouter {
53    fn info(&self) -> AdapterInfo {
54        let first = self.adapters[0].info();
55        AdapterInfo {
56            name: "failover-router".to_string(),
57            base_url: first.base_url,
58            capabilities: first.capabilities,
59        }
60    }
61
62    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
63        let mut last_error: Option<ForgeError> = None;
64        for adapter in self.adapters_to_try() {
65            match adapter.chat(request.clone()).await {
66                Ok(response) => return Ok(response),
67                Err(error) if should_failover(&error) => {
68                    last_error = Some(error);
69                }
70                Err(error) => return Err(error),
71            }
72        }
73        Err(last_error.unwrap_or_else(|| {
74            ForgeError::Internal("failover router exhausted adapters without error".to_string())
75        }))
76    }
77
78    async fn chat_stream(
79        &self,
80        request: ChatRequest,
81    ) -> Result<StreamResult<StreamEvent>, ForgeError> {
82        let mut last_error: Option<ForgeError> = None;
83        for adapter in self.adapters_to_try() {
84            match adapter.chat_stream(request.clone()).await {
85                Ok(stream) => return Ok(stream),
86                Err(error) if should_failover(&error) => {
87                    last_error = Some(error);
88                }
89                Err(error) => return Err(error),
90            }
91        }
92        Err(last_error.unwrap_or_else(|| {
93            ForgeError::Internal("failover router exhausted adapters without error".to_string())
94        }))
95    }
96}
97
98fn should_failover(error: &ForgeError) -> bool {
99    matches!(
100        error,
101        ForgeError::RateLimited | ForgeError::Transport(_) | ForgeError::Provider(_)
102    )
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use forgeai_core::{CapabilityMatrix, Message, Role};
109
110    struct MockAdapter {
111        name: String,
112        result: Result<ChatResponse, ForgeError>,
113    }
114
115    #[async_trait]
116    impl ChatAdapter for MockAdapter {
117        fn info(&self) -> AdapterInfo {
118            AdapterInfo {
119                name: self.name.clone(),
120                base_url: None,
121                capabilities: CapabilityMatrix {
122                    streaming: true,
123                    tools: true,
124                    structured_output: true,
125                    multimodal_input: false,
126                    citations: false,
127                },
128            }
129        }
130
131        async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ForgeError> {
132            match &self.result {
133                Ok(response) => Ok(response.clone()),
134                Err(ForgeError::Validation(message)) => {
135                    Err(ForgeError::Validation(message.clone()))
136                }
137                Err(ForgeError::Authentication) => Err(ForgeError::Authentication),
138                Err(ForgeError::RateLimited) => Err(ForgeError::RateLimited),
139                Err(ForgeError::Provider(message)) => Err(ForgeError::Provider(message.clone())),
140                Err(ForgeError::Transport(message)) => Err(ForgeError::Transport(message.clone())),
141                Err(ForgeError::Internal(message)) => Err(ForgeError::Internal(message.clone())),
142            }
143        }
144
145        async fn chat_stream(
146            &self,
147            _request: ChatRequest,
148        ) -> Result<StreamResult<StreamEvent>, ForgeError> {
149            Err(ForgeError::Provider(
150                "stream tests are out of scope for this unit test".to_string(),
151            ))
152        }
153    }
154
155    fn request() -> ChatRequest {
156        ChatRequest {
157            model: "mock".to_string(),
158            messages: vec![Message {
159                role: Role::User,
160                content: "hello".to_string(),
161            }],
162            temperature: None,
163            max_tokens: None,
164            tools: vec![],
165            metadata: serde_json::json!({}),
166        }
167    }
168
169    #[tokio::test]
170    async fn router_returns_first_successful_adapter() {
171        let router = FailoverRouter::new(vec![
172            Arc::new(MockAdapter {
173                name: "a".to_string(),
174                result: Err(ForgeError::Transport("timeout".to_string())),
175            }),
176            Arc::new(MockAdapter {
177                name: "b".to_string(),
178                result: Ok(ChatResponse {
179                    id: "2".to_string(),
180                    model: "mock".to_string(),
181                    output_text: "ok".to_string(),
182                    tool_calls: vec![],
183                    usage: None,
184                }),
185            }),
186        ])
187        .unwrap();
188
189        let response = router.chat(request()).await.unwrap();
190        assert_eq!(response.output_text, "ok");
191    }
192
193    #[tokio::test]
194    async fn router_stops_on_non_retryable_error() {
195        let router = FailoverRouter::new(vec![
196            Arc::new(MockAdapter {
197                name: "a".to_string(),
198                result: Err(ForgeError::Authentication),
199            }),
200            Arc::new(MockAdapter {
201                name: "b".to_string(),
202                result: Ok(ChatResponse {
203                    id: "2".to_string(),
204                    model: "mock".to_string(),
205                    output_text: "should not be used".to_string(),
206                    tool_calls: vec![],
207                    usage: None,
208                }),
209            }),
210        ])
211        .unwrap();
212
213        let err = router.chat(request()).await.unwrap_err();
214        assert!(matches!(err, ForgeError::Authentication));
215    }
216}