1use std::collections::HashMap;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicUsize, Ordering};
12
13use crate::errors::{BitrouterError, Result};
14
15use super::admin::{AdminRoutingTable, DynamicRoute, RouteEndpoint, RouteStrategy};
16use super::routing_table::{ModelEntry, RouteEntry, RoutingTable, RoutingTarget};
17
18struct DynamicRouteData {
20 strategy: RouteStrategy,
21 endpoints: Vec<RouteEndpoint>,
22 counter: AtomicUsize,
23}
24
25pub struct DynamicRoutingTable<T> {
30 inner: T,
31 routes: RwLock<HashMap<String, DynamicRouteData>>,
32}
33
34impl<T> DynamicRoutingTable<T> {
35 pub fn new(inner: T) -> Self {
37 Self {
38 inner,
39 routes: RwLock::new(HashMap::new()),
40 }
41 }
42
43 pub fn inner(&self) -> &T {
45 &self.inner
46 }
47
48 fn resolve_dynamic(&self, model: &str) -> Option<RoutingTarget> {
52 let routes = self.routes.read().ok()?;
53 let data = routes.get(model)?;
54
55 if data.endpoints.is_empty() {
56 return None;
57 }
58
59 let endpoint = match data.strategy {
60 RouteStrategy::Priority => &data.endpoints[0],
61 RouteStrategy::LoadBalance => {
62 let idx = data.counter.fetch_add(1, Ordering::Relaxed) % data.endpoints.len();
63 &data.endpoints[idx]
64 }
65 };
66
67 Some(RoutingTarget {
68 provider_name: endpoint.provider.clone(),
69 model_id: endpoint.model_id.clone(),
70 })
71 }
72}
73
74impl<T: RoutingTable + Sync> RoutingTable for DynamicRoutingTable<T> {
75 async fn route(&self, incoming_model_name: &str) -> Result<RoutingTarget> {
76 if let Some(target) = self.resolve_dynamic(incoming_model_name) {
78 return Ok(target);
79 }
80 self.inner.route(incoming_model_name).await
82 }
83
84 fn list_routes(&self) -> Vec<RouteEntry> {
85 let mut entries = self.inner.list_routes();
86
87 if let Ok(routes) = self.routes.read() {
88 entries.retain(|e| !routes.contains_key(&e.model));
90
91 for (model, data) in routes.iter() {
93 if let Some(ep) = data.endpoints.first() {
94 entries.push(RouteEntry {
95 model: model.clone(),
96 provider: ep.provider.clone(),
97 protocol: ep.provider.clone(),
99 });
100 }
101 }
102 }
103
104 entries.sort_by(|a, b| a.model.cmp(&b.model));
105 entries
106 }
107
108 fn list_models(&self) -> Vec<ModelEntry> {
109 self.inner.list_models()
110 }
111}
112
113impl<T: RoutingTable + Sync> AdminRoutingTable for DynamicRoutingTable<T> {
114 fn add_route(&self, route: DynamicRoute) -> Result<()> {
115 if route.endpoints.is_empty() {
116 return Err(BitrouterError::invalid_request(
117 None,
118 "route must have at least one endpoint".to_owned(),
119 None,
120 ));
121 }
122
123 let data = DynamicRouteData {
124 strategy: route.strategy,
125 endpoints: route.endpoints,
126 counter: AtomicUsize::new(0),
127 };
128
129 let mut routes = self
130 .routes
131 .write()
132 .map_err(|_| BitrouterError::transport(None, "routing table lock poisoned"))?;
133 routes.insert(route.model, data);
134 Ok(())
135 }
136
137 fn remove_route(&self, model: &str) -> Result<bool> {
138 let mut routes = self
139 .routes
140 .write()
141 .map_err(|_| BitrouterError::transport(None, "routing table lock poisoned"))?;
142 Ok(routes.remove(model).is_some())
143 }
144
145 fn list_dynamic_routes(&self) -> Vec<DynamicRoute> {
146 let routes = match self.routes.read() {
147 Ok(r) => r,
148 Err(_) => return Vec::new(),
149 };
150 let mut result: Vec<DynamicRoute> = routes
151 .iter()
152 .map(|(model, data)| DynamicRoute {
153 model: model.clone(),
154 strategy: data.strategy.clone(),
155 endpoints: data.endpoints.clone(),
156 })
157 .collect();
158 result.sort_by(|a, b| a.model.cmp(&b.model));
159 result
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 struct StaticTable;
168
169 impl RoutingTable for StaticTable {
170 async fn route(&self, incoming: &str) -> Result<RoutingTarget> {
171 if incoming == "default" {
172 Ok(RoutingTarget {
173 provider_name: "openai".to_owned(),
174 model_id: "gpt-4o".to_owned(),
175 })
176 } else {
177 Err(BitrouterError::invalid_request(
178 None,
179 format!("no route: {incoming}"),
180 None,
181 ))
182 }
183 }
184
185 fn list_routes(&self) -> Vec<RouteEntry> {
186 vec![RouteEntry {
187 model: "default".to_owned(),
188 provider: "openai".to_owned(),
189 protocol: "openai".to_owned(),
190 }]
191 }
192 }
193
194 async fn route(table: &DynamicRoutingTable<StaticTable>, model: &str) -> Result<RoutingTarget> {
196 <DynamicRoutingTable<StaticTable> as RoutingTable>::route(table, model).await
197 }
198
199 #[tokio::test]
200 async fn dynamic_route_takes_precedence() {
201 let table = DynamicRoutingTable::new(StaticTable);
202 table
203 .add_route(DynamicRoute {
204 model: "default".to_owned(),
205 strategy: RouteStrategy::Priority,
206 endpoints: vec![RouteEndpoint {
207 provider: "anthropic".to_owned(),
208 model_id: "claude-sonnet-4-20250514".to_owned(),
209 }],
210 })
211 .ok();
212
213 let target = route(&table, "default").await.ok();
214 assert!(target.is_some());
215 let target = target.unwrap();
216 assert_eq!(target.provider_name, "anthropic");
217 assert_eq!(target.model_id, "claude-sonnet-4-20250514");
218 }
219
220 #[tokio::test]
221 async fn falls_back_to_inner_table() {
222 let table = DynamicRoutingTable::new(StaticTable);
223
224 let target = route(&table, "default").await.ok();
225 assert!(target.is_some());
226 let target = target.unwrap();
227 assert_eq!(target.provider_name, "openai");
228 assert_eq!(target.model_id, "gpt-4o");
229 }
230
231 #[tokio::test]
232 async fn add_and_remove_dynamic_route() {
233 let table = DynamicRoutingTable::new(StaticTable);
234
235 table
236 .add_route(DynamicRoute {
237 model: "research".to_owned(),
238 strategy: RouteStrategy::Priority,
239 endpoints: vec![RouteEndpoint {
240 provider: "openai".to_owned(),
241 model_id: "o1".to_owned(),
242 }],
243 })
244 .ok();
245
246 assert!(route(&table, "research").await.is_ok());
247 assert_eq!(table.list_dynamic_routes().len(), 1);
248
249 let removed = table.remove_route("research").ok();
250 assert_eq!(removed, Some(true));
251 assert!(route(&table, "research").await.is_err());
252 assert!(table.list_dynamic_routes().is_empty());
253 }
254
255 #[test]
256 fn remove_nonexistent_returns_false() {
257 let table = DynamicRoutingTable::new(StaticTable);
258 let removed = table.remove_route("nope").ok();
259 assert_eq!(removed, Some(false));
260 }
261
262 #[test]
263 fn add_route_with_no_endpoints_fails() {
264 let table = DynamicRoutingTable::new(StaticTable);
265 let result = table.add_route(DynamicRoute {
266 model: "empty".to_owned(),
267 strategy: RouteStrategy::Priority,
268 endpoints: vec![],
269 });
270 assert!(result.is_err());
271 }
272
273 #[tokio::test]
274 async fn load_balance_round_robin() {
275 let table = DynamicRoutingTable::new(StaticTable);
276 table
277 .add_route(DynamicRoute {
278 model: "balanced".to_owned(),
279 strategy: RouteStrategy::LoadBalance,
280 endpoints: vec![
281 RouteEndpoint {
282 provider: "openai".to_owned(),
283 model_id: "gpt-4o".to_owned(),
284 },
285 RouteEndpoint {
286 provider: "anthropic".to_owned(),
287 model_id: "claude-sonnet-4-20250514".to_owned(),
288 },
289 ],
290 })
291 .ok();
292
293 let t1 = route(&table, "balanced").await.ok().unwrap();
294 let t2 = route(&table, "balanced").await.ok().unwrap();
295 let t3 = route(&table, "balanced").await.ok().unwrap();
296
297 assert_eq!(t1.provider_name, "openai");
298 assert_eq!(t2.provider_name, "anthropic");
299 assert_eq!(t3.provider_name, "openai"); }
301
302 #[test]
303 fn list_routes_includes_dynamic() {
304 let table = DynamicRoutingTable::new(StaticTable);
305 table
306 .add_route(DynamicRoute {
307 model: "custom".to_owned(),
308 strategy: RouteStrategy::Priority,
309 endpoints: vec![RouteEndpoint {
310 provider: "anthropic".to_owned(),
311 model_id: "claude-sonnet-4-20250514".to_owned(),
312 }],
313 })
314 .ok();
315
316 let routes = table.list_routes();
317 assert!(routes.iter().any(|r| r.model == "custom"));
318 assert!(routes.iter().any(|r| r.model == "default"));
319 }
320
321 #[test]
322 fn dynamic_route_shadows_config_in_list() {
323 let table = DynamicRoutingTable::new(StaticTable);
324 table
325 .add_route(DynamicRoute {
326 model: "default".to_owned(),
327 strategy: RouteStrategy::Priority,
328 endpoints: vec![RouteEndpoint {
329 provider: "anthropic".to_owned(),
330 model_id: "claude-sonnet-4-20250514".to_owned(),
331 }],
332 })
333 .ok();
334
335 let routes = table.list_routes();
336 let defaults: Vec<_> = routes.iter().filter(|r| r.model == "default").collect();
337 assert_eq!(defaults.len(), 1);
338 assert_eq!(defaults[0].provider, "anthropic");
339 }
340}