1use crate::ModelProvider;
5use crate::traits::{ChatRequest, ChatResponse};
6use async_trait::async_trait;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct Route {
12 pub provider_name: String,
13 pub model: String,
14}
15
16pub struct RouterProvider {
25 routes: HashMap<String, (usize, String)>, providers: Vec<(String, Box<dyn ModelProvider>)>,
27 default_index: usize,
28}
29
30impl RouterProvider {
31 pub fn new(
36 providers: Vec<(String, Box<dyn ModelProvider>)>,
37 routes: Vec<(String, Route)>,
38 _default_model: String,
39 ) -> Self {
40 let name_to_index: HashMap<&str, usize> = providers
42 .iter()
43 .enumerate()
44 .map(|(i, (name, _))| (name.as_str(), i))
45 .collect();
46
47 let resolved_routes: HashMap<String, (usize, String)> = routes
49 .into_iter()
50 .filter_map(|(hint, route)| {
51 let index = name_to_index.get(route.provider_name.as_str()).copied();
52 match index {
53 Some(i) => Some((hint, (i, route.model))),
54 None => {
55 tracing::warn!(
56 hint = hint,
57 provider = route.provider_name,
58 "Route references unknown provider, skipping"
59 );
60 None
61 }
62 }
63 })
64 .collect();
65
66 Self {
67 routes: resolved_routes,
68 providers,
69 default_index: 0,
70 }
71 }
72
73 fn resolve(&self, model: &str) -> (usize, String) {
79 if let Some(hint) = model.strip_prefix("hint:") {
80 if let Some((idx, resolved_model)) = self.routes.get(hint) {
81 return (*idx, resolved_model.clone());
82 }
83 tracing::warn!(
84 hint = hint,
85 "Unknown route hint, falling back to default provider"
86 );
87 }
88
89 (self.default_index, model.to_string())
91 }
92}
93
94#[async_trait]
95impl ModelProvider for RouterProvider {
96 async fn chat(
97 &self,
98 request: ChatRequest<'_>,
99 model: &str,
100 temperature: f64,
101 ) -> anyhow::Result<ChatResponse> {
102 let (provider_idx, resolved_model) = self.resolve(model);
103 let (_, provider) = &self.providers[provider_idx];
104 provider.chat(request, &resolved_model, temperature).await
105 }
106
107 fn context_window(&self, model: &str) -> Option<usize> {
108 self.providers
109 .get(self.default_index)
110 .and_then(|(_, p)| p.context_window(model))
111 }
112
113 fn supports_native_tools(&self) -> bool {
114 self.providers
115 .get(self.default_index)
116 .map(|(_, p)| p.supports_native_tools())
117 .unwrap_or(false)
118 }
119
120 fn supports_developer_role(&self, model: &str) -> bool {
121 self.providers
122 .get(self.default_index)
123 .map(|(_, p)| p.supports_developer_role(model))
124 .unwrap_or(false)
125 }
126
127 async fn warmup(&self) -> anyhow::Result<()> {
128 for (name, provider) in &self.providers {
129 tracing::info!(provider = name, "Warming up routed provider");
130 if let Err(e) = provider.warmup().await {
131 tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
132 }
133 }
134 Ok(())
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::traits::{ChatRequest, ChatResponse, TokenUsage, one_shot};
142 use std::sync::Arc;
143 use std::sync::atomic::{AtomicUsize, Ordering};
144
145 struct MockProvider {
146 calls: Arc<AtomicUsize>,
147 response: &'static str,
148 last_model: std::sync::Mutex<String>,
149 }
150
151 impl MockProvider {
152 fn new(response: &'static str) -> Self {
153 Self {
154 calls: Arc::new(AtomicUsize::new(0)),
155 response,
156 last_model: std::sync::Mutex::new(String::new()),
157 }
158 }
159
160 fn call_count(&self) -> usize {
161 self.calls.load(Ordering::SeqCst)
162 }
163
164 fn last_model(&self) -> String {
165 self.last_model.lock().unwrap().clone()
166 }
167 }
168
169 #[async_trait]
170 impl ModelProvider for MockProvider {
171 async fn chat(
172 &self,
173 _request: ChatRequest<'_>,
174 model: &str,
175 _temperature: f64,
176 ) -> anyhow::Result<ChatResponse> {
177 self.calls.fetch_add(1, Ordering::SeqCst);
178 *self.last_model.lock().unwrap() = model.to_string();
179 Ok(ChatResponse {
180 text: Some(self.response.to_string()),
181 tool_calls: vec![],
182 provider_tool_calls: vec![],
183 usage: TokenUsage::default(),
184 })
185 }
186 }
187
188 fn make_router(
189 providers: Vec<(&'static str, &'static str)>,
190 routes: Vec<(&str, &str, &str)>,
191 ) -> (RouterProvider, Vec<Arc<MockProvider>>) {
192 let mocks: Vec<Arc<MockProvider>> = providers
193 .iter()
194 .map(|(_, response)| Arc::new(MockProvider::new(response)))
195 .collect();
196
197 let provider_list: Vec<(String, Box<dyn ModelProvider>)> = providers
198 .iter()
199 .zip(mocks.iter())
200 .map(|((name, _), mock)| {
201 (
202 name.to_string(),
203 Box::new(Arc::clone(mock)) as Box<dyn ModelProvider>,
204 )
205 })
206 .collect();
207
208 let route_list: Vec<(String, Route)> = routes
209 .iter()
210 .map(|(hint, provider_name, model)| {
211 (
212 hint.to_string(),
213 Route {
214 provider_name: provider_name.to_string(),
215 model: model.to_string(),
216 },
217 )
218 })
219 .collect();
220
221 let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
222
223 (router, mocks)
224 }
225
226 #[async_trait]
228 impl ModelProvider for Arc<MockProvider> {
229 async fn chat(
230 &self,
231 request: ChatRequest<'_>,
232 model: &str,
233 temperature: f64,
234 ) -> anyhow::Result<ChatResponse> {
235 self.as_ref().chat(request, model, temperature).await
236 }
237 }
238
239 #[tokio::test]
240 async fn routes_hint_to_correct_provider() {
241 let (router, mocks) = make_router(
242 vec![("fast", "fast-response"), ("smart", "smart-response")],
243 vec![
244 ("fast", "fast", "llama-3-70b"),
245 ("reasoning", "smart", "claude-opus"),
246 ],
247 );
248
249 let result = one_shot(&router, None, "hello", "hint:reasoning", 0.5)
250 .await
251 .unwrap();
252 assert_eq!(result, "smart-response");
253 assert_eq!(mocks[1].call_count(), 1);
254 assert_eq!(mocks[1].last_model(), "claude-opus");
255 assert_eq!(mocks[0].call_count(), 0);
256 }
257
258 #[tokio::test]
259 async fn routes_fast_hint() {
260 let (router, mocks) = make_router(
261 vec![("fast", "fast-response"), ("smart", "smart-response")],
262 vec![("fast", "fast", "llama-3-70b")],
263 );
264
265 let result = one_shot(&router, None, "hello", "hint:fast", 0.5)
266 .await
267 .unwrap();
268 assert_eq!(result, "fast-response");
269 assert_eq!(mocks[0].call_count(), 1);
270 assert_eq!(mocks[0].last_model(), "llama-3-70b");
271 }
272
273 #[tokio::test]
274 async fn unknown_hint_falls_back_to_default() {
275 let (router, mocks) = make_router(
276 vec![("default", "default-response"), ("other", "other-response")],
277 vec![],
278 );
279
280 let result = one_shot(&router, None, "hello", "hint:nonexistent", 0.5)
281 .await
282 .unwrap();
283 assert_eq!(result, "default-response");
284 assert_eq!(mocks[0].call_count(), 1);
285 assert_eq!(mocks[0].last_model(), "hint:nonexistent");
287 }
288
289 #[tokio::test]
290 async fn non_hint_model_uses_default_provider() {
291 let (router, mocks) = make_router(
292 vec![
293 ("primary", "primary-response"),
294 ("secondary", "secondary-response"),
295 ],
296 vec![("code", "secondary", "codellama")],
297 );
298
299 let result = one_shot(
300 &router,
301 None,
302 "hello",
303 "anthropic/claude-sonnet-4-20250514",
304 0.5,
305 )
306 .await
307 .unwrap();
308 assert_eq!(result, "primary-response");
309 assert_eq!(mocks[0].call_count(), 1);
310 assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
311 }
312
313 #[test]
314 fn resolve_preserves_model_for_non_hints() {
315 let (router, _) = make_router(vec![("default", "ok")], vec![]);
316
317 let (idx, model) = router.resolve("gpt-4o");
318 assert_eq!(idx, 0);
319 assert_eq!(model, "gpt-4o");
320 }
321
322 #[test]
323 fn resolve_strips_hint_prefix() {
324 let (router, _) = make_router(
325 vec![("fast", "ok"), ("smart", "ok")],
326 vec![("reasoning", "smart", "claude-opus")],
327 );
328
329 let (idx, model) = router.resolve("hint:reasoning");
330 assert_eq!(idx, 1);
331 assert_eq!(model, "claude-opus");
332 }
333
334 #[test]
335 fn skips_routes_with_unknown_provider() {
336 let (router, _) = make_router(
337 vec![("default", "ok")],
338 vec![("broken", "nonexistent", "model")],
339 );
340
341 assert!(!router.routes.contains_key("broken"));
343 }
344
345 #[tokio::test]
346 async fn warmup_calls_all_providers() {
347 let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
348
349 assert!(router.warmup().await.is_ok());
351 }
352
353 #[tokio::test]
354 async fn chat_dispatches_to_correct_provider() {
355 let mock = Arc::new(MockProvider::new("response"));
356 let router = RouterProvider::new(
357 vec![(
358 "default".into(),
359 Box::new(Arc::clone(&mock)) as Box<dyn ModelProvider>,
360 )],
361 vec![],
362 "model".into(),
363 );
364
365 let result = one_shot(&router, Some("system"), "hello", "model", 0.5)
366 .await
367 .unwrap();
368 assert_eq!(result, "response");
369 assert_eq!(mock.call_count(), 1);
370 }
371}