1use async_trait::async_trait;
2use forgeai_core::{
3 AdapterInfo, ChatAdapter, ChatRequest, ChatResponse, ForgeError, StreamEvent, StreamResult,
4};
5use std::sync::Arc;
6
7pub fn pick_first_healthy(adapters: &[AdapterInfo]) -> Option<&AdapterInfo> {
8 adapters.first()
9}
10
11#[derive(Debug, Clone, Copy)]
12pub struct FailoverPolicy {
13 pub max_adapters_to_try: usize,
14}
15
16impl Default for FailoverPolicy {
17 fn default() -> Self {
18 Self {
19 max_adapters_to_try: usize::MAX,
20 }
21 }
22}
23
24pub struct FailoverRouter {
25 adapters: Vec<Arc<dyn ChatAdapter>>,
26 policy: FailoverPolicy,
27}
28
29impl FailoverRouter {
30 pub fn new(adapters: Vec<Arc<dyn ChatAdapter>>) -> Result<Self, ForgeError> {
31 Self::with_policy(adapters, FailoverPolicy::default())
32 }
33
34 pub fn with_policy(
35 adapters: Vec<Arc<dyn ChatAdapter>>,
36 policy: FailoverPolicy,
37 ) -> Result<Self, ForgeError> {
38 if adapters.is_empty() {
39 return Err(ForgeError::Validation(
40 "failover router requires at least one adapter".to_string(),
41 ));
42 }
43 Ok(Self { adapters, policy })
44 }
45
46 fn adapters_to_try(&self) -> impl Iterator<Item = &Arc<dyn ChatAdapter>> {
47 self.adapters.iter().take(self.policy.max_adapters_to_try)
48 }
49}
50
51#[async_trait]
52impl ChatAdapter for FailoverRouter {
53 fn info(&self) -> AdapterInfo {
54 let first = self.adapters[0].info();
55 AdapterInfo {
56 name: "failover-router".to_string(),
57 base_url: first.base_url,
58 capabilities: first.capabilities,
59 }
60 }
61
62 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
63 let mut last_error: Option<ForgeError> = None;
64 for adapter in self.adapters_to_try() {
65 match adapter.chat(request.clone()).await {
66 Ok(response) => return Ok(response),
67 Err(error) if should_failover(&error) => {
68 last_error = Some(error);
69 }
70 Err(error) => return Err(error),
71 }
72 }
73 Err(last_error.unwrap_or_else(|| {
74 ForgeError::Internal("failover router exhausted adapters without error".to_string())
75 }))
76 }
77
78 async fn chat_stream(
79 &self,
80 request: ChatRequest,
81 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
82 let mut last_error: Option<ForgeError> = None;
83 for adapter in self.adapters_to_try() {
84 match adapter.chat_stream(request.clone()).await {
85 Ok(stream) => return Ok(stream),
86 Err(error) if should_failover(&error) => {
87 last_error = Some(error);
88 }
89 Err(error) => return Err(error),
90 }
91 }
92 Err(last_error.unwrap_or_else(|| {
93 ForgeError::Internal("failover router exhausted adapters without error".to_string())
94 }))
95 }
96}
97
98fn should_failover(error: &ForgeError) -> bool {
99 matches!(
100 error,
101 ForgeError::RateLimited | ForgeError::Transport(_) | ForgeError::Provider(_)
102 )
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use forgeai_core::{CapabilityMatrix, Message, Role};
109
110 struct MockAdapter {
111 name: String,
112 result: Result<ChatResponse, ForgeError>,
113 }
114
115 #[async_trait]
116 impl ChatAdapter for MockAdapter {
117 fn info(&self) -> AdapterInfo {
118 AdapterInfo {
119 name: self.name.clone(),
120 base_url: None,
121 capabilities: CapabilityMatrix {
122 streaming: true,
123 tools: true,
124 structured_output: true,
125 multimodal_input: false,
126 citations: false,
127 },
128 }
129 }
130
131 async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ForgeError> {
132 match &self.result {
133 Ok(response) => Ok(response.clone()),
134 Err(ForgeError::Validation(message)) => {
135 Err(ForgeError::Validation(message.clone()))
136 }
137 Err(ForgeError::Authentication) => Err(ForgeError::Authentication),
138 Err(ForgeError::RateLimited) => Err(ForgeError::RateLimited),
139 Err(ForgeError::Provider(message)) => Err(ForgeError::Provider(message.clone())),
140 Err(ForgeError::Transport(message)) => Err(ForgeError::Transport(message.clone())),
141 Err(ForgeError::Internal(message)) => Err(ForgeError::Internal(message.clone())),
142 }
143 }
144
145 async fn chat_stream(
146 &self,
147 _request: ChatRequest,
148 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
149 Err(ForgeError::Provider(
150 "stream tests are out of scope for this unit test".to_string(),
151 ))
152 }
153 }
154
155 fn request() -> ChatRequest {
156 ChatRequest {
157 model: "mock".to_string(),
158 messages: vec![Message {
159 role: Role::User,
160 content: "hello".to_string(),
161 }],
162 temperature: None,
163 max_tokens: None,
164 tools: vec![],
165 metadata: serde_json::json!({}),
166 }
167 }
168
169 #[tokio::test]
170 async fn router_returns_first_successful_adapter() {
171 let router = FailoverRouter::new(vec![
172 Arc::new(MockAdapter {
173 name: "a".to_string(),
174 result: Err(ForgeError::Transport("timeout".to_string())),
175 }),
176 Arc::new(MockAdapter {
177 name: "b".to_string(),
178 result: Ok(ChatResponse {
179 id: "2".to_string(),
180 model: "mock".to_string(),
181 output_text: "ok".to_string(),
182 tool_calls: vec![],
183 usage: None,
184 }),
185 }),
186 ])
187 .unwrap();
188
189 let response = router.chat(request()).await.unwrap();
190 assert_eq!(response.output_text, "ok");
191 }
192
193 #[tokio::test]
194 async fn router_stops_on_non_retryable_error() {
195 let router = FailoverRouter::new(vec![
196 Arc::new(MockAdapter {
197 name: "a".to_string(),
198 result: Err(ForgeError::Authentication),
199 }),
200 Arc::new(MockAdapter {
201 name: "b".to_string(),
202 result: Ok(ChatResponse {
203 id: "2".to_string(),
204 model: "mock".to_string(),
205 output_text: "should not be used".to_string(),
206 tool_calls: vec![],
207 usage: None,
208 }),
209 }),
210 ])
211 .unwrap();
212
213 let err = router.chat(request()).await.unwrap_err();
214 assert!(matches!(err, ForgeError::Authentication));
215 }
216}