1use super::Provider;
2use super::traits::{
3 ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamEvent, StreamOptions, StreamResult,
4};
5use crate::config::schema::ModelPricing;
6use async_trait::async_trait;
7use futures_util::stream::BoxStream;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct Route {
13 pub provider_name: String,
14 pub model: String,
15}
16
17pub struct RouterProvider {
26 routes: HashMap<String, (usize, String)>, providers: Vec<(String, Box<dyn Provider>)>,
28 default_index: usize,
29 default_model: String,
30}
31
32impl RouterProvider {
33 pub fn new(
38 providers: Vec<(String, Box<dyn Provider>)>,
39 routes: Vec<(String, Route)>,
40 default_model: String,
41 ) -> Self {
42 let name_to_index: HashMap<&str, usize> = providers
44 .iter()
45 .enumerate()
46 .map(|(i, (name, _))| (name.as_str(), i))
47 .collect();
48
49 let resolved_routes: HashMap<String, (usize, String)> = routes
51 .into_iter()
52 .filter_map(|(hint, route)| {
53 let index = name_to_index.get(route.provider_name.as_str()).copied();
54 match index {
55 Some(i) => Some((hint, (i, route.model))),
56 None => {
57 tracing::warn!(
58 hint = hint,
59 provider = route.provider_name,
60 "Route references unknown provider, skipping"
61 );
62 None
63 }
64 }
65 })
66 .collect();
67
68 Self {
69 routes: resolved_routes,
70 providers,
71 default_index: 0,
72 default_model,
73 }
74 }
75
76 pub fn resolve_cost_optimized(
85 &self,
86 model: &str,
87 prices: &HashMap<String, ModelPricing>,
88 required_vision: bool,
89 required_tools: bool,
90 ) -> (usize, String) {
91 let hint = model.strip_prefix("hint:");
92 let is_cost_hint = matches!(hint, Some("cost-optimized" | "cheapest"));
93
94 if !is_cost_hint {
95 return self.resolve(model);
96 }
97
98 let mut candidates: Vec<(usize, String, f64)> = Vec::new();
99
100 for (idx, route_model) in self.routes.values() {
101 if let Some((_, provider)) = self.providers.get(*idx) {
103 if required_vision && !provider.supports_vision() {
104 continue;
105 }
106 if required_tools && !provider.supports_native_tools() {
107 continue;
108 }
109 }
110
111 if let Some(pricing) = prices.get(route_model) {
112 let total_cost = pricing.input + pricing.output;
113 candidates.push((*idx, route_model.clone(), total_cost));
114 }
115 }
116
117 candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
119
120 if let Some((idx, route_model, _)) = candidates.into_iter().next() {
121 return (idx, route_model);
122 }
123
124 tracing::warn!(
126 "No cost-optimized route found with matching pricing data, \
127 falling back to default"
128 );
129 (self.default_index, self.default_model.clone())
130 }
131
132 fn resolve(&self, model: &str) -> (usize, String) {
138 if let Some(hint) = model.strip_prefix("hint:") {
139 if let Some((idx, resolved_model)) = self.routes.get(hint) {
140 return (*idx, resolved_model.clone());
141 }
142 tracing::warn!(
143 hint = hint,
144 "Unknown route hint, falling back to default provider"
145 );
146 }
147
148 (self.default_index, model.to_string())
150 }
151}
152
153#[derive(Debug, Clone)]
159pub struct CostOptimizedStrategy {
160 pub prices: HashMap<String, ModelPricing>,
162 pub required_vision: bool,
164 pub required_tools: bool,
166}
167
168impl CostOptimizedStrategy {
169 pub fn new(prices: HashMap<String, ModelPricing>) -> Self {
171 Self {
172 prices,
173 required_vision: false,
174 required_tools: false,
175 }
176 }
177
178 pub fn with_vision(mut self, required: bool) -> Self {
180 self.required_vision = required;
181 self
182 }
183
184 pub fn with_tools(mut self, required: bool) -> Self {
186 self.required_tools = required;
187 self
188 }
189
190 pub fn score(&self, model: &str) -> Option<f64> {
193 self.prices.get(model).map(|p| p.input + p.output)
194 }
195}
196
197#[async_trait]
198impl Provider for RouterProvider {
199 async fn chat_with_system(
200 &self,
201 system_prompt: Option<&str>,
202 message: &str,
203 model: &str,
204 temperature: f64,
205 ) -> anyhow::Result<String> {
206 let (provider_idx, resolved_model) = self.resolve(model);
207
208 let (provider_name, provider) = &self.providers[provider_idx];
209 tracing::info!(
210 provider = provider_name.as_str(),
211 model = resolved_model.as_str(),
212 "Router dispatching request"
213 );
214
215 provider
216 .chat_with_system(system_prompt, message, &resolved_model, temperature)
217 .await
218 }
219
220 async fn chat_with_history(
221 &self,
222 messages: &[ChatMessage],
223 model: &str,
224 temperature: f64,
225 ) -> anyhow::Result<String> {
226 let (provider_idx, resolved_model) = self.resolve(model);
227 let (_, provider) = &self.providers[provider_idx];
228 provider
229 .chat_with_history(messages, &resolved_model, temperature)
230 .await
231 }
232
233 async fn chat(
234 &self,
235 request: ChatRequest<'_>,
236 model: &str,
237 temperature: f64,
238 ) -> anyhow::Result<ChatResponse> {
239 let (provider_idx, resolved_model) = self.resolve(model);
240 let (_, provider) = &self.providers[provider_idx];
241 provider.chat(request, &resolved_model, temperature).await
242 }
243
244 async fn chat_with_tools(
245 &self,
246 messages: &[ChatMessage],
247 tools: &[serde_json::Value],
248 model: &str,
249 temperature: f64,
250 ) -> anyhow::Result<ChatResponse> {
251 let (provider_idx, resolved_model) = self.resolve(model);
252 let (_, provider) = &self.providers[provider_idx];
253 provider
254 .chat_with_tools(messages, tools, &resolved_model, temperature)
255 .await
256 }
257
258 fn supports_native_tools(&self) -> bool {
259 self.providers
260 .get(self.default_index)
261 .map(|(_, p)| p.supports_native_tools())
262 .unwrap_or(false)
263 }
264
265 fn supports_streaming(&self) -> bool {
266 self.providers
267 .iter()
268 .any(|(_, provider)| provider.supports_streaming())
269 }
270
271 fn supports_streaming_tool_events(&self) -> bool {
272 self.providers
273 .iter()
274 .any(|(_, provider)| provider.supports_streaming_tool_events())
275 }
276
277 fn stream_chat_with_history(
278 &self,
279 messages: &[ChatMessage],
280 model: &str,
281 temperature: f64,
282 options: StreamOptions,
283 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
284 let (provider_idx, resolved_model) = self.resolve(model);
285 let (_, provider) = &self.providers[provider_idx];
286 provider.stream_chat_with_history(messages, &resolved_model, temperature, options)
287 }
288
289 fn stream_chat(
290 &self,
291 request: ChatRequest<'_>,
292 model: &str,
293 temperature: f64,
294 options: StreamOptions,
295 ) -> BoxStream<'static, StreamResult<StreamEvent>> {
296 let (provider_idx, resolved_model) = self.resolve(model);
297 let (_, provider) = &self.providers[provider_idx];
298 provider.stream_chat(request, &resolved_model, temperature, options)
299 }
300
301 fn supports_vision(&self) -> bool {
302 self.providers
303 .iter()
304 .any(|(_, provider)| provider.supports_vision())
305 }
306
307 async fn warmup(&self) -> anyhow::Result<()> {
308 for (name, provider) in &self.providers {
309 tracing::info!(provider = name, "Warming up routed provider");
310 if let Err(e) = provider.warmup().await {
311 tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
312 }
313 }
314 Ok(())
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::tools::ToolSpec;
322 use futures_util::StreamExt;
323 use std::sync::Arc;
324 use std::sync::atomic::{AtomicUsize, Ordering};
325
326 struct MockProvider {
327 calls: Arc<AtomicUsize>,
328 response: &'static str,
329 last_model: parking_lot::Mutex<String>,
330 }
331
332 impl MockProvider {
333 fn new(response: &'static str) -> Self {
334 Self {
335 calls: Arc::new(AtomicUsize::new(0)),
336 response,
337 last_model: parking_lot::Mutex::new(String::new()),
338 }
339 }
340
341 fn call_count(&self) -> usize {
342 self.calls.load(Ordering::SeqCst)
343 }
344
345 fn last_model(&self) -> String {
346 self.last_model.lock().clone()
347 }
348 }
349
350 #[async_trait]
351 impl Provider for MockProvider {
352 async fn chat_with_system(
353 &self,
354 _system_prompt: Option<&str>,
355 _message: &str,
356 model: &str,
357 _temperature: f64,
358 ) -> anyhow::Result<String> {
359 self.calls.fetch_add(1, Ordering::SeqCst);
360 *self.last_model.lock() = model.to_string();
361 Ok(self.response.to_string())
362 }
363 }
364
365 fn make_router(
366 providers: Vec<(&'static str, &'static str)>,
367 routes: Vec<(&str, &str, &str)>,
368 ) -> (RouterProvider, Vec<Arc<MockProvider>>) {
369 let mocks: Vec<Arc<MockProvider>> = providers
370 .iter()
371 .map(|(_, response)| Arc::new(MockProvider::new(response)))
372 .collect();
373
374 let provider_list: Vec<(String, Box<dyn Provider>)> = providers
375 .iter()
376 .zip(mocks.iter())
377 .map(|((name, _), mock)| {
378 (
379 (*name).to_string(),
380 Box::new(Arc::clone(mock)) as Box<dyn Provider>,
381 )
382 })
383 .collect();
384
385 let route_list: Vec<(String, Route)> = routes
386 .iter()
387 .map(|(hint, provider_name, model)| {
388 (
389 (*hint).to_string(),
390 Route {
391 provider_name: (*provider_name).to_string(),
392 model: (*model).to_string(),
393 },
394 )
395 })
396 .collect();
397
398 let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
399
400 (router, mocks)
401 }
402
403 #[async_trait]
405 impl Provider for Arc<MockProvider> {
406 async fn chat_with_system(
407 &self,
408 system_prompt: Option<&str>,
409 message: &str,
410 model: &str,
411 temperature: f64,
412 ) -> anyhow::Result<String> {
413 self.as_ref()
414 .chat_with_system(system_prompt, message, model, temperature)
415 .await
416 }
417 }
418
419 struct StreamingMockProvider {
420 stream_calls: Arc<AtomicUsize>,
421 last_stream_model: parking_lot::Mutex<String>,
422 response: &'static str,
423 }
424
425 impl StreamingMockProvider {
426 fn new(response: &'static str) -> Self {
427 Self {
428 stream_calls: Arc::new(AtomicUsize::new(0)),
429 last_stream_model: parking_lot::Mutex::new(String::new()),
430 response,
431 }
432 }
433 }
434
435 #[async_trait]
436 impl Provider for StreamingMockProvider {
437 async fn chat_with_system(
438 &self,
439 _system_prompt: Option<&str>,
440 _message: &str,
441 _model: &str,
442 _temperature: f64,
443 ) -> anyhow::Result<String> {
444 Ok("ok".to_string())
445 }
446
447 fn supports_streaming(&self) -> bool {
448 true
449 }
450
451 fn stream_chat_with_history(
452 &self,
453 _messages: &[ChatMessage],
454 model: &str,
455 _temperature: f64,
456 _options: StreamOptions,
457 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
458 self.stream_calls.fetch_add(1, Ordering::SeqCst);
459 *self.last_stream_model.lock() = model.to_string();
460 let chunks = vec![
461 Ok(StreamChunk::delta(self.response)),
462 Ok(StreamChunk::final_chunk()),
463 ];
464 futures_util::stream::iter(chunks).boxed()
465 }
466 }
467
468 #[async_trait]
469 impl Provider for Arc<StreamingMockProvider> {
470 async fn chat_with_system(
471 &self,
472 system_prompt: Option<&str>,
473 message: &str,
474 model: &str,
475 temperature: f64,
476 ) -> anyhow::Result<String> {
477 self.as_ref()
478 .chat_with_system(system_prompt, message, model, temperature)
479 .await
480 }
481
482 fn supports_streaming(&self) -> bool {
483 self.as_ref().supports_streaming()
484 }
485
486 fn stream_chat_with_history(
487 &self,
488 messages: &[ChatMessage],
489 model: &str,
490 temperature: f64,
491 options: StreamOptions,
492 ) -> BoxStream<'static, StreamResult<StreamChunk>> {
493 self.as_ref()
494 .stream_chat_with_history(messages, model, temperature, options)
495 }
496 }
497
498 struct ToolEventStreamingMockProvider {
499 stream_calls: Arc<AtomicUsize>,
500 tool_event_calls: Arc<AtomicUsize>,
501 last_stream_model: parking_lot::Mutex<String>,
502 }
503
504 impl ToolEventStreamingMockProvider {
505 fn new() -> Self {
506 Self {
507 stream_calls: Arc::new(AtomicUsize::new(0)),
508 tool_event_calls: Arc::new(AtomicUsize::new(0)),
509 last_stream_model: parking_lot::Mutex::new(String::new()),
510 }
511 }
512 }
513
514 #[async_trait]
515 impl Provider for ToolEventStreamingMockProvider {
516 async fn chat_with_system(
517 &self,
518 _system_prompt: Option<&str>,
519 _message: &str,
520 _model: &str,
521 _temperature: f64,
522 ) -> anyhow::Result<String> {
523 Ok("ok".to_string())
524 }
525
526 fn supports_streaming(&self) -> bool {
527 true
528 }
529
530 fn supports_streaming_tool_events(&self) -> bool {
531 true
532 }
533
534 fn stream_chat(
535 &self,
536 request: ChatRequest<'_>,
537 model: &str,
538 _temperature: f64,
539 _options: StreamOptions,
540 ) -> BoxStream<'static, StreamResult<StreamEvent>> {
541 self.stream_calls.fetch_add(1, Ordering::SeqCst);
542 if request.tools.is_some_and(|tools| !tools.is_empty()) {
543 self.tool_event_calls.fetch_add(1, Ordering::SeqCst);
544 }
545 *self.last_stream_model.lock() = model.to_string();
546 futures_util::stream::iter(vec![
547 Ok(StreamEvent::ToolCall(crate::providers::ToolCall {
548 id: "call_router_1".to_string(),
549 name: "shell".to_string(),
550 arguments: r#"{"command":"date"}"#.to_string(),
551 })),
552 Ok(StreamEvent::Final),
553 ])
554 .boxed()
555 }
556 }
557
558 #[async_trait]
559 impl Provider for Arc<ToolEventStreamingMockProvider> {
560 async fn chat_with_system(
561 &self,
562 system_prompt: Option<&str>,
563 message: &str,
564 model: &str,
565 temperature: f64,
566 ) -> anyhow::Result<String> {
567 self.as_ref()
568 .chat_with_system(system_prompt, message, model, temperature)
569 .await
570 }
571
572 fn supports_streaming(&self) -> bool {
573 self.as_ref().supports_streaming()
574 }
575
576 fn supports_streaming_tool_events(&self) -> bool {
577 self.as_ref().supports_streaming_tool_events()
578 }
579
580 fn stream_chat(
581 &self,
582 request: ChatRequest<'_>,
583 model: &str,
584 temperature: f64,
585 options: StreamOptions,
586 ) -> BoxStream<'static, StreamResult<StreamEvent>> {
587 self.as_ref()
588 .stream_chat(request, model, temperature, options)
589 }
590 }
591
592 #[tokio::test]
593 async fn routes_hint_to_correct_provider() {
594 let (router, mocks) = make_router(
595 vec![("fast", "fast-response"), ("smart", "smart-response")],
596 vec![
597 ("fast", "fast", "llama-3-70b"),
598 ("reasoning", "smart", "claude-opus"),
599 ],
600 );
601
602 let result = router
603 .simple_chat("hello", "hint:reasoning", 0.5)
604 .await
605 .unwrap();
606 assert_eq!(result, "smart-response");
607 assert_eq!(mocks[1].call_count(), 1);
608 assert_eq!(mocks[1].last_model(), "claude-opus");
609 assert_eq!(mocks[0].call_count(), 0);
610 }
611
612 #[tokio::test]
613 async fn routes_fast_hint() {
614 let (router, mocks) = make_router(
615 vec![("fast", "fast-response"), ("smart", "smart-response")],
616 vec![("fast", "fast", "llama-3-70b")],
617 );
618
619 let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap();
620 assert_eq!(result, "fast-response");
621 assert_eq!(mocks[0].call_count(), 1);
622 assert_eq!(mocks[0].last_model(), "llama-3-70b");
623 }
624
625 #[tokio::test]
626 async fn unknown_hint_falls_back_to_default() {
627 let (router, mocks) = make_router(
628 vec![("default", "default-response"), ("other", "other-response")],
629 vec![],
630 );
631
632 let result = router
633 .simple_chat("hello", "hint:nonexistent", 0.5)
634 .await
635 .unwrap();
636 assert_eq!(result, "default-response");
637 assert_eq!(mocks[0].call_count(), 1);
638 assert_eq!(mocks[0].last_model(), "hint:nonexistent");
640 }
641
642 #[tokio::test]
643 async fn non_hint_model_uses_default_provider() {
644 let (router, mocks) = make_router(
645 vec![
646 ("primary", "primary-response"),
647 ("secondary", "secondary-response"),
648 ],
649 vec![("code", "secondary", "codellama")],
650 );
651
652 let result = router
653 .simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
654 .await
655 .unwrap();
656 assert_eq!(result, "primary-response");
657 assert_eq!(mocks[0].call_count(), 1);
658 assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
659 }
660
661 #[test]
662 fn resolve_preserves_model_for_non_hints() {
663 let (router, _) = make_router(vec![("default", "ok")], vec![]);
664
665 let (idx, model) = router.resolve("gpt-4o");
666 assert_eq!(idx, 0);
667 assert_eq!(model, "gpt-4o");
668 }
669
670 #[test]
671 fn resolve_strips_hint_prefix() {
672 let (router, _) = make_router(
673 vec![("fast", "ok"), ("smart", "ok")],
674 vec![("reasoning", "smart", "claude-opus")],
675 );
676
677 let (idx, model) = router.resolve("hint:reasoning");
678 assert_eq!(idx, 1);
679 assert_eq!(model, "claude-opus");
680 }
681
682 #[test]
683 fn skips_routes_with_unknown_provider() {
684 let (router, _) = make_router(
685 vec![("default", "ok")],
686 vec![("broken", "nonexistent", "model")],
687 );
688
689 assert!(!router.routes.contains_key("broken"));
691 }
692
693 #[tokio::test]
694 async fn warmup_calls_all_providers() {
695 let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
696
697 assert!(router.warmup().await.is_ok());
699 }
700
701 #[tokio::test]
702 async fn chat_with_system_passes_system_prompt() {
703 let mock = Arc::new(MockProvider::new("response"));
704 let router = RouterProvider::new(
705 vec![(
706 "default".into(),
707 Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
708 )],
709 vec![],
710 "model".into(),
711 );
712
713 let result = router
714 .chat_with_system(Some("system"), "hello", "model", 0.5)
715 .await
716 .unwrap();
717 assert_eq!(result, "response");
718 assert_eq!(mock.call_count(), 1);
719 }
720
721 #[tokio::test]
722 async fn chat_with_tools_delegates_to_resolved_provider() {
723 let mock = Arc::new(MockProvider::new("tool-response"));
724 let router = RouterProvider::new(
725 vec![(
726 "default".into(),
727 Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
728 )],
729 vec![],
730 "model".into(),
731 );
732
733 let messages = vec![ChatMessage {
734 role: "user".to_string(),
735 content: "use tools".to_string(),
736 }];
737 let tools = vec![serde_json::json!({
738 "type": "function",
739 "function": {
740 "name": "shell",
741 "description": "Run shell command",
742 "parameters": {}
743 }
744 })];
745
746 let result = router
749 .chat_with_tools(&messages, &tools, "model", 0.7)
750 .await
751 .unwrap();
752 assert_eq!(result.text.as_deref(), Some("tool-response"));
753 assert_eq!(mock.call_count(), 1);
754 assert_eq!(mock.last_model(), "model");
755 }
756
757 #[tokio::test]
758 async fn chat_with_tools_routes_hint_correctly() {
759 let (router, mocks) = make_router(
760 vec![("fast", "fast-tool"), ("smart", "smart-tool")],
761 vec![("reasoning", "smart", "claude-opus")],
762 );
763
764 let messages = vec![ChatMessage {
765 role: "user".to_string(),
766 content: "reason about this".to_string(),
767 }];
768 let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})];
769
770 let result = router
771 .chat_with_tools(&messages, &tools, "hint:reasoning", 0.5)
772 .await
773 .unwrap();
774 assert_eq!(result.text.as_deref(), Some("smart-tool"));
775 assert_eq!(mocks[1].call_count(), 1);
776 assert_eq!(mocks[1].last_model(), "claude-opus");
777 assert_eq!(mocks[0].call_count(), 0);
778 }
779
780 use crate::providers::traits::ProviderCapabilities;
783
784 struct CapableMockProvider {
786 response: &'static str,
787 vision: bool,
788 tools: bool,
789 }
790
791 impl CapableMockProvider {
792 fn new(response: &'static str, vision: bool, tools: bool) -> Self {
793 Self {
794 response,
795 vision,
796 tools,
797 }
798 }
799 }
800
801 #[async_trait]
802 impl Provider for CapableMockProvider {
803 fn capabilities(&self) -> ProviderCapabilities {
804 ProviderCapabilities {
805 native_tool_calling: self.tools,
806 vision: self.vision,
807 prompt_caching: false,
808 }
809 }
810
811 async fn chat_with_system(
812 &self,
813 _system_prompt: Option<&str>,
814 _message: &str,
815 _model: &str,
816 _temperature: f64,
817 ) -> anyhow::Result<String> {
818 Ok(self.response.to_string())
819 }
820 }
821
822 fn make_pricing(entries: Vec<(&str, f64, f64)>) -> HashMap<String, ModelPricing> {
823 entries
824 .into_iter()
825 .map(|(model, input, output)| (model.to_string(), ModelPricing { input, output }))
826 .collect()
827 }
828
829 #[test]
830 fn cost_optimized_selects_cheapest_provider() {
831 let providers: Vec<(String, Box<dyn Provider>)> = vec![
832 (
833 "expensive".into(),
834 Box::new(CapableMockProvider::new("exp", false, false)),
835 ),
836 (
837 "cheap".into(),
838 Box::new(CapableMockProvider::new("chp", false, false)),
839 ),
840 ];
841 let routes = vec![
842 (
843 "expensive".to_string(),
844 Route {
845 provider_name: "expensive".into(),
846 model: "big-model".into(),
847 },
848 ),
849 (
850 "cheap".to_string(),
851 Route {
852 provider_name: "cheap".into(),
853 model: "small-model".into(),
854 },
855 ),
856 ];
857 let router = RouterProvider::new(providers, routes, "default-model".into());
858
859 let prices = make_pricing(vec![("big-model", 15.0, 75.0), ("small-model", 0.25, 1.25)]);
860
861 let (idx, model) =
862 router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
863 assert_eq!(model, "small-model");
864 assert_eq!(idx, 1);
865 }
866
867 #[test]
868 fn cost_optimized_respects_vision_requirement() {
869 let providers: Vec<(String, Box<dyn Provider>)> = vec![
870 (
871 "no-vision".into(),
872 Box::new(CapableMockProvider::new("nv", false, false)),
873 ),
874 (
875 "has-vision".into(),
876 Box::new(CapableMockProvider::new("hv", true, false)),
877 ),
878 ];
879 let routes = vec![
880 (
881 "cheap".to_string(),
882 Route {
883 provider_name: "no-vision".into(),
884 model: "cheap-model".into(),
885 },
886 ),
887 (
888 "vision".to_string(),
889 Route {
890 provider_name: "has-vision".into(),
891 model: "vision-model".into(),
892 },
893 ),
894 ];
895 let router = RouterProvider::new(providers, routes, "default-model".into());
896
897 let prices = make_pricing(vec![
898 ("cheap-model", 0.10, 0.40),
899 ("vision-model", 3.0, 15.0),
900 ]);
901
902 let (_, model) = router.resolve_cost_optimized("hint:cheapest", &prices, true, false);
904 assert_eq!(model, "vision-model");
905 }
906
907 #[test]
908 fn cost_optimized_respects_tools_requirement() {
909 let providers: Vec<(String, Box<dyn Provider>)> = vec![
910 (
911 "no-tools".into(),
912 Box::new(CapableMockProvider::new("nt", false, false)),
913 ),
914 (
915 "has-tools".into(),
916 Box::new(CapableMockProvider::new("ht", false, true)),
917 ),
918 ];
919 let routes = vec![
920 (
921 "basic".to_string(),
922 Route {
923 provider_name: "no-tools".into(),
924 model: "basic-model".into(),
925 },
926 ),
927 (
928 "tools".to_string(),
929 Route {
930 provider_name: "has-tools".into(),
931 model: "tools-model".into(),
932 },
933 ),
934 ];
935 let router = RouterProvider::new(providers, routes, "default-model".into());
936
937 let prices = make_pricing(vec![
938 ("basic-model", 0.10, 0.40),
939 ("tools-model", 5.0, 15.0),
940 ]);
941
942 let (_, model) = router.resolve_cost_optimized("hint:cost-optimized", &prices, false, true);
944 assert_eq!(model, "tools-model");
945 }
946
947 #[test]
948 fn cost_optimized_falls_back_when_no_pricing() {
949 let (router, _) = make_router(
950 vec![("default", "ok"), ("other", "ok")],
951 vec![("route-a", "other", "some-model")],
952 );
953
954 let prices: HashMap<String, ModelPricing> = HashMap::new();
956 let (idx, model) =
957 router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
958 assert_eq!(idx, 0);
959 assert_eq!(model, "default-model");
960 }
961
962 #[test]
963 fn cost_optimized_with_single_route() {
964 let providers: Vec<(String, Box<dyn Provider>)> = vec![(
965 "only".into(),
966 Box::new(CapableMockProvider::new("ok", false, false)),
967 )];
968 let routes = vec![(
969 "single".to_string(),
970 Route {
971 provider_name: "only".into(),
972 model: "the-model".into(),
973 },
974 )];
975 let router = RouterProvider::new(providers, routes, "default-model".into());
976
977 let prices = make_pricing(vec![("the-model", 1.0, 2.0)]);
978
979 let (idx, model) = router.resolve_cost_optimized("hint:cheapest", &prices, false, false);
980 assert_eq!(idx, 0);
981 assert_eq!(model, "the-model");
982 }
983
984 #[test]
985 fn cost_optimized_prefers_lower_total_cost() {
986 let providers: Vec<(String, Box<dyn Provider>)> = vec![
987 (
988 "p1".into(),
989 Box::new(CapableMockProvider::new("r1", false, false)),
990 ),
991 (
992 "p2".into(),
993 Box::new(CapableMockProvider::new("r2", false, false)),
994 ),
995 (
996 "p3".into(),
997 Box::new(CapableMockProvider::new("r3", false, false)),
998 ),
999 ];
1000 let routes = vec![
1001 (
1002 "a".to_string(),
1003 Route {
1004 provider_name: "p1".into(),
1005 model: "model-a".into(),
1006 },
1007 ),
1008 (
1009 "b".to_string(),
1010 Route {
1011 provider_name: "p2".into(),
1012 model: "model-b".into(),
1013 },
1014 ),
1015 (
1016 "c".to_string(),
1017 Route {
1018 provider_name: "p3".into(),
1019 model: "model-c".into(),
1020 },
1021 ),
1022 ];
1023 let router = RouterProvider::new(providers, routes, "default-model".into());
1024
1025 let prices = make_pricing(vec![
1026 ("model-a", 10.0, 50.0), ("model-b", 0.15, 0.60), ("model-c", 3.0, 15.0), ]);
1030
1031 let (idx, model) =
1032 router.resolve_cost_optimized("hint:cost-optimized", &prices, false, false);
1033 assert_eq!(model, "model-b");
1034 assert_eq!(idx, 1);
1035 }
1036
1037 #[test]
1038 fn cost_optimized_strategy_score() {
1039 let prices = make_pricing(vec![("cheap", 0.10, 0.40), ("expensive", 15.0, 75.0)]);
1040 let strategy = CostOptimizedStrategy::new(prices);
1041
1042 assert!((strategy.score("cheap").unwrap() - 0.50).abs() < f64::EPSILON);
1043 assert!((strategy.score("expensive").unwrap() - 90.0).abs() < f64::EPSILON);
1044 assert!(strategy.score("unknown").is_none());
1045 }
1046
1047 #[tokio::test]
1048 async fn supports_streaming_returns_true_when_any_provider_supports_it() {
1049 let streaming = Arc::new(StreamingMockProvider::new("stream"));
1050 let router = RouterProvider::new(
1051 vec![
1052 (
1053 "default".into(),
1054 Box::new(MockProvider::new("default")) as Box<dyn Provider>,
1055 ),
1056 (
1057 "streaming".into(),
1058 Box::new(Arc::clone(&streaming)) as Box<dyn Provider>,
1059 ),
1060 ],
1061 vec![(
1062 "reasoning".into(),
1063 Route {
1064 provider_name: "streaming".into(),
1065 model: "claude-opus".into(),
1066 },
1067 )],
1068 "model".into(),
1069 );
1070
1071 assert!(router.supports_streaming());
1072 }
1073
1074 #[tokio::test]
1075 async fn stream_chat_with_history_routes_hint_to_correct_provider_and_model() {
1076 let streaming = Arc::new(StreamingMockProvider::new("streamed response"));
1077 let router = RouterProvider::new(
1078 vec![
1079 (
1080 "default".into(),
1081 Box::new(MockProvider::new("default")) as Box<dyn Provider>,
1082 ),
1083 (
1084 "streaming".into(),
1085 Box::new(Arc::clone(&streaming)) as Box<dyn Provider>,
1086 ),
1087 ],
1088 vec![(
1089 "reasoning".into(),
1090 Route {
1091 provider_name: "streaming".into(),
1092 model: "claude-opus".into(),
1093 },
1094 )],
1095 "model".into(),
1096 );
1097
1098 let messages = vec![ChatMessage::user("hello")];
1099 let mut stream = router.stream_chat_with_history(
1100 &messages,
1101 "hint:reasoning",
1102 0.0,
1103 StreamOptions::new(true),
1104 );
1105
1106 let mut collected = String::new();
1107 while let Some(chunk) = stream.next().await {
1108 let chunk = chunk.expect("stream chunk should be ok");
1109 collected.push_str(&chunk.delta);
1110 }
1111
1112 assert_eq!(collected, "streamed response");
1113 assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1114 assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1115 }
1116
1117 #[tokio::test]
1118 async fn stream_chat_routes_hint_with_structured_tool_events() {
1119 let streaming = Arc::new(ToolEventStreamingMockProvider::new());
1120 let router = RouterProvider::new(
1121 vec![
1122 (
1123 "default".into(),
1124 Box::new(MockProvider::new("default")) as Box<dyn Provider>,
1125 ),
1126 (
1127 "streaming".into(),
1128 Box::new(Arc::clone(&streaming)) as Box<dyn Provider>,
1129 ),
1130 ],
1131 vec![(
1132 "reasoning".into(),
1133 Route {
1134 provider_name: "streaming".into(),
1135 model: "claude-opus".into(),
1136 },
1137 )],
1138 "model".into(),
1139 );
1140
1141 let messages = vec![ChatMessage::user("hello")];
1142 let tools = vec![ToolSpec {
1143 name: "shell".to_string(),
1144 description: "run shell commands".to_string(),
1145 parameters: serde_json::json!({
1146 "type": "object",
1147 "properties": {
1148 "command": { "type": "string" }
1149 }
1150 }),
1151 }];
1152
1153 let mut stream = router.stream_chat(
1154 ChatRequest {
1155 messages: &messages,
1156 tools: Some(&tools),
1157 },
1158 "hint:reasoning",
1159 0.0,
1160 StreamOptions::new(true),
1161 );
1162
1163 let first = stream.next().await.unwrap().unwrap();
1164 let second = stream.next().await.unwrap().unwrap();
1165 assert!(stream.next().await.is_none());
1166
1167 match first {
1168 StreamEvent::ToolCall(call) => {
1169 assert_eq!(call.name, "shell");
1170 assert_eq!(call.arguments, r#"{"command":"date"}"#);
1171 }
1172 other => panic!("expected tool-call event, got {other:?}"),
1173 }
1174 assert!(matches!(second, StreamEvent::Final));
1175 assert_eq!(streaming.stream_calls.load(Ordering::SeqCst), 1);
1176 assert_eq!(streaming.tool_event_calls.load(Ordering::SeqCst), 1);
1177 assert_eq!(*streaming.last_stream_model.lock(), "claude-opus");
1178 }
1179}