1use std::collections::HashMap;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicUsize, Ordering};
12
13use crate::errors::{BitrouterError, Result};
14
15use super::admin::{AdminRoutingTable, DynamicRoute, RouteEndpoint, RouteStrategy};
16use super::registry::{ModelEntry, ModelRegistry};
17use super::routing_table::{RouteEntry, RoutingTable, RoutingTarget};
18
19struct DynamicRouteData {
21 strategy: RouteStrategy,
22 endpoints: Vec<RouteEndpoint>,
23 counter: AtomicUsize,
24}
25
26pub struct DynamicRoutingTable<T> {
31 inner: T,
32 routes: RwLock<HashMap<String, DynamicRouteData>>,
33}
34
35impl<T> DynamicRoutingTable<T> {
36 pub fn new(inner: T) -> Self {
38 Self {
39 inner,
40 routes: RwLock::new(HashMap::new()),
41 }
42 }
43
44 pub fn inner(&self) -> &T {
46 &self.inner
47 }
48
49 fn resolve_dynamic(&self, model: &str) -> Option<RoutingTarget> {
53 let routes = self.routes.read().ok()?;
54 let data = routes.get(model)?;
55
56 if data.endpoints.is_empty() {
57 return None;
58 }
59
60 let endpoint = match data.strategy {
61 RouteStrategy::Priority => &data.endpoints[0],
62 RouteStrategy::LoadBalance => {
63 let idx = data.counter.fetch_add(1, Ordering::Relaxed) % data.endpoints.len();
64 &data.endpoints[idx]
65 }
66 };
67
68 Some(RoutingTarget {
69 provider_name: endpoint.provider.clone(),
70 model_id: endpoint.model_id.clone(),
71 })
72 }
73}
74
75impl<T: RoutingTable + Sync> RoutingTable for DynamicRoutingTable<T> {
76 async fn route(&self, incoming_model_name: &str) -> Result<RoutingTarget> {
77 if let Some(target) = self.resolve_dynamic(incoming_model_name) {
79 return Ok(target);
80 }
81 self.inner.route(incoming_model_name).await
83 }
84
85 fn list_routes(&self) -> Vec<RouteEntry> {
86 let mut entries = self.inner.list_routes();
87
88 if let Ok(routes) = self.routes.read() {
89 entries.retain(|e| !routes.contains_key(&e.model));
91
92 for (model, data) in routes.iter() {
94 if let Some(ep) = data.endpoints.first() {
95 entries.push(RouteEntry {
96 model: model.clone(),
97 provider: ep.provider.clone(),
98 protocol: ep.provider.clone(),
100 });
101 }
102 }
103 }
104
105 entries.sort_by(|a, b| a.model.cmp(&b.model));
106 entries
107 }
108}
109
110impl<T: ModelRegistry> ModelRegistry for DynamicRoutingTable<T> {
111 fn list_models(&self) -> Vec<ModelEntry> {
112 self.inner.list_models()
113 }
114}
115
116impl<T: RoutingTable + Sync> AdminRoutingTable for DynamicRoutingTable<T> {
117 fn add_route(&self, route: DynamicRoute) -> Result<()> {
118 if route.endpoints.is_empty() {
119 return Err(BitrouterError::invalid_request(
120 None,
121 "route must have at least one endpoint".to_owned(),
122 None,
123 ));
124 }
125
126 let data = DynamicRouteData {
127 strategy: route.strategy,
128 endpoints: route.endpoints,
129 counter: AtomicUsize::new(0),
130 };
131
132 let mut routes = self
133 .routes
134 .write()
135 .map_err(|_| BitrouterError::transport(None, "routing table lock poisoned"))?;
136 routes.insert(route.model, data);
137 Ok(())
138 }
139
140 fn remove_route(&self, model: &str) -> Result<bool> {
141 let mut routes = self
142 .routes
143 .write()
144 .map_err(|_| BitrouterError::transport(None, "routing table lock poisoned"))?;
145 Ok(routes.remove(model).is_some())
146 }
147
148 fn list_dynamic_routes(&self) -> Vec<DynamicRoute> {
149 let routes = match self.routes.read() {
150 Ok(r) => r,
151 Err(_) => return Vec::new(),
152 };
153 let mut result: Vec<DynamicRoute> = routes
154 .iter()
155 .map(|(model, data)| DynamicRoute {
156 model: model.clone(),
157 strategy: data.strategy.clone(),
158 endpoints: data.endpoints.clone(),
159 })
160 .collect();
161 result.sort_by(|a, b| a.model.cmp(&b.model));
162 result
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 struct StaticTable;
171
172 impl RoutingTable for StaticTable {
173 async fn route(&self, incoming: &str) -> Result<RoutingTarget> {
174 if incoming == "default" {
175 Ok(RoutingTarget {
176 provider_name: "openai".to_owned(),
177 model_id: "gpt-4o".to_owned(),
178 })
179 } else {
180 Err(BitrouterError::invalid_request(
181 None,
182 format!("no route: {incoming}"),
183 None,
184 ))
185 }
186 }
187
188 fn list_routes(&self) -> Vec<RouteEntry> {
189 vec![RouteEntry {
190 model: "default".to_owned(),
191 provider: "openai".to_owned(),
192 protocol: "openai".to_owned(),
193 }]
194 }
195 }
196
197 async fn route(table: &DynamicRoutingTable<StaticTable>, model: &str) -> Result<RoutingTarget> {
199 <DynamicRoutingTable<StaticTable> as RoutingTable>::route(table, model).await
200 }
201
202 #[tokio::test]
203 async fn dynamic_route_takes_precedence() {
204 let table = DynamicRoutingTable::new(StaticTable);
205 table
206 .add_route(DynamicRoute {
207 model: "default".to_owned(),
208 strategy: RouteStrategy::Priority,
209 endpoints: vec![RouteEndpoint {
210 provider: "anthropic".to_owned(),
211 model_id: "claude-sonnet-4-20250514".to_owned(),
212 }],
213 })
214 .ok();
215
216 let target = route(&table, "default").await.ok();
217 assert!(target.is_some());
218 let target = target.unwrap();
219 assert_eq!(target.provider_name, "anthropic");
220 assert_eq!(target.model_id, "claude-sonnet-4-20250514");
221 }
222
223 #[tokio::test]
224 async fn falls_back_to_inner_table() {
225 let table = DynamicRoutingTable::new(StaticTable);
226
227 let target = route(&table, "default").await.ok();
228 assert!(target.is_some());
229 let target = target.unwrap();
230 assert_eq!(target.provider_name, "openai");
231 assert_eq!(target.model_id, "gpt-4o");
232 }
233
234 #[tokio::test]
235 async fn add_and_remove_dynamic_route() {
236 let table = DynamicRoutingTable::new(StaticTable);
237
238 table
239 .add_route(DynamicRoute {
240 model: "research".to_owned(),
241 strategy: RouteStrategy::Priority,
242 endpoints: vec![RouteEndpoint {
243 provider: "openai".to_owned(),
244 model_id: "o1".to_owned(),
245 }],
246 })
247 .ok();
248
249 assert!(route(&table, "research").await.is_ok());
250 assert_eq!(table.list_dynamic_routes().len(), 1);
251
252 let removed = table.remove_route("research").ok();
253 assert_eq!(removed, Some(true));
254 assert!(route(&table, "research").await.is_err());
255 assert!(table.list_dynamic_routes().is_empty());
256 }
257
258 #[test]
259 fn remove_nonexistent_returns_false() {
260 let table = DynamicRoutingTable::new(StaticTable);
261 let removed = table.remove_route("nope").ok();
262 assert_eq!(removed, Some(false));
263 }
264
265 #[test]
266 fn add_route_with_no_endpoints_fails() {
267 let table = DynamicRoutingTable::new(StaticTable);
268 let result = table.add_route(DynamicRoute {
269 model: "empty".to_owned(),
270 strategy: RouteStrategy::Priority,
271 endpoints: vec![],
272 });
273 assert!(result.is_err());
274 }
275
276 #[tokio::test]
277 async fn load_balance_round_robin() {
278 let table = DynamicRoutingTable::new(StaticTable);
279 table
280 .add_route(DynamicRoute {
281 model: "balanced".to_owned(),
282 strategy: RouteStrategy::LoadBalance,
283 endpoints: vec![
284 RouteEndpoint {
285 provider: "openai".to_owned(),
286 model_id: "gpt-4o".to_owned(),
287 },
288 RouteEndpoint {
289 provider: "anthropic".to_owned(),
290 model_id: "claude-sonnet-4-20250514".to_owned(),
291 },
292 ],
293 })
294 .ok();
295
296 let t1 = route(&table, "balanced").await.ok().unwrap();
297 let t2 = route(&table, "balanced").await.ok().unwrap();
298 let t3 = route(&table, "balanced").await.ok().unwrap();
299
300 assert_eq!(t1.provider_name, "openai");
301 assert_eq!(t2.provider_name, "anthropic");
302 assert_eq!(t3.provider_name, "openai"); }
304
305 #[test]
306 fn list_routes_includes_dynamic() {
307 let table = DynamicRoutingTable::new(StaticTable);
308 table
309 .add_route(DynamicRoute {
310 model: "custom".to_owned(),
311 strategy: RouteStrategy::Priority,
312 endpoints: vec![RouteEndpoint {
313 provider: "anthropic".to_owned(),
314 model_id: "claude-sonnet-4-20250514".to_owned(),
315 }],
316 })
317 .ok();
318
319 let routes = table.list_routes();
320 assert!(routes.iter().any(|r| r.model == "custom"));
321 assert!(routes.iter().any(|r| r.model == "default"));
322 }
323
324 #[test]
325 fn dynamic_route_shadows_config_in_list() {
326 let table = DynamicRoutingTable::new(StaticTable);
327 table
328 .add_route(DynamicRoute {
329 model: "default".to_owned(),
330 strategy: RouteStrategy::Priority,
331 endpoints: vec![RouteEndpoint {
332 provider: "anthropic".to_owned(),
333 model_id: "claude-sonnet-4-20250514".to_owned(),
334 }],
335 })
336 .ok();
337
338 let routes = table.list_routes();
339 let defaults: Vec<_> = routes.iter().filter(|r| r.model == "default").collect();
340 assert_eq!(defaults.len(), 1);
341 assert_eq!(defaults[0].provider, "anthropic");
342 }
343}