1use std::collections::HashMap;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicUsize, Ordering};
12
13use crate::errors::{BitrouterError, Result};
14
15use super::admin::{AdminRoutingTable, DynamicRoute, RouteEndpoint, RouteStrategy};
16use super::routing_table::{RouteEntry, RoutingTable, RoutingTarget};
17
18struct DynamicRouteData {
20 strategy: RouteStrategy,
21 endpoints: Vec<RouteEndpoint>,
22 counter: AtomicUsize,
23}
24
25pub struct DynamicRoutingTable<T> {
30 inner: T,
31 routes: RwLock<HashMap<String, DynamicRouteData>>,
32}
33
34impl<T> DynamicRoutingTable<T> {
35 pub fn new(inner: T) -> Self {
37 Self {
38 inner,
39 routes: RwLock::new(HashMap::new()),
40 }
41 }
42
43 fn resolve_dynamic(&self, model: &str) -> Option<RoutingTarget> {
47 let routes = self.routes.read().ok()?;
48 let data = routes.get(model)?;
49
50 if data.endpoints.is_empty() {
51 return None;
52 }
53
54 let endpoint = match data.strategy {
55 RouteStrategy::Priority => &data.endpoints[0],
56 RouteStrategy::LoadBalance => {
57 let idx = data.counter.fetch_add(1, Ordering::Relaxed) % data.endpoints.len();
58 &data.endpoints[idx]
59 }
60 };
61
62 Some(RoutingTarget {
63 provider_name: endpoint.provider.clone(),
64 model_id: endpoint.model_id.clone(),
65 })
66 }
67}
68
69impl<T: RoutingTable + Sync> RoutingTable for DynamicRoutingTable<T> {
70 async fn route(&self, incoming_model_name: &str) -> Result<RoutingTarget> {
71 if let Some(target) = self.resolve_dynamic(incoming_model_name) {
73 return Ok(target);
74 }
75 self.inner.route(incoming_model_name).await
77 }
78
79 fn list_routes(&self) -> Vec<RouteEntry> {
80 let mut entries = self.inner.list_routes();
81
82 if let Ok(routes) = self.routes.read() {
83 entries.retain(|e| !routes.contains_key(&e.model));
85
86 for (model, data) in routes.iter() {
88 if let Some(ep) = data.endpoints.first() {
89 entries.push(RouteEntry {
90 model: model.clone(),
91 provider: ep.provider.clone(),
92 protocol: ep.provider.clone(),
94 });
95 }
96 }
97 }
98
99 entries.sort_by(|a, b| a.model.cmp(&b.model));
100 entries
101 }
102}
103
104impl<T: RoutingTable + Sync> AdminRoutingTable for DynamicRoutingTable<T> {
105 fn add_route(&self, route: DynamicRoute) -> Result<()> {
106 if route.endpoints.is_empty() {
107 return Err(BitrouterError::invalid_request(
108 None,
109 "route must have at least one endpoint".to_owned(),
110 None,
111 ));
112 }
113
114 let data = DynamicRouteData {
115 strategy: route.strategy,
116 endpoints: route.endpoints,
117 counter: AtomicUsize::new(0),
118 };
119
120 let mut routes = self
121 .routes
122 .write()
123 .map_err(|_| BitrouterError::transport(None, "routing table lock poisoned"))?;
124 routes.insert(route.model, data);
125 Ok(())
126 }
127
128 fn remove_route(&self, model: &str) -> Result<bool> {
129 let mut routes = self
130 .routes
131 .write()
132 .map_err(|_| BitrouterError::transport(None, "routing table lock poisoned"))?;
133 Ok(routes.remove(model).is_some())
134 }
135
136 fn list_dynamic_routes(&self) -> Vec<DynamicRoute> {
137 let routes = match self.routes.read() {
138 Ok(r) => r,
139 Err(_) => return Vec::new(),
140 };
141 let mut result: Vec<DynamicRoute> = routes
142 .iter()
143 .map(|(model, data)| DynamicRoute {
144 model: model.clone(),
145 strategy: data.strategy.clone(),
146 endpoints: data.endpoints.clone(),
147 })
148 .collect();
149 result.sort_by(|a, b| a.model.cmp(&b.model));
150 result
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 struct StaticTable;
159
160 impl RoutingTable for StaticTable {
161 async fn route(&self, incoming: &str) -> Result<RoutingTarget> {
162 if incoming == "default" {
163 Ok(RoutingTarget {
164 provider_name: "openai".to_owned(),
165 model_id: "gpt-4o".to_owned(),
166 })
167 } else {
168 Err(BitrouterError::invalid_request(
169 None,
170 format!("no route: {incoming}"),
171 None,
172 ))
173 }
174 }
175
176 fn list_routes(&self) -> Vec<RouteEntry> {
177 vec![RouteEntry {
178 model: "default".to_owned(),
179 provider: "openai".to_owned(),
180 protocol: "openai".to_owned(),
181 }]
182 }
183 }
184
185 async fn route(table: &DynamicRoutingTable<StaticTable>, model: &str) -> Result<RoutingTarget> {
187 <DynamicRoutingTable<StaticTable> as RoutingTable>::route(table, model).await
188 }
189
190 #[tokio::test]
191 async fn dynamic_route_takes_precedence() {
192 let table = DynamicRoutingTable::new(StaticTable);
193 table
194 .add_route(DynamicRoute {
195 model: "default".to_owned(),
196 strategy: RouteStrategy::Priority,
197 endpoints: vec![RouteEndpoint {
198 provider: "anthropic".to_owned(),
199 model_id: "claude-sonnet-4-20250514".to_owned(),
200 }],
201 })
202 .ok();
203
204 let target = route(&table, "default").await.ok();
205 assert!(target.is_some());
206 let target = target.unwrap();
207 assert_eq!(target.provider_name, "anthropic");
208 assert_eq!(target.model_id, "claude-sonnet-4-20250514");
209 }
210
211 #[tokio::test]
212 async fn falls_back_to_inner_table() {
213 let table = DynamicRoutingTable::new(StaticTable);
214
215 let target = route(&table, "default").await.ok();
216 assert!(target.is_some());
217 let target = target.unwrap();
218 assert_eq!(target.provider_name, "openai");
219 assert_eq!(target.model_id, "gpt-4o");
220 }
221
222 #[tokio::test]
223 async fn add_and_remove_dynamic_route() {
224 let table = DynamicRoutingTable::new(StaticTable);
225
226 table
227 .add_route(DynamicRoute {
228 model: "research".to_owned(),
229 strategy: RouteStrategy::Priority,
230 endpoints: vec![RouteEndpoint {
231 provider: "openai".to_owned(),
232 model_id: "o1".to_owned(),
233 }],
234 })
235 .ok();
236
237 assert!(route(&table, "research").await.is_ok());
238 assert_eq!(table.list_dynamic_routes().len(), 1);
239
240 let removed = table.remove_route("research").ok();
241 assert_eq!(removed, Some(true));
242 assert!(route(&table, "research").await.is_err());
243 assert!(table.list_dynamic_routes().is_empty());
244 }
245
246 #[test]
247 fn remove_nonexistent_returns_false() {
248 let table = DynamicRoutingTable::new(StaticTable);
249 let removed = table.remove_route("nope").ok();
250 assert_eq!(removed, Some(false));
251 }
252
253 #[test]
254 fn add_route_with_no_endpoints_fails() {
255 let table = DynamicRoutingTable::new(StaticTable);
256 let result = table.add_route(DynamicRoute {
257 model: "empty".to_owned(),
258 strategy: RouteStrategy::Priority,
259 endpoints: vec![],
260 });
261 assert!(result.is_err());
262 }
263
264 #[tokio::test]
265 async fn load_balance_round_robin() {
266 let table = DynamicRoutingTable::new(StaticTable);
267 table
268 .add_route(DynamicRoute {
269 model: "balanced".to_owned(),
270 strategy: RouteStrategy::LoadBalance,
271 endpoints: vec![
272 RouteEndpoint {
273 provider: "openai".to_owned(),
274 model_id: "gpt-4o".to_owned(),
275 },
276 RouteEndpoint {
277 provider: "anthropic".to_owned(),
278 model_id: "claude-sonnet-4-20250514".to_owned(),
279 },
280 ],
281 })
282 .ok();
283
284 let t1 = route(&table, "balanced").await.ok().unwrap();
285 let t2 = route(&table, "balanced").await.ok().unwrap();
286 let t3 = route(&table, "balanced").await.ok().unwrap();
287
288 assert_eq!(t1.provider_name, "openai");
289 assert_eq!(t2.provider_name, "anthropic");
290 assert_eq!(t3.provider_name, "openai"); }
292
293 #[test]
294 fn list_routes_includes_dynamic() {
295 let table = DynamicRoutingTable::new(StaticTable);
296 table
297 .add_route(DynamicRoute {
298 model: "custom".to_owned(),
299 strategy: RouteStrategy::Priority,
300 endpoints: vec![RouteEndpoint {
301 provider: "anthropic".to_owned(),
302 model_id: "claude-sonnet-4-20250514".to_owned(),
303 }],
304 })
305 .ok();
306
307 let routes = table.list_routes();
308 assert!(routes.iter().any(|r| r.model == "custom"));
309 assert!(routes.iter().any(|r| r.model == "default"));
310 }
311
312 #[test]
313 fn dynamic_route_shadows_config_in_list() {
314 let table = DynamicRoutingTable::new(StaticTable);
315 table
316 .add_route(DynamicRoute {
317 model: "default".to_owned(),
318 strategy: RouteStrategy::Priority,
319 endpoints: vec![RouteEndpoint {
320 provider: "anthropic".to_owned(),
321 model_id: "claude-sonnet-4-20250514".to_owned(),
322 }],
323 })
324 .ok();
325
326 let routes = table.list_routes();
327 let defaults: Vec<_> = routes.iter().filter(|r| r.model == "default").collect();
328 assert_eq!(defaults.len(), 1);
329 assert_eq!(defaults[0].provider, "anthropic");
330 }
331}