Skip to main content

construct/providers/
router.rs

1use super::Provider;
2use super::traits::{
3    ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use crate::config::schema::ModelPricing;
6use async_trait::async_trait;
7use futures_util::stream::BoxStream;
8use std::collections::HashMap;
9
10/// A single route: maps a task hint to a provider + model combo.
11#[derive(Debug, Clone)]
12pub struct Route {
13    pub provider_name: String,
14    pub model: String,
15}
16
17/// Multi-model router — routes requests to different provider+model combos
18/// based on a task hint encoded in the model parameter.
19///
20/// The model parameter can be:
21/// - A regular model name (e.g. "anthropic/claude-sonnet-4") → uses default provider
22/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table
23///
24/// This wraps multiple pre-created providers and selects the right one per request.
25pub struct RouterProvider {
26    routes: HashMap<String, (usize, String)>, // hint → (provider_index, model)
27    providers: Vec<(String, Box<dyn Provider>)>,
28    default_index: usize,
29    default_model: String,
30}
31
32impl RouterProvider {
33    /// Create a new router with a default provider and optional routes.
34    ///
35    /// `providers` is a list of (name, provider) pairs. The first one is the default.
36    /// `routes` maps hint names to Route structs containing provider_name and model.
37    pub fn new(
38        providers: Vec<(String, Box<dyn Provider>)>,
39        routes: Vec<(String, Route)>,
40        default_model: String,
41    ) -> Self {
42        // Build provider name → index lookup
43        let name_to_index: HashMap<&str, usize> = providers
44            .iter()
45            .enumerate()
46            .map(|(i, (name, _))| (name.as_str(), i))
47            .collect();
48
49        // Resolve routes to provider indices
50        let resolved_routes: HashMap<String, (usize, String)> = routes
51            .into_iter()
52            .filter_map(|(hint, route)| {
53                let index = name_to_index.get(route.provider_name.as_str()).copied();
54                match index {
55                    Some(i) => Some((hint, (i, route.model))),
56                    None => {
57                        tracing::warn!(
58                            hint = hint,
59                            provider = route.provider_name,
60                            "Route references unknown provider, skipping"
61                        );
62                        None
63                    }
64                }
65            })
66            .collect();
67
68        Self {
69            routes: resolved_routes,
70            providers,
71            default_index: 0,
72            default_model,
73        }
74    }
75
76    /// Resolve a model parameter to the cheapest qualifying route based on pricing.
77    ///
78    /// If the model starts with `"hint:cost-optimized"` or `"hint:cheapest"`, this
79    /// method scores each route by `input_price + output_price` (a simple proxy for
80    /// total cost), optionally filtering by capability requirements, and returns the
81    /// cheapest qualifying route.
82    ///
83    /// Falls back to the default route when no pricing data matches.
84    pub fn resolve_cost_optimized(
85        &self,
86        model: &str,
87        prices: &HashMap<String, ModelPricing>,
88        required_vision: bool,
89        required_tools: bool,
90    ) -> (usize, String) {
91        let hint = model.strip_prefix("hint:");
92        let is_cost_hint = matches!(hint, Some("cost-optimized" | "cheapest"));
93
94        if !is_cost_hint {
95            return self.resolve(model);
96        }
97
98        let mut candidates: Vec<(usize, String, f64)> = Vec::new();
99
100        for (idx, route_model) in self.routes.values() {
101            // Capability filtering
102            if let Some((_, provider)) = self.providers.get(*idx) {
103                if required_vision && !provider.supports_vision() {
104                    continue;
105                }
106                if required_tools && !provider.supports_native_tools() {
107                    continue;
108                }
109            }
110
111            if let Some(pricing) = prices.get(route_model) {
112                let total_cost = pricing.input + pricing.output;
113                candidates.push((*idx, route_model.clone(), total_cost));
114            }
115        }
116
117        // Sort by total cost (ascending) and pick the cheapest
118        candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
119
120        if let Some((idx, route_model, _)) = candidates.into_iter().next() {
121            return (idx, route_model);
122        }
123
124        // Fallback to default
125        tracing::warn!(
126            "No cost-optimized route found with matching pricing data, \
127             falling back to default"
128        );
129        (self.default_index, self.default_model.clone())
130    }
131
132    /// Resolve a model parameter to a (provider, actual_model) pair.
133    ///
134    /// If the model starts with "hint:", look up the hint in the route table.
135    /// Otherwise, use the default provider with the given model name.
136    /// Resolve a model parameter to a (provider_index, actual_model) pair.
137    fn resolve(&self, model: &str) -> (usize, String) {
138        if let Some(hint) = model.strip_prefix("hint:") {
139            if let Some((idx, resolved_model)) = self.routes.get(hint) {
140                return (*idx, resolved_model.clone());
141            }
142            tracing::warn!(
143                hint = hint,
144                "Unknown route hint, falling back to default provider"
145            );
146        }
147
148        // Not a hint or hint not found — use default provider with the model as-is
149        (self.default_index, model.to_string())
150    }
151}
152
153/// A cost-optimized routing strategy that selects the cheapest qualifying
154/// provider from the route table based on `ModelPricing` data.
155///
156/// This wraps pricing config and capability requirements, scoring candidates
157/// by their combined input + output cost per 1M tokens.
158#[derive(Debug, Clone)]
159pub struct CostOptimizedStrategy {
160    /// Per-model pricing data (keyed by model name).
161    pub prices: HashMap<String, ModelPricing>,
162    /// Whether the request requires vision support.
163    pub required_vision: bool,
164    /// Whether the request requires native tool support.
165    pub required_tools: bool,
166}
167
168impl CostOptimizedStrategy {
169    /// Create a new cost-optimized strategy with the given pricing data.
170    pub fn new(prices: HashMap<String, ModelPricing>) -> Self {
171        Self {
172            prices,
173            required_vision: false,
174            required_tools: false,
175        }
176    }
177
178    /// Set whether vision support is required.
179    pub fn with_vision(mut self, required: bool) -> Self {
180        self.required_vision = required;
181        self
182    }
183
184    /// Set whether native tool support is required.
185    pub fn with_tools(mut self, required: bool) -> Self {
186        self.required_tools = required;
187        self
188    }
189
190    /// Score a model by total cost (input + output per 1M tokens).
191    /// Returns `None` if no pricing data is available for the model.
192    pub fn score(&self, model: &str) -> Option<f64> {
193        self.prices.get(model).map(|p| p.input + p.output)
194    }
195}
196
197#[async_trait]
198impl Provider for RouterProvider {
199    async fn chat_with_system(
200        &self,
201        system_prompt: Option<&str>,
202        message: &str,
203        model: &str,
204        temperature: f64,
205    ) -> anyhow::Result<String> {
206        let (provider_idx, resolved_model) = self.resolve(model);
207
208        let (provider_name, provider) = &self.providers[provider_idx];
209        tracing::info!(
210            provider = provider_name.as_str(),
211            model = resolved_model.as_str(),
212            "Router dispatching request"
213        );
214
215        provider
216            .chat_with_system(system_prompt, message, &resolved_model, temperature)
217            .await
218    }
219
220    async fn chat_with_history(
221        &self,
222        messages: &[ChatMessage],
223        model: &str,
224        temperature: f64,
225    ) -> anyhow::Result<String> {
226        let (provider_idx, resolved_model) = self.resolve(model);
227        let (_, provider) = &self.providers[provider_idx];
228        provider
229            .chat_with_history(messages, &resolved_model, temperature)
230            .await
231    }
232
233    async fn chat(
234        &self,
235        request: ChatRequest<'_>,
236        model: &str,
237        temperature: f64,
238    ) -> anyhow::Result<ChatResponse> {
239        let (provider_idx, resolved_model) = self.resolve(model);
240        let (_, provider) = &self.providers[provider_idx];
241        provider.chat(request, &resolved_model, temperature).await
242    }
243
244    async fn chat_with_tools(
245        &self,
246        messages: &[ChatMessage],
247        tools: &[serde_json::Value],
248        model: &str,
249        temperature: f64,
250    ) -> anyhow::Result<ChatResponse> {
251        let (provider_idx, resolved_model) = self.resolve(model);
252        let (_, provider) = &self.providers[provider_idx];
253        provider
254            .chat_with_tools(messages, tools, &resolved_model, temperature)
255            .await
256    }
257
258    fn supports_native_tools(&self) -> bool {
259        self.providers
260            .get(self.default_index)
261            .map(|(_, p)| p.supports_native_tools())
262            .unwrap_or(false)
263    }
264
265    fn supports_streaming(&self) -> bool {
266        self.providers
267            .iter()
268            .any(|(_, provider)| provider.supports_streaming())
269    }
270
271    fn supports_streaming_tool_events(&self) -> bool {
272        self.providers
273            .iter()
274            .any(|(_, provider)| provider.supports_streaming_tool_events())
275    }
276
277    fn stream_chat_with_history(
278        &self,
279        messages: &[ChatMessage],
280        model: &str,
281        temperature: f64,
282        options: StreamOptions,
283    ) -> BoxStream<'static, StreamResult<StreamChunk>> {
284        let (provider_idx, resolved_model) = self.resolve(model);
285        let (_, provider) = &self.providers[provider_idx];
286        provider.stream_chat_with_history(messages, &resolved_model, temperature, options)
287    }
288
289    fn stream_chat(
290        &self,
291        request: ChatRequest<'_>,
292        model: &str,
293        temperature: f64,
294        options: StreamOptions,
295    ) -> BoxStream<'static, StreamResult<StreamEvent>> {
296        let (provider_idx, resolved_model) = self.resolve(model);
297        let (_, provider) = &self.providers[provider_idx];
298        provider.stream_chat(request, &resolved_model, temperature, options)
299    }
300
301    fn supports_vision(&self) -> bool {
302        self.providers
303            .iter()
304            .any(|(_, provider)| provider.supports_vision())
305    }
306
307    async fn warmup(&self) -> anyhow::Result<()> {
308        for (name, provider) in &self.providers {
309            tracing::info!(provider = name, "Warming up routed provider");
310            if let Err(e) = provider.warmup().await {
311                tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
312            }
313        }
314        Ok(())
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::tools::ToolSpec;
322    use futures_util::StreamExt;
323    use std::sync::Arc;
324    use std::sync::atomic::{AtomicUsize, Ordering};
325
326    struct MockProvider {
327        calls: Arc<AtomicUsize>,
328        response: &'static str,
329        last_model: parking_lot::Mutex<String>,
330    }
331
332    impl MockProvider {
333        fn new(response: &'static str) -> Self {
334            Self {
335                calls: Arc::new(AtomicUsize::new(0)),
336                response,
337                last_model: parking_lot::Mutex::new(String::new()),
338            }
339        }
340
341        fn call_count(&self) -> usize {
342            self.calls.load(Ordering::SeqCst)
343        }
344
345        fn last_model(&self) -> String {
346            self.last_model.lock().clone()
347        }
348    }
349
350    #[async_trait]
351    impl Provider for MockProvider {
352        async fn chat_with_system(
353            &self,
354            _system_prompt: Option<&str>,
355            _message: &str,
356            model: &str,
357            _temperature: f64,
358        ) -> anyhow::Result<String> {
359            self.calls.fetch_add(1, Ordering::SeqCst);
360            *self.last_model.lock() = model.to_string();
361            Ok(self.response.to_string())
362        }
363    }
364
365    fn make_router(
366        providers: Vec<(&'static str, &'static str)>,
367        routes: Vec<(&str, &str, &str)>,
368    ) -> (RouterProvider, Vec<Arc<MockProvider>>) {
369        let mocks: Vec<Arc<MockProvider>> = providers
370            .iter()
371            .map(|(_, response)| Arc::new(MockProvider::new(response)))
372            .collect();
373
374        let provider_list: Vec<(String, Box<dyn Provider>)> = providers
375            .iter()
376            .zip(mocks.iter())
377            .map(|((name, _), mock)| {
378                (
379                    (*name).to_string(),
380                    Box::new(Arc::clone(mock)) as Box<dyn Provider>,
381                )
382            })
383            .collect();
384
385        let route_list: Vec<(String, Route)> = routes
386            .iter()
387            .map(|(hint, provider_name, model)| {
388                (
389                    (*hint).to_string(),
390                    Route {
391                        provider_name: (*provider_name).to_string(),
392                        model: (*model).to_string(),
393                    },
394                )
395            })
396            .collect();
397
398        let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
399
400        (router, mocks)
401    }
402
403    // Arc<MockProvider> should also be a Provider
404    #[async_trait]
405    impl Provider for Arc<MockProvider> {
406        async fn chat_with_system(
407            &self,
408            system_prompt: Option<&str>,
409            message: &str,
410            model: &str,
411            temperature: f64,
412        ) -> anyhow::Result<String> {
413            self.as_ref()
414                .chat_with_system(system_prompt, message, model, temperature)
415                .await
416        }
417    }
418
419    struct StreamingMockProvider {
420        stream_calls: Arc<AtomicUsize>,
421        last_stream_model: parking_lot::Mutex<String>,
422        response: &'static str,
423    }
424
425    impl StreamingMockProvider {
426        fn new(response: &'static str) -> Self {
427            Self {
428                stream_calls: Arc::new(AtomicUsize::new(0)),
429                last_stream_model: parking_lot::Mutex::new(String::new()),
430                response,
431            }
432        }
433    }
434
435    #[async_trait]
436    impl Provider for StreamingMockProvider {
437        async fn chat_with_system(
438            &self,
439            _system_prompt: Option<&str>,
440            _message: &str,
441            _model: &str,
442            _temperature: f64,
443        ) -> anyhow::Result<String> {
444            Ok("ok".to_string())
445        }
446
447        fn supports_streaming(&self) -> bool {
448            true
449        }
450
451        fn stream_chat_with_history(
452            &self,
453            _messages: &[ChatMessage],
454            model: &str,
455            _temperature: f64,
456            _options: StreamOptions,
457        ) -> BoxStream<'static, StreamResult<StreamChunk>> {
458            self.stream_calls.fetch_add(1, Ordering::SeqCst);
459            *self.last_stream_model.lock() = model.to_string();
460            let chunks = vec![
461                Ok(StreamChunk::delta(self.response)),
462                Ok(StreamChunk::final_chunk()),
463            ];
464            futures_util::stream::iter(chunks).boxed()
465        }
466    }
467
468    #[async_trait]
469    impl Provider for Arc<StreamingMockProvider> {
470        async fn chat_with_system(
471            &self,
472            system_prompt: Option<&str>,
473            message: &str,
474            model: &str,
475            temperature: f64,
476        ) -> anyhow::Result<String> {
477            self.as_ref()
478                .chat_with_system(system_prompt, message, model, temperature)
479                .await
480        }
481
482        fn supports_streaming(&self) -> bool {
483            self.as_ref().supports_streaming()
484        }
485
486        fn stream_chat_with_history(
487            &self,
488            messages: &[ChatMessage],
489            model: &str,
490            temperature: f64,
491            options: StreamOptions,
492        ) -> BoxStream<'static, StreamResult<StreamChunk>> {
493            self.as_ref()
494                .stream_chat_with_history(messages, model, temperature, options)
495        }
496    }
497
498    struct ToolEventStreamingMockProvider {
499        stream_calls: Arc<AtomicUsize>,
500        tool_event_calls: Arc<AtomicUsize>,
501        last_stream_model: parking_lot::Mutex<String>,
502    }
503
504    impl ToolEventStreamingMockProvider {
505        fn new() -> Self {
506            Self {
507                stream_calls: Arc::new(AtomicUsize::new(0)),
508                tool_event_calls: Arc::new(AtomicUsize::new(0)),
509                last_stream_model: parking_lot::Mutex::new(String::new()),
510            }
511        }
512    }
513
514    #[async_trait]
515    impl Provider for ToolEventStreamingMockProvider {
516        async fn chat_with_system(
517            &self,
518            _system_prompt: Option<&str>,
519            _message: &str,
520            _model: &str,
521            _temperature: f64,
522        ) -> anyhow::Result<String> {
523            Ok("ok".to_string())
524        }
525
526        fn supports_streaming(&self) -> bool {
527            true
528        }
529
530        fn supports_streaming_tool_events(&self) -> bool {
531            true
532        }
533
534        fn stream_chat(
535            &self,
536            request: ChatRequest<'_>,
537            model: &str,
538            _temperature: f64,
539            _options: StreamOptions,
540        ) -> BoxStream<'static, StreamResult<StreamEvent>> {
541            self.stream_calls.fetch_add(1, Ordering::SeqCst);
542            if request.tools.is_some_and(|tools| !tools.is_empty()) {
543                self.tool_event_calls.fetch_add(1, Ordering::SeqCst);
544            }
545            *self.last_stream_model.lock() = model.to_string();
546            futures_util::stream::iter(vec![
547                Ok(StreamEvent::ToolCall(crate::providers::ToolCall {
548                    id: "call_router_1".to_string(),
549                    name: "shell".to_string(),
550                    arguments: r#"{"command":"date"}"#.to_string(),
551                })),
552                Ok(StreamEvent::Final),
553            ])
554            .boxed()
555        }
556    }
557
558    #[async_trait]
559    impl Provider for Arc<ToolEventStreamingMockProvider> {
560        async fn chat_with_system(
561            &self,
562            system_prompt: Option<&str>,
563            message: &str,
564            model: &str,
565            temperature: f64,
566        ) -> anyhow::Result<String> {
567            self.as_ref()
568                .chat_with_system(system_prompt, message, model, temperature)
569                .await
570        }
571
572        fn supports_streaming(&self) -> bool {
573            self.as_ref().supports_streaming()
574        }
575
576        fn supports_streaming_tool_events(&self) -> bool {
577            self.as_ref().supports_streaming_tool_events()
578        }
579
580        fn stream_chat(
581            &self,
582            request: ChatRequest<'_>,
583            model: &str,
584            temperature: f64,
585            options: StreamOptions,
586        ) -> BoxStream<'static, StreamResult<StreamEvent>> {
587            self.as_ref()
588                .stream_chat(request, model, temperature, options)
589        }
590    }
591
592    #[tokio::test]
593    async fn routes_hint_to_correct_provider() {
594        let (router, mocks) = make_router(
595            vec![("fast", "fast-response"), ("smart", "smart-response")],
596            vec![
597                ("fast", "fast", "llama-3-70b"),
598                ("reasoning", "smart", "claude-opus"),
599            ],
600        );
601
602        let result = router
603            .simple_chat("hello", "hint:reasoning", 0.5)
604            .await
605            .unwrap();
606        assert_eq!(result, "smart-response");
607        assert_eq!(mocks[1].call_count(), 1);
608        assert_eq!(mocks[1].last_model(), "claude-opus");
609        assert_eq!(mocks[0].call_count(), 0);
610    }
611
612    #[tokio::test]
613    async fn routes_fast_hint() {
614        let (router, mocks) = make_router(
615            vec![("fast", "fast-response"), ("smart", "smart-response")],
616            vec![("fast", "fast", "llama-3-70b")],
617        );
618
619        let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap();
620        assert_eq!(result, "fast-response");
621        assert_eq!(mocks[0].call_count(), 1);
622        assert_eq!(mocks[0].last_model(), "llama-3-70b");
623    }
624
625    #[tokio::test]
626    async fn unknown_hint_falls_back_to_default() {
627        let (router, mocks) = make_router(
628            vec![("default", "default-response"), ("other", "other-response")],
629            vec![],
630        );
631
632        let result = router
633            .simple_chat("hello", "hint:nonexistent", 0.5)
634            .await
635            .unwrap();
636        assert_eq!(result, "default-response");
637        assert_eq!(mocks[0].call_count(), 1);
638        // Falls back to default with the hint as model name
639        assert_eq!(mocks[0].last_model(), "hint:nonexistent");
640    }
641
642    #[tokio::test]
643    async fn non_hint_model_uses_default_provider() {
644        let (router, mocks) = make_router(
645            vec![
646                ("primary", "primary-response"),
647                ("secondary", "secondary-response"),
648            ],
649            vec![("code", "secondary", "codellama")],
650        );
651
652        let result = router
653            .simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
654            .await
655            .unwrap();
656        assert_eq!(result, "primary-response");
657        assert_eq!(mocks[0].call_count(), 1);
658        assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
659    }
660
661    #[test]
662    fn resolve_preserves_model_for_non_hints() {
663        let (router, _) = make_router(vec![("default", "ok")], vec![]);
664
665        let (idx, model) = router.resolve("gpt-4o");
666        assert_eq!(idx, 0);
667        assert_eq!(model, "gpt-4o");
668    }
669
670    #[test]
671    fn resolve_strips_hint_prefix() {
672        let (router, _) = make_router(
673            vec![("fast", "ok"), ("smart", "ok")],
674            vec![("reasoning", "smart", "claude-opus")],
675        );
676
677        let (idx, model) = router.resolve("hint:reasoning");
678        assert_eq!(idx, 1);
679        assert_eq!(model, "claude-opus");
680    }
681
682    #[test]
683    fn skips_routes_with_unknown_provider() {
684        let (router, _) = make_router(
685            vec![("default", "ok")],
686            vec![("broken", "nonexistent", "model")],
687        );
688
689        // Route should not exist
690        assert!(!router.routes.contains_key("broken"));
691    }
692
693    #[tokio::test]
694    async fn warmup_calls_all_providers() {
695        let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
696
697        // Warmup should not error
698        assert!(router.warmup().await.is_ok());
699    }
700
701    #[tokio::test]
702    async fn chat_with_system_passes_system_prompt() {
703        let mock = Arc::new(MockProvider::new("response"));
704        let router = RouterProvider::new(
705            vec![(
706                "default".into(),
707                Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
708            )],
709            vec![],
710            "model".into(),
711        );
712
713        let result = router
714            .chat_with_system(Some("system"), "hello", "model", 0.5)
715            .await
716            .unwrap();
717        assert_eq!(result, "response");
718        assert_eq!(mock.call_count(), 1);
719    }
720
721    #[tokio::test]
722    async fn chat_with_tools_delegates_to_resolved_provider() {
723        let mock = Arc::new(MockProvider::new("tool-response"));
724        let router = RouterProvider::new(
725            vec![(
726                "default".into(),
727                Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
728            )],
729            vec![],
730            "model".into(),
731        );
732
733        let messages = vec![ChatMessage {
734            role: "user".to_string(),
735            content: "use tools".to_string(),
736        }];
737        let tools = vec![serde_json::json!({
738            "type": "function",
739            "function": {
740                "name": "shell",
741                "description": "Run shell command",
742                "parameters": {}
743            }
744        })];
745
746        // chat_with_tools should delegate through the router to the mock.
747        // MockProvider's default chat_with_tools calls chat_with_history -> chat_with_system.
748        let result = router
749            .chat_with_tools(&messages, &tools, "model", 0.7)
750            .await
751            .unwrap();
752        assert_eq!(result.text.as_deref(), Some("tool-response"));
753        assert_eq!(mock.call_count(), 1);
754        assert_eq!(mock.last_model(), "model");
755    }
756
757    #[tokio::test]
758    async fn chat_with_tools_routes_hint_correctly() {
759        let (router, mocks) = make_router(
760            vec![("fast", "fast-tool"), ("smart", "smart-tool")],
761            vec![("reasoning", "smart", "claude-opus")],
762        );
763
764        let messages = vec![ChatMessage {
765            role: "user".to_string(),
766            content: "reason about this".to_string(),
767        }];
768        let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})];
769
770        let result = router
771            .chat_with_tools(&messages, &tools, "hint:reasoning", 0.5)
772            .await
773            .unwrap();
774        assert_eq!(result.text.as_deref(), Some("smart-tool"));
775        assert_eq!(mocks[1].call_count(), 1);
776        assert_eq!(mocks[1].last_model(), "claude-opus");
777        assert_eq!(mocks[0].call_count(), 0);
778    }
779
780    // ── Cost-optimized routing tests ────────────────────────────────
781
782    use crate::providers::traits::ProviderCapabilities;
783
784    /// Mock provider with configurable capability flags.
785    struct CapableMockProvider {
786        response: &'static str,
787        vision: bool,
788        tools: bool,
789    }
790
791    impl CapableMockProvider {
792        fn new(response: &'static str, vision: bool, tools: bool) -> Self {
793            Self {
794                response,
795                vision,
796                tools,
797            }
798        }
799    }
800
801    #[async_trait]
802    impl Provider for CapableMockProvider {
803        fn capabilities(&self) -> ProviderCapabilities {
804            ProviderCapabilities {
805                native_tool_calling: self.tools,
806                vision: self.vision,
807                prompt_caching: false,
808            }
809        }
810
811        async fn chat_with_system(
812            &self,
813            _system_prompt: Option<&str>,
814            _message: &str,
815            _model: &str,
816            _temperature: f64,
817        ) -> anyhow::Result<String> {
818            Ok(self.response.to_string())
819        }
820    }
821
822    fn make_pricing(entries: Vec<(&str, f64, f64)>) -> HashMap<String, ModelPricing> {
823        entries
824            .into_iter()
825            .map(|(model, input, output)| (model.to_string(), ModelPricing { input, output }))
826            .collect()
827    }
828
829    #[test]
830    fn cost_optimized_selects_cheapest_provider() {
831        let providers: Vec<(String, Box<dyn Provider>)> = vec![
832            (
833                "expensive".into(),
834                Box::new(CapableMockProvider::new("exp", false, false)),
835            ),
836            (
837                "cheap".into(),
838                Box::new(CapableMockProvider::new("chp", false, false)),
839            ),
840        ];
841        let routes = vec![
842            (
843                "expensive".to_string(),
844                Route {
845                    provider_name: "expensive".into(),
846                    model: "big-model".into(),
847                },
848            ),
849            (
850                "cheap".to_string(),
851                Route {
852                    provider_name: "cheap".into(),
853                    model: "small-model".into(),
854                },
855            ),
856        ];
857        let router = RouterProvider::new(providers, routes, "default-model".into());
858
859        let prices = make_pricing(vec![("big-model", 15.0, 75.0), ("small-model", 0.25, 1.25)]);
860
861        let (idx, model) =
862            router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
863        assert_eq!(model, "small-model");
864        assert_eq!(idx, 1);
865    }
866
867    #[test]
868    fn cost_optimized_respects_vision_requirement() {
869        let providers: Vec<(String, Box<dyn Provider>)> = vec![
870            (
871                "no-vision".into(),
872                Box::new(CapableMockProvider::new("nv", false, false)),
873            ),
874            (
875                "has-vision".into(),
876                Box::new(CapableMockProvider::new("hv", true, false)),
877            ),
878        ];
879        let routes = vec![
880            (
881                "cheap".to_string(),
882                Route {
883                    provider_name: "no-vision".into(),
884                    model: "cheap-model".into(),
885                },
886            ),
887            (
888                "vision".to_string(),
889                Route {
890                    provider_name: "has-vision".into(),
891                    model: "vision-model".into(),
892                },
893            ),
894        ];
895        let router = RouterProvider::new(providers, routes, "default-model".into());
896
897        let prices = make_pricing(vec![
898            ("cheap-model", 0.10, 0.40),
899            ("vision-model", 3.0, 15.0),
900        ]);
901
902        // With vision required, the cheap model (no vision) is filtered out
903        let (_, model) = router.resolve_cost_optimized("hint:cheapest", &prices, true, false);
904        assert_eq!(model, "vision-model");
905    }
906
907    #[test]
908    fn cost_optimized_respects_tools_requirement() {
909        let providers: Vec<(String, Box<dyn Provider>)> = vec![
910            (
911                "no-tools".into(),
912                Box::new(CapableMockProvider::new("nt", false, false)),
913            ),
914            (
915                "has-tools".into(),
916                Box::new(CapableMockProvider::new("ht", false, true)),
917            ),
918        ];
919        let routes = vec![
920            (
921                "basic".to_string(),
922                Route {
923                    provider_name: "no-tools".into(),
924                    model: "basic-model".into(),
925                },
926            ),
927            (
928                "tools".to_string(),
929                Route {
930                    provider_name: "has-tools".into(),
931                    model: "tools-model".into(),
932                },
933            ),
934        ];
935        let router = RouterProvider::new(providers, routes, "default-model".into());
936
937        let prices = make_pricing(vec![
938            ("basic-model", 0.10, 0.40),
939            ("tools-model", 5.0, 15.0),
940        ]);
941
942        // With tools required, the basic model (no tools) is filtered out
943        let (_, model) = router.resolve_cost_optimized("hint:cost-optimized", &prices, false, true);
944        assert_eq!(model, "tools-model");
945    }
946
947    #[test]
948    fn cost_optimized_falls_back_when_no_pricing() {
949        let (router, _) = make_router(
950            vec![("default", "ok"), ("other", "ok")],
951            vec![("route-a", "other", "some-model")],
952        );
953
954        // Empty pricing map — no matches possible
955        let prices: HashMap<String, ModelPricing> = HashMap::new();
956        let (idx, model) =
957            router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
958        assert_eq!(idx, 0);
959        assert_eq!(model, "default-model");
960    }
961
962    #[test]
963    fn cost_optimized_with_single_route() {
964        let providers: Vec<(String, Box<dyn Provider>)> = vec![(
965            "only".into(),
966            Box::new(CapableMockProvider::new("ok", false, false)),
967        )];
968        let routes = vec![(
969            "single".to_string(),
970            Route {
971                provider_name: "only".into(),
972                model: "the-model".into(),
973            },
974        )];
975        let router = RouterProvider::new(providers, routes, "default-model".into());
976
977        let prices = make_pricing(vec![("the-model", 1.0, 2.0)]);
978
979        let (idx, model) = router.resolve_cost_optimized("hint:cheapest", &prices, false, false);
980        assert_eq!(idx, 0);
981        assert_eq!(model, "the-model");
982    }
983
984    #[test]
985    fn cost_optimized_prefers_lower_total_cost() {
986        let providers: Vec<(String, Box<dyn Provider>)> = vec![
987            (
988                "p1".into(),
989                Box::new(CapableMockProvider::new("r1", false, false)),
990            ),
991            (
992                "p2".into(),
993                Box::new(CapableMockProvider::new("r2", false, false)),
994            ),
995            (
996                "p3".into(),
997                Box::new(CapableMockProvider::new("r3", false, false)),
998            ),
999        ];
1000        let routes = vec![
1001            (
1002                "a".to_string(),
1003                Route {
1004                    provider_name: "p1".into(),
1005                    model: "model-a".into(),
1006                },
1007            ),
1008            (
1009                "b".to_string(),
1010                Route {
1011                    provider_name: "p2".into(),
1012                    model: "model-b".into(),
1013                },
1014            ),
1015            (
1016                "c".to_string(),
1017                Route {
1018                    provider_name: "p3".into(),
1019                    model: "model-c".into(),
1020                },
1021            ),
1022        ];
1023        let router = RouterProvider::new(providers, routes, "default-model".into());
1024
1025        let prices = make_pricing(vec![
1026            ("model-a", 10.0, 50.0), // total: 60
1027            ("model-b", 0.15, 0.60), // total: 0.75 (cheapest)
1028            ("model-c", 3.0, 15.0),  // total: 18
1029        ]);
1030
1031        let (idx, model) =
1032            router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
1033        assert_eq!(model, "model-b");
1034        assert_eq!(idx, 1);
1035    }
1036
1037    #[test]
1038    fn cost_optimized_strategy_score() {
1039        let prices = make_pricing(vec![("cheap", 0.10, 0.40), ("expensive", 15.0, 75.0)]);
1040        let strategy = CostOptimizedStrategy::new(prices);
1041
1042        assert!((strategy.score("cheap").unwrap() - 0.50).abs() < f64::EPSILON);
1043        assert!((strategy.score("expensive").unwrap() - 90.0).abs() < f64::EPSILON);
1044        assert!(strategy.score("unknown").is_none());
1045    }
1046
1047    #[tokio::test]
1048    async fn supports_streaming_returns_true_when_any_provider_supports_it() {
1049        let streaming = Arc::new(StreamingMockProvider::new("stream"));
1050        let router = RouterProvider::new(
1051            vec![
1052                (
1053                    "default".into(),
1054                    Box::new(MockProvider::new("default")) as Box<dyn Provider>,
1055                ),
1056                (
1057                    "streaming".into(),
1058                    Box::new(Arc::clone(&streaming)) as Box<dyn Provider>,
1059                ),
1060            ],
1061            vec![(
1062                "reasoning".into(),
1063                Route {
1064                    provider_name: "streaming".into(),
1065                    model: "claude-opus".into(),
1066                },
1067            )],
1068            "model".into(),
1069        );
1070
1071        assert!(router.supports_streaming());
1072    }
1073
1074    #[tokio::test]
1075    async fn stream_chat_with_history_routes_hint_to_correct_provider_and_model() {
1076        let streaming = Arc::new(StreamingMockProvider::new("streamed response"));
1077        let router = RouterProvider::new(
1078            vec![
1079                (
1080                    "default".into(),
1081                    Box::new(MockProvider::new("default")) as Box<dyn Provider>,
1082                ),
1083                (
1084                    "streaming".into(),
1085                    Box::new(Arc::clone(&streaming)) as Box<dyn Provider>,
1086                ),
1087            ],
1088            vec![(
1089                "reasoning".into(),
1090                Route {
1091                    provider_name: "streaming".into(),
1092                    model: "claude-opus".into(),
1093                },
1094            )],
1095            "model".into(),
1096        );
1097
1098        let messages = vec![ChatMessage::user("hello")];
1099        let mut stream = router.stream_chat_with_history(
1100            &messages,
1101            "hint:reasoning",
1102            0.0,
1103            StreamOptions::new(true),
1104        );
1105
1106        let mut collected = String::new();
1107        while let Some(chunk) = stream.next().await {
1108            let chunk = chunk.expect("stream chunk should be ok");
1109            collected.push_str(&chunk.delta);
1110        }
1111
1112        assert_eq!(collected, "streamed response");
1113        assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1114        assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1115    }
1116
1117    #[tokio::test]
1118    async fn stream_chat_routes_hint_with_structured_tool_events() {
1119        let streaming = Arc::new(ToolEventStreamingMockProvider::new());
1120        let router = RouterProvider::new(
1121            vec![
1122                (
1123                    "default".into(),
1124                    Box::new(MockProvider::new("default")) as Box<dyn Provider>,
1125                ),
1126                (
1127                    "streaming".into(),
1128                    Box::new(Arc::clone(&streaming)) as Box<dyn Provider>,
1129                ),
1130            ],
1131            vec![(
1132                "reasoning".into(),
1133                Route {
1134                    provider_name: "streaming".into(),
1135                    model: "claude-opus".into(),
1136                },
1137            )],
1138            "model".into(),
1139        );
1140
1141        let messages = vec![ChatMessage::user("hello")];
1142        let tools = vec![ToolSpec {
1143            name: "shell".to_string(),
1144            description: "run shell commands".to_string(),
1145            parameters: serde_json::json!({
1146                "type": "object",
1147                "properties": {
1148                    "command": { "type": "string" }
1149                }
1150            }),
1151        }];
1152
1153        let mut stream = router.stream_chat(
1154            ChatRequest {
1155                messages: &messages,
1156                tools: Some(&tools),
1157            },
1158            "hint:reasoning",
1159            0.0,
1160            StreamOptions::new(true),
1161        );
1162
1163        let first = stream.next().await.unwrap().unwrap();
1164        let second = stream.next().await.unwrap().unwrap();
1165        assert!(stream.next().await.is_none());
1166
1167        match first {
1168            StreamEvent::ToolCall(call) => {
1169                assert_eq!(call.name, "shell");
1170                assert_eq!(call.arguments, r#"{"command":"date"}"#);
1171            }
1172            other => panic!("expected tool-call event, got {other:?}"),
1173        }
1174        assert!(matches!(second, StreamEvent::Final));
1175        assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1176        assert_eq!(streaming.tool_event_calls.load(Ordering::SeqCst), 1);
1177        assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1178    }
1179}