Skip to main content

nenjo_models/
router.rs

1//! Multi-model router that dispatches requests to different provider+model
2//! combinations based on hint-prefixed model names.
3
4use crate::ModelProvider;
5use crate::traits::{ChatRequest, ChatResponse};
6use async_trait::async_trait;
7use std::collections::HashMap;
8
9/// A single route: maps a task hint to a provider + model combo.
10#[derive(Debug, Clone)]
11pub struct Route {
12    pub provider_name: String,
13    pub model: String,
14}
15
16/// Multi-model router — routes requests to different provider+model combos
17/// based on a task hint encoded in the model parameter.
18///
19/// The model parameter can be:
20/// - A regular model name (e.g. "anthropic/claude-sonnet-4") → uses default provider
21/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table
22///
23/// This wraps multiple pre-created providers and selects the right one per request.
24pub struct RouterProvider {
25    routes: HashMap<String, (usize, String)>, // hint → (provider_index, model)
26    providers: Vec<(String, Box<dyn ModelProvider>)>,
27    default_index: usize,
28}
29
30impl RouterProvider {
31    /// Create a new router with a default provider and optional routes.
32    ///
33    /// `providers` is a list of (name, provider) pairs. The first one is the default.
34    /// `routes` maps hint names to Route structs containing provider_name and model.
35    pub fn new(
36        providers: Vec<(String, Box<dyn ModelProvider>)>,
37        routes: Vec<(String, Route)>,
38        _default_model: String,
39    ) -> Self {
40        // Build provider name → index lookup
41        let name_to_index: HashMap<&str, usize> = providers
42            .iter()
43            .enumerate()
44            .map(|(i, (name, _))| (name.as_str(), i))
45            .collect();
46
47        // Resolve routes to provider indices
48        let resolved_routes: HashMap<String, (usize, String)> = routes
49            .into_iter()
50            .filter_map(|(hint, route)| {
51                let index = name_to_index.get(route.provider_name.as_str()).copied();
52                match index {
53                    Some(i) => Some((hint, (i, route.model))),
54                    None => {
55                        tracing::warn!(
56                            hint = hint,
57                            provider = route.provider_name,
58                            "Route references unknown provider, skipping"
59                        );
60                        None
61                    }
62                }
63            })
64            .collect();
65
66        Self {
67            routes: resolved_routes,
68            providers,
69            default_index: 0,
70        }
71    }
72
73    /// Resolve a model parameter to a (provider, actual_model) pair.
74    ///
75    /// If the model starts with "hint:", look up the hint in the route table.
76    /// Otherwise, use the default provider with the given model name.
77    /// Resolve a model parameter to a (provider_index, actual_model) pair.
78    fn resolve(&self, model: &str) -> (usize, String) {
79        if let Some(hint) = model.strip_prefix("hint:") {
80            if let Some((idx, resolved_model)) = self.routes.get(hint) {
81                return (*idx, resolved_model.clone());
82            }
83            tracing::warn!(
84                hint = hint,
85                "Unknown route hint, falling back to default provider"
86            );
87        }
88
89        // Not a hint or hint not found — use default provider with the model as-is
90        (self.default_index, model.to_string())
91    }
92}
93
94#[async_trait]
95impl ModelProvider for RouterProvider {
96    async fn chat(
97        &self,
98        request: ChatRequest<'_>,
99        model: &str,
100        temperature: f64,
101    ) -> anyhow::Result<ChatResponse> {
102        let (provider_idx, resolved_model) = self.resolve(model);
103        let (_, provider) = &self.providers[provider_idx];
104        provider.chat(request, &resolved_model, temperature).await
105    }
106
107    fn context_window(&self, model: &str) -> Option<usize> {
108        self.providers
109            .get(self.default_index)
110            .and_then(|(_, p)| p.context_window(model))
111    }
112
113    fn supports_native_tools(&self) -> bool {
114        self.providers
115            .get(self.default_index)
116            .map(|(_, p)| p.supports_native_tools())
117            .unwrap_or(false)
118    }
119
120    fn supports_developer_role(&self, model: &str) -> bool {
121        self.providers
122            .get(self.default_index)
123            .map(|(_, p)| p.supports_developer_role(model))
124            .unwrap_or(false)
125    }
126
127    async fn warmup(&self) -> anyhow::Result<()> {
128        for (name, provider) in &self.providers {
129            tracing::info!(provider = name, "Warming up routed provider");
130            if let Err(e) = provider.warmup().await {
131                tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
132            }
133        }
134        Ok(())
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::traits::{ChatRequest, ChatResponse, TokenUsage, one_shot};
142    use std::sync::Arc;
143    use std::sync::atomic::{AtomicUsize, Ordering};
144
145    struct MockProvider {
146        calls: Arc<AtomicUsize>,
147        response: &'static str,
148        last_model: std::sync::Mutex<String>,
149    }
150
151    impl MockProvider {
152        fn new(response: &'static str) -> Self {
153            Self {
154                calls: Arc::new(AtomicUsize::new(0)),
155                response,
156                last_model: std::sync::Mutex::new(String::new()),
157            }
158        }
159
160        fn call_count(&self) -> usize {
161            self.calls.load(Ordering::SeqCst)
162        }
163
164        fn last_model(&self) -> String {
165            self.last_model.lock().unwrap().clone()
166        }
167    }
168
169    #[async_trait]
170    impl ModelProvider for MockProvider {
171        async fn chat(
172            &self,
173            _request: ChatRequest<'_>,
174            model: &str,
175            _temperature: f64,
176        ) -> anyhow::Result<ChatResponse> {
177            self.calls.fetch_add(1, Ordering::SeqCst);
178            *self.last_model.lock().unwrap() = model.to_string();
179            Ok(ChatResponse {
180                text: Some(self.response.to_string()),
181                tool_calls: vec![],
182                provider_tool_calls: vec![],
183                usage: TokenUsage::default(),
184            })
185        }
186    }
187
188    fn make_router(
189        providers: Vec<(&'static str, &'static str)>,
190        routes: Vec<(&str, &str, &str)>,
191    ) -> (RouterProvider, Vec<Arc<MockProvider>>) {
192        let mocks: Vec<Arc<MockProvider>> = providers
193            .iter()
194            .map(|(_, response)| Arc::new(MockProvider::new(response)))
195            .collect();
196
197        let provider_list: Vec<(String, Box<dyn ModelProvider>)> = providers
198            .iter()
199            .zip(mocks.iter())
200            .map(|((name, _), mock)| {
201                (
202                    name.to_string(),
203                    Box::new(Arc::clone(mock)) as Box<dyn ModelProvider>,
204                )
205            })
206            .collect();
207
208        let route_list: Vec<(String, Route)> = routes
209            .iter()
210            .map(|(hint, provider_name, model)| {
211                (
212                    hint.to_string(),
213                    Route {
214                        provider_name: provider_name.to_string(),
215                        model: model.to_string(),
216                    },
217                )
218            })
219            .collect();
220
221        let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
222
223        (router, mocks)
224    }
225
226    // Arc<MockProvider> should also be a Provider
227    #[async_trait]
228    impl ModelProvider for Arc<MockProvider> {
229        async fn chat(
230            &self,
231            request: ChatRequest<'_>,
232            model: &str,
233            temperature: f64,
234        ) -> anyhow::Result<ChatResponse> {
235            self.as_ref().chat(request, model, temperature).await
236        }
237    }
238
239    #[tokio::test]
240    async fn routes_hint_to_correct_provider() {
241        let (router, mocks) = make_router(
242            vec![("fast", "fast-response"), ("smart", "smart-response")],
243            vec![
244                ("fast", "fast", "llama-3-70b"),
245                ("reasoning", "smart", "claude-opus"),
246            ],
247        );
248
249        let result = one_shot(&router, None, "hello", "hint:reasoning", 0.5)
250            .await
251            .unwrap();
252        assert_eq!(result, "smart-response");
253        assert_eq!(mocks[1].call_count(), 1);
254        assert_eq!(mocks[1].last_model(), "claude-opus");
255        assert_eq!(mocks[0].call_count(), 0);
256    }
257
258    #[tokio::test]
259    async fn routes_fast_hint() {
260        let (router, mocks) = make_router(
261            vec![("fast", "fast-response"), ("smart", "smart-response")],
262            vec![("fast", "fast", "llama-3-70b")],
263        );
264
265        let result = one_shot(&router, None, "hello", "hint:fast", 0.5)
266            .await
267            .unwrap();
268        assert_eq!(result, "fast-response");
269        assert_eq!(mocks[0].call_count(), 1);
270        assert_eq!(mocks[0].last_model(), "llama-3-70b");
271    }
272
273    #[tokio::test]
274    async fn unknown_hint_falls_back_to_default() {
275        let (router, mocks) = make_router(
276            vec![("default", "default-response"), ("other", "other-response")],
277            vec![],
278        );
279
280        let result = one_shot(&router, None, "hello", "hint:nonexistent", 0.5)
281            .await
282            .unwrap();
283        assert_eq!(result, "default-response");
284        assert_eq!(mocks[0].call_count(), 1);
285        // Falls back to default with the hint as model name
286        assert_eq!(mocks[0].last_model(), "hint:nonexistent");
287    }
288
289    #[tokio::test]
290    async fn non_hint_model_uses_default_provider() {
291        let (router, mocks) = make_router(
292            vec![
293                ("primary", "primary-response"),
294                ("secondary", "secondary-response"),
295            ],
296            vec![("code", "secondary", "codellama")],
297        );
298
299        let result = one_shot(
300            &router,
301            None,
302            "hello",
303            "anthropic/claude-sonnet-4-20250514",
304            0.5,
305        )
306        .await
307        .unwrap();
308        assert_eq!(result, "primary-response");
309        assert_eq!(mocks[0].call_count(), 1);
310        assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
311    }
312
313    #[test]
314    fn resolve_preserves_model_for_non_hints() {
315        let (router, _) = make_router(vec![("default", "ok")], vec![]);
316
317        let (idx, model) = router.resolve("gpt-4o");
318        assert_eq!(idx, 0);
319        assert_eq!(model, "gpt-4o");
320    }
321
322    #[test]
323    fn resolve_strips_hint_prefix() {
324        let (router, _) = make_router(
325            vec![("fast", "ok"), ("smart", "ok")],
326            vec![("reasoning", "smart", "claude-opus")],
327        );
328
329        let (idx, model) = router.resolve("hint:reasoning");
330        assert_eq!(idx, 1);
331        assert_eq!(model, "claude-opus");
332    }
333
334    #[test]
335    fn skips_routes_with_unknown_provider() {
336        let (router, _) = make_router(
337            vec![("default", "ok")],
338            vec![("broken", "nonexistent", "model")],
339        );
340
341        // Route should not exist
342        assert!(!router.routes.contains_key("broken"));
343    }
344
345    #[tokio::test]
346    async fn warmup_calls_all_providers() {
347        let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
348
349        // Warmup should not error
350        assert!(router.warmup().await.is_ok());
351    }
352
353    #[tokio::test]
354    async fn chat_dispatches_to_correct_provider() {
355        let mock = Arc::new(MockProvider::new("response"));
356        let router = RouterProvider::new(
357            vec![(
358                "default".into(),
359                Box::new(Arc::clone(&mock)) as Box<dyn ModelProvider>,
360            )],
361            vec![],
362            "model".into(),
363        );
364
365        let result = one_shot(&router, Some("system"), "hello", "model", 0.5)
366            .await
367            .unwrap();
368        assert_eq!(result, "response");
369        assert_eq!(mock.call_count(), 1);
370    }
371}