1use converge_traits::{
13 AgentRequirements, ComplianceLevel, CostClass, DataSovereignty, LlmError, ModelSelectorTrait,
14};
15
16#[derive(Debug, Clone, PartialEq)]
18pub struct FitnessBreakdown {
19 pub cost_score: f64,
22 pub latency_score: f64,
25 pub quality_score: f64,
27 pub total: f64,
29}
30
31impl std::fmt::Display for FitnessBreakdown {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(
34 f,
35 "{:.3} = 40%×cost({:.2}) + 30%×latency({:.2}) + 30%×quality({:.2})",
36 self.total, self.cost_score, self.latency_score, self.quality_score
37 )
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct SelectionResult {
44 pub selected: ModelMetadata,
46 pub fitness: FitnessBreakdown,
48 pub candidates: Vec<(ModelMetadata, FitnessBreakdown)>,
51 pub rejected: Vec<(ModelMetadata, RejectionReason)>,
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub enum RejectionReason {
58 ProviderUnavailable,
60 CostTooHigh {
62 model_cost: CostClass,
63 max_allowed: CostClass,
64 },
65 LatencyTooHigh {
67 model_latency_ms: u32,
68 max_allowed_ms: u32,
69 },
70 QualityTooLow {
72 model_quality: f64,
73 min_required: f64,
74 },
75 ReasoningRequired,
77 WebSearchRequired,
79 DataSovereigntyMismatch {
81 required: DataSovereignty,
82 model_has: DataSovereignty,
83 },
84 ComplianceMismatch {
86 required: ComplianceLevel,
87 model_has: ComplianceLevel,
88 },
89 MultilingualRequired,
91}
92
93impl std::fmt::Display for RejectionReason {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Self::ProviderUnavailable => write!(f, "provider unavailable (no API key)"),
97 Self::CostTooHigh {
98 model_cost,
99 max_allowed,
100 } => {
101 write!(f, "cost {model_cost:?} exceeds max {max_allowed:?}")
102 }
103 Self::LatencyTooHigh {
104 model_latency_ms,
105 max_allowed_ms,
106 } => {
107 write!(
108 f,
109 "latency {model_latency_ms}ms exceeds max {max_allowed_ms}ms"
110 )
111 }
112 Self::QualityTooLow {
113 model_quality,
114 min_required,
115 } => {
116 write!(f, "quality {model_quality:.2} below min {min_required:.2}")
117 }
118 Self::ReasoningRequired => write!(f, "reasoning required but not supported"),
119 Self::WebSearchRequired => write!(f, "web search required but not supported"),
120 Self::DataSovereigntyMismatch {
121 required,
122 model_has,
123 } => {
124 write!(f, "data sovereignty {model_has:?} != required {required:?}")
125 }
126 Self::ComplianceMismatch {
127 required,
128 model_has,
129 } => {
130 write!(f, "compliance {model_has:?} != required {required:?}")
131 }
132 Self::MultilingualRequired => write!(f, "multilingual required but not supported"),
133 }
134 }
135}
136
137#[derive(Debug, Clone, PartialEq)]
139#[allow(clippy::struct_excessive_bools)]
140pub struct ModelMetadata {
141 pub provider: String,
143 pub model: String,
145 pub cost_class: CostClass,
147 pub typical_latency_ms: u32,
149 pub quality: f64,
151 pub has_reasoning: bool,
153 pub supports_web_search: bool,
155 pub data_sovereignty: DataSovereignty,
157 pub compliance: ComplianceLevel,
159 pub supports_multilingual: bool,
161 pub context_tokens: usize,
164 pub supports_tool_use: bool,
166 pub supports_vision: bool,
168 pub supports_structured_output: bool,
170 pub supports_code: bool,
172 pub country: String,
174 pub region: String,
176}
177
178impl ModelMetadata {
179 #[must_use]
181 pub fn new(
182 provider: impl Into<String>,
183 model: impl Into<String>,
184 cost_class: CostClass,
185 typical_latency_ms: u32,
186 quality: f64,
187 ) -> Self {
188 Self {
189 provider: provider.into(),
190 model: model.into(),
191 cost_class,
192 typical_latency_ms,
193 quality: quality.clamp(0.0, 1.0),
194 has_reasoning: false,
195 supports_web_search: false,
196 data_sovereignty: DataSovereignty::Any,
197 compliance: ComplianceLevel::None,
198 supports_multilingual: false,
199 context_tokens: 8192,
201 supports_tool_use: false,
202 supports_vision: false,
203 supports_structured_output: false,
204 supports_code: false,
205 country: "US".to_string(),
206 region: "US".to_string(),
207 }
208 }
209
210 #[must_use]
212 pub fn with_reasoning(mut self, has: bool) -> Self {
213 self.has_reasoning = has;
214 self
215 }
216
217 #[must_use]
219 pub fn with_web_search(mut self, supports: bool) -> Self {
220 self.supports_web_search = supports;
221 self
222 }
223
224 #[must_use]
226 pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
227 self.data_sovereignty = sovereignty;
228 self
229 }
230
231 #[must_use]
233 pub fn with_compliance(mut self, compliance: ComplianceLevel) -> Self {
234 self.compliance = compliance;
235 self
236 }
237
238 #[must_use]
240 pub fn with_multilingual(mut self, supports: bool) -> Self {
241 self.supports_multilingual = supports;
242 self
243 }
244
245 #[must_use]
247 pub fn with_context_tokens(mut self, tokens: usize) -> Self {
248 self.context_tokens = tokens;
249 self
250 }
251
252 #[must_use]
254 pub fn with_tool_use(mut self, supports: bool) -> Self {
255 self.supports_tool_use = supports;
256 self
257 }
258
259 #[must_use]
261 pub fn with_vision(mut self, supports: bool) -> Self {
262 self.supports_vision = supports;
263 self
264 }
265
266 #[must_use]
268 pub fn with_structured_output(mut self, supports: bool) -> Self {
269 self.supports_structured_output = supports;
270 self
271 }
272
273 #[must_use]
275 pub fn with_code(mut self, supports: bool) -> Self {
276 self.supports_code = supports;
277 self
278 }
279
280 #[must_use]
282 pub fn with_location(mut self, country: impl Into<String>, region: impl Into<String>) -> Self {
283 self.country = country.into();
284 self.region = region.into();
285 self
286 }
287
288 #[must_use]
290 pub fn satisfies(&self, requirements: &AgentRequirements) -> bool {
291 if !requirements
293 .max_cost_class
294 .allowed_classes()
295 .contains(&self.cost_class)
296 {
297 return false;
298 }
299
300 if self.typical_latency_ms > requirements.max_latency_ms {
302 return false;
303 }
304
305 if requirements.requires_reasoning && !self.has_reasoning {
307 return false;
308 }
309
310 if requirements.requires_web_search && !self.supports_web_search {
312 return false;
313 }
314
315 if self.quality < requirements.min_quality {
317 return false;
318 }
319
320 if requirements.data_sovereignty != DataSovereignty::Any
322 && self.data_sovereignty != requirements.data_sovereignty
323 {
324 return false;
325 }
326
327 if requirements.compliance != ComplianceLevel::None
329 && self.compliance != requirements.compliance
330 {
331 return false;
332 }
333
334 if requirements.requires_multilingual && !self.supports_multilingual {
336 return false;
337 }
338
339 true
340 }
341
342 #[must_use]
349 pub fn fitness_score(&self, requirements: &AgentRequirements) -> f64 {
350 if !self.satisfies(requirements) {
351 return 0.0;
352 }
353
354 let cost_score = match self.cost_class {
356 CostClass::VeryLow => 1.0,
357 CostClass::Low => 0.8,
358 CostClass::Medium => 0.6,
359 CostClass::High => 0.4,
360 CostClass::VeryHigh => 0.2,
361 };
362
363 let latency_ratio =
365 f64::from(self.typical_latency_ms) / f64::from(requirements.max_latency_ms);
366 let latency_score = 1.0 - latency_ratio.min(1.0);
367
368 let quality_score = self.quality;
370
371 0.4 * cost_score + 0.3 * latency_score + 0.3 * quality_score
374 }
375
376 #[must_use]
380 pub fn fitness_breakdown(&self, requirements: &AgentRequirements) -> Option<FitnessBreakdown> {
381 if !self.satisfies(requirements) {
382 return None;
383 }
384
385 let cost_score = match self.cost_class {
386 CostClass::VeryLow => 1.0,
387 CostClass::Low => 0.8,
388 CostClass::Medium => 0.6,
389 CostClass::High => 0.4,
390 CostClass::VeryHigh => 0.2,
391 };
392
393 let latency_ratio =
394 f64::from(self.typical_latency_ms) / f64::from(requirements.max_latency_ms);
395 let latency_score = 1.0 - latency_ratio.min(1.0);
396
397 let quality_score = self.quality;
398
399 let total = 0.4 * cost_score + 0.3 * latency_score + 0.3 * quality_score;
400
401 Some(FitnessBreakdown {
402 cost_score,
403 latency_score,
404 quality_score,
405 total,
406 })
407 }
408
409 #[must_use]
413 pub fn rejection_reason(&self, requirements: &AgentRequirements) -> Option<RejectionReason> {
414 if !requirements
416 .max_cost_class
417 .allowed_classes()
418 .contains(&self.cost_class)
419 {
420 return Some(RejectionReason::CostTooHigh {
421 model_cost: self.cost_class,
422 max_allowed: requirements.max_cost_class,
423 });
424 }
425
426 if self.typical_latency_ms > requirements.max_latency_ms {
428 return Some(RejectionReason::LatencyTooHigh {
429 model_latency_ms: self.typical_latency_ms,
430 max_allowed_ms: requirements.max_latency_ms,
431 });
432 }
433
434 if requirements.requires_reasoning && !self.has_reasoning {
436 return Some(RejectionReason::ReasoningRequired);
437 }
438
439 if requirements.requires_web_search && !self.supports_web_search {
441 return Some(RejectionReason::WebSearchRequired);
442 }
443
444 if self.quality < requirements.min_quality {
446 return Some(RejectionReason::QualityTooLow {
447 model_quality: self.quality,
448 min_required: requirements.min_quality,
449 });
450 }
451
452 if requirements.data_sovereignty != DataSovereignty::Any
454 && self.data_sovereignty != requirements.data_sovereignty
455 {
456 return Some(RejectionReason::DataSovereigntyMismatch {
457 required: requirements.data_sovereignty,
458 model_has: self.data_sovereignty,
459 });
460 }
461
462 if requirements.compliance != ComplianceLevel::None
464 && self.compliance != requirements.compliance
465 {
466 return Some(RejectionReason::ComplianceMismatch {
467 required: requirements.compliance,
468 model_has: self.compliance,
469 });
470 }
471
472 if requirements.requires_multilingual && !self.supports_multilingual {
474 return Some(RejectionReason::MultilingualRequired);
475 }
476
477 None
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct ModelSelector {
484 models: Vec<ModelMetadata>,
486}
487
488impl ModelSelector {
489 #[must_use]
491 pub fn new() -> Self {
492 Self::default()
493 }
494
495 #[must_use]
497 pub fn empty() -> Self {
498 Self { models: Vec::new() }
499 }
500
501 #[must_use]
503 pub fn with_model(mut self, metadata: ModelMetadata) -> Self {
504 self.models.push(metadata);
505 self
506 }
507
508 #[must_use]
510 pub fn list_satisfying(&self, requirements: &AgentRequirements) -> Vec<&ModelMetadata> {
511 self.models
512 .iter()
513 .filter(|m| m.satisfies(requirements))
514 .collect()
515 }
516}
517
518impl ModelSelectorTrait for ModelSelector {
519 fn select(&self, requirements: &AgentRequirements) -> Result<(String, String), LlmError> {
520 let mut candidates: Vec<(&ModelMetadata, f64)> = self
521 .models
522 .iter()
523 .filter_map(|m| {
524 if m.satisfies(requirements) {
525 Some((m, m.fitness_score(requirements)))
526 } else {
527 None
528 }
529 })
530 .collect();
531
532 if candidates.is_empty() {
533 return Err(LlmError::provider(format!(
534 "No model found satisfying requirements: cost <= {:?}, latency <= {}ms, reasoning = {}, web_search = {}, quality >= {:.2}, data_sovereignty = {:?}, compliance = {:?}, multilingual = {}",
535 requirements.max_cost_class,
536 requirements.max_latency_ms,
537 requirements.requires_reasoning,
538 requirements.requires_web_search,
539 requirements.min_quality,
540 requirements.data_sovereignty,
541 requirements.compliance,
542 requirements.requires_multilingual
543 )));
544 }
545
546 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
548
549 let best = candidates[0].0;
551 Ok((best.provider.clone(), best.model.clone()))
552 }
553}
554
555impl Default for ModelSelector {
556 #[allow(clippy::too_many_lines)] fn default() -> Self {
558 Self {
560 models: vec![
561 #[cfg(feature = "anthropic")]
563 ModelMetadata::new(
564 "anthropic",
565 "claude-haiku-4-5-20251001",
566 CostClass::VeryLow,
567 1200,
568 0.78,
569 )
570 .with_tool_use(true)
571 .with_vision(true)
572 .with_context_tokens(200_000),
573 #[cfg(feature = "anthropic")]
574 ModelMetadata::new(
575 "anthropic",
576 "claude-sonnet-4-6",
577 CostClass::Low,
578 2500,
579 0.93,
580 )
581 .with_reasoning(true)
582 .with_tool_use(true)
583 .with_vision(true)
584 .with_structured_output(true)
585 .with_code(true)
586 .with_context_tokens(200_000),
587 #[cfg(feature = "anthropic")]
588 ModelMetadata::new(
589 "anthropic",
590 "claude-opus-4-6",
591 CostClass::High,
592 7000,
593 0.97,
594 )
595 .with_reasoning(true)
596 .with_tool_use(true)
597 .with_vision(true)
598 .with_structured_output(true)
599 .with_code(true)
600 .with_context_tokens(200_000),
601 #[cfg(feature = "openai")]
603 ModelMetadata::new("openai", "gpt-3.5-turbo", CostClass::VeryLow, 1200, 0.70),
604 #[cfg(feature = "openai")]
605 ModelMetadata::new("openai", "gpt-4", CostClass::Medium, 5000, 0.90)
606 .with_reasoning(true),
607 #[cfg(feature = "openai")]
608 ModelMetadata::new("openai", "gpt-4-turbo", CostClass::Medium, 4000, 0.92)
609 .with_reasoning(true),
610 #[cfg(feature = "openai")]
611 ModelMetadata::new("openai", "gpt-5.4-mini", CostClass::Low, 2500, 0.95)
612 .with_reasoning(true)
613 .with_web_search(true)
614 .with_multilingual(true)
615 .with_context_tokens(1_050_000)
616 .with_tool_use(true)
617 .with_vision(true)
618 .with_structured_output(true)
619 .with_code(true),
620 #[cfg(feature = "openai")]
621 ModelMetadata::new("openai", "gpt-5.4", CostClass::High, 5500, 0.99)
622 .with_reasoning(true)
623 .with_web_search(true)
624 .with_multilingual(true)
625 .with_context_tokens(1_050_000)
626 .with_tool_use(true)
627 .with_vision(true)
628 .with_structured_output(true)
629 .with_code(true),
630 #[cfg(feature = "openai")]
631 ModelMetadata::new("openai", "gpt-5.4-pro", CostClass::VeryHigh, 11000, 1.00)
632 .with_reasoning(true)
633 .with_web_search(true)
634 .with_multilingual(true)
635 .with_context_tokens(1_050_000)
636 .with_tool_use(true)
637 .with_vision(true)
638 .with_code(true),
639 #[cfg(feature = "gemini")]
641 ModelMetadata::new("gemini", "gemini-pro", CostClass::Low, 2000, 0.80)
642 .with_tool_use(true)
643 .with_structured_output(true)
644 .with_context_tokens(32000),
645 #[cfg(feature = "gemini")]
646 ModelMetadata::new("gemini", "gemini-1.5-flash", CostClass::VeryLow, 800, 0.78)
647 .with_tool_use(true)
648 .with_vision(true)
649 .with_structured_output(true)
650 .with_multilingual(true)
651 .with_context_tokens(1_000_000),
652 #[cfg(feature = "gemini")]
653 ModelMetadata::new("gemini", "gemini-1.5-pro", CostClass::Medium, 3000, 0.88)
654 .with_tool_use(true)
655 .with_vision(true)
656 .with_structured_output(true)
657 .with_code(true)
658 .with_reasoning(true)
659 .with_multilingual(true)
660 .with_context_tokens(2_000_000),
661 #[cfg(feature = "gemini")]
662 ModelMetadata::new("gemini", "gemini-2.0-flash", CostClass::VeryLow, 700, 0.82)
663 .with_tool_use(true)
664 .with_vision(true)
665 .with_structured_output(true)
666 .with_code(true)
667 .with_reasoning(true)
668 .with_multilingual(true)
669 .with_context_tokens(1_000_000),
670 #[cfg(feature = "gemini")]
671 ModelMetadata::new("gemini", "gemini-2.5-flash", CostClass::VeryLow, 800, 0.82)
672 .with_tool_use(true)
673 .with_vision(true)
674 .with_structured_output(true)
675 .with_code(true)
676 .with_reasoning(true)
677 .with_multilingual(true)
678 .with_context_tokens(1_000_000),
679 #[cfg(feature = "gemini")]
680 ModelMetadata::new("gemini", "gemini-3-flash-preview", CostClass::VeryLow, 900, 0.90)
681 .with_tool_use(true)
682 .with_vision(true)
683 .with_structured_output(true)
684 .with_code(true)
685 .with_reasoning(true)
686 .with_multilingual(true)
687 .with_context_tokens(1_050_000),
688 #[cfg(feature = "gemini")]
689 ModelMetadata::new("gemini", "gemini-3-pro", CostClass::Medium, 2500, 0.96)
690 .with_tool_use(true)
691 .with_vision(true)
692 .with_structured_output(true)
693 .with_code(true)
694 .with_reasoning(true)
695 .with_multilingual(true)
696 .with_context_tokens(2_000_000),
697 #[cfg(feature = "perplexity")]
699 ModelMetadata::new(
700 "perplexity",
701 "pplx-70b-online",
702 CostClass::Medium,
703 4000,
704 0.90,
705 )
706 .with_reasoning(true)
707 .with_web_search(true),
708 #[cfg(feature = "perplexity")]
709 ModelMetadata::new("perplexity", "pplx-7b-online", CostClass::Low, 2500, 0.75)
710 .with_web_search(true),
711 #[cfg(feature = "qwen")]
713 ModelMetadata::new("qwen", "qwen-turbo", CostClass::VeryLow, 1500, 0.70),
714 #[cfg(feature = "qwen")]
715 ModelMetadata::new("qwen", "qwen-plus", CostClass::Low, 2500, 0.80),
716 #[cfg(feature = "openai")]
718 ModelMetadata::new(
719 "openrouter",
720 "anthropic/claude-haiku-4-5-20251001",
721 CostClass::VeryLow,
722 1200,
723 0.78,
724 ),
725 #[cfg(feature = "openai")]
726 ModelMetadata::new("openrouter", "openai/gpt-4", CostClass::Medium, 5000, 0.90)
727 .with_reasoning(true),
728 #[cfg(feature = "minmax")]
730 ModelMetadata::new("minmax", "abab5.5-chat", CostClass::Low, 2000, 0.75),
731 #[cfg(feature = "grok")]
733 ModelMetadata::new("grok", "grok-beta", CostClass::Medium, 3000, 0.80),
734 #[cfg(feature = "mistral")]
736 ModelMetadata::new(
737 "mistral",
738 "mistral-large-latest",
739 CostClass::Low,
740 3000,
741 0.85,
742 )
743 .with_reasoning(true)
744 .with_multilingual(true),
745 #[cfg(feature = "mistral")]
746 ModelMetadata::new(
747 "mistral",
748 "mistral-medium-latest",
749 CostClass::Medium,
750 4000,
751 0.88,
752 )
753 .with_reasoning(true)
754 .with_multilingual(true),
755 #[cfg(feature = "deepseek")]
757 ModelMetadata::new("deepseek", "deepseek-chat", CostClass::VeryLow, 1500, 0.75)
758 .with_reasoning(true),
759 #[cfg(feature = "deepseek")]
760 ModelMetadata::new("deepseek", "deepseek-r1", CostClass::Low, 3000, 0.85)
761 .with_reasoning(true),
762 #[cfg(feature = "baidu")]
764 ModelMetadata::new("baidu", "ernie-bot", CostClass::Low, 2500, 0.80)
765 .with_data_sovereignty(DataSovereignty::China)
766 .with_multilingual(true),
767 #[cfg(feature = "baidu")]
768 ModelMetadata::new("baidu", "ernie-bot-turbo", CostClass::VeryLow, 1500, 0.75)
769 .with_data_sovereignty(DataSovereignty::China)
770 .with_multilingual(true),
771 #[cfg(feature = "zhipu")]
773 ModelMetadata::new("zhipu", "glm-4", CostClass::Low, 2500, 0.82)
774 .with_data_sovereignty(DataSovereignty::China)
775 .with_multilingual(true),
776 #[cfg(feature = "zhipu")]
777 ModelMetadata::new("zhipu", "glm-4.5", CostClass::Medium, 3000, 0.88)
778 .with_data_sovereignty(DataSovereignty::China)
779 .with_reasoning(true)
780 .with_multilingual(true),
781 #[cfg(feature = "kimi")]
783 ModelMetadata::new("kimi", "moonshot-v1-8k", CostClass::Low, 2000, 0.80)
784 .with_multilingual(true),
785 #[cfg(feature = "kimi")]
786 ModelMetadata::new("kimi", "moonshot-v1-32k", CostClass::Medium, 3000, 0.85)
787 .with_reasoning(true)
788 .with_multilingual(true),
789 #[cfg(feature = "apertus")]
791 ModelMetadata::new("apertus", "apertus-v1", CostClass::Medium, 4000, 0.85)
792 .with_data_sovereignty(DataSovereignty::Switzerland)
793 .with_compliance(ComplianceLevel::GDPR)
794 .with_multilingual(true),
795 ],
796 }
797 }
798}
799
800#[must_use]
804pub fn is_provider_available(provider: &str) -> bool {
805 match provider {
806 #[cfg(feature = "anthropic")]
807 "anthropic" => std::env::var("ANTHROPIC_API_KEY").is_ok(),
808 #[cfg(feature = "openai")]
809 "openai" => std::env::var("OPENAI_API_KEY").is_ok(),
810 #[cfg(feature = "gemini")]
811 "gemini" => std::env::var("GEMINI_API_KEY").is_ok(),
812 #[cfg(feature = "perplexity")]
813 "perplexity" => std::env::var("PERPLEXITY_API_KEY").is_ok(),
814 #[cfg(feature = "openai")]
815 "openrouter" => std::env::var("OPENROUTER_API_KEY").is_ok(),
816 #[cfg(feature = "qwen")]
817 "qwen" => std::env::var("QWEN_API_KEY").is_ok(),
818 #[cfg(feature = "minmax")]
819 "minmax" => std::env::var("MINMAX_API_KEY").is_ok(),
820 #[cfg(feature = "grok")]
821 "grok" => std::env::var("GROK_API_KEY").is_ok(),
822 #[cfg(feature = "mistral")]
823 "mistral" => std::env::var("MISTRAL_API_KEY").is_ok(),
824 #[cfg(feature = "deepseek")]
825 "deepseek" => std::env::var("DEEPSEEK_API_KEY").is_ok(),
826 #[cfg(feature = "baidu")]
827 "baidu" => {
828 std::env::var("BAIDU_API_KEY").is_ok() && std::env::var("BAIDU_SECRET_KEY").is_ok()
829 }
830 #[cfg(feature = "zhipu")]
831 "zhipu" => std::env::var("ZHIPU_API_KEY").is_ok(),
832 #[cfg(feature = "kimi")]
833 "kimi" => std::env::var("KIMI_API_KEY").is_ok(),
834 #[cfg(feature = "apertus")]
835 "apertus" => std::env::var("APERTUS_API_KEY").is_ok(),
836 #[cfg(feature = "brave")]
838 "brave" => std::env::var("BRAVE_API_KEY").is_ok(),
839 _ => false,
840 }
841}
842
843#[must_use]
845pub fn is_brave_available() -> bool {
846 #[cfg(feature = "brave")]
847 {
848 is_provider_available("brave")
849 }
850 #[cfg(not(feature = "brave"))]
851 {
852 false
853 }
854}
855
856#[derive(Debug, Clone)]
864pub struct ProviderRegistry {
865 base_selector: ModelSelector,
867 available_providers: std::collections::HashSet<String>,
869 metadata_overrides: std::collections::HashMap<(String, String), ModelMetadata>,
871}
872
873impl ProviderRegistry {
874 #[must_use]
878 pub fn from_env() -> Self {
879 let base_selector = ModelSelector::new();
880
881 let known_providers = vec![
883 "anthropic",
885 "openai",
886 "gemini",
887 "perplexity",
888 "openrouter",
889 "qwen",
890 "minmax",
891 "grok",
892 "mistral",
893 "deepseek",
894 "baidu",
895 "zhipu",
896 "kimi",
897 "apertus",
898 "brave",
900 ];
901
902 let available_providers: std::collections::HashSet<String> = known_providers
903 .into_iter()
904 .filter(|p| is_provider_available(p))
905 .map(std::string::ToString::to_string)
906 .collect();
907
908 Self {
909 base_selector,
910 available_providers,
911 metadata_overrides: std::collections::HashMap::new(),
912 }
913 }
914
915 #[must_use]
920 pub fn with_providers(providers: &[&str]) -> Self {
921 let base_selector = ModelSelector::new();
922 let available_providers: std::collections::HashSet<String> = providers
923 .iter()
924 .map(std::string::ToString::to_string)
925 .collect();
926
927 Self {
928 base_selector,
929 available_providers,
930 metadata_overrides: std::collections::HashMap::new(),
931 }
932 }
933
934 pub fn update_metadata(
939 &mut self,
940 provider: impl Into<String>,
941 model: impl Into<String>,
942 metadata: ModelMetadata,
943 ) {
944 self.metadata_overrides
945 .insert((provider.into(), model.into()), metadata);
946 }
947
948 #[must_use]
950 pub fn list_available(&self, requirements: &AgentRequirements) -> Vec<&ModelMetadata> {
951 self.base_selector
952 .list_satisfying(requirements)
953 .into_iter()
954 .filter(|m| self.available_providers.contains(&m.provider))
955 .collect()
956 }
957
958 #[must_use]
960 pub fn available_providers(&self) -> Vec<&str> {
961 self.available_providers
962 .iter()
963 .map(std::string::String::as_str)
964 .collect()
965 }
966
967 #[must_use]
969 pub fn is_available(&self, provider: &str) -> bool {
970 self.available_providers.contains(provider)
971 }
972
973 pub fn select_with_details(
984 &self,
985 requirements: &AgentRequirements,
986 ) -> Result<SelectionResult, LlmError> {
987 let mut candidates: Vec<(ModelMetadata, FitnessBreakdown)> = Vec::new();
988 let mut rejected: Vec<(ModelMetadata, RejectionReason)> = Vec::new();
989
990 for model in &self.base_selector.models {
992 if !self.available_providers.contains(&model.provider) {
994 rejected.push((model.clone(), RejectionReason::ProviderUnavailable));
995 continue;
996 }
997
998 let metadata = self
1000 .metadata_overrides
1001 .get(&(model.provider.clone(), model.model.clone()))
1002 .unwrap_or(model);
1003
1004 if let Some(breakdown) = metadata.fitness_breakdown(requirements) {
1006 candidates.push((metadata.clone(), breakdown));
1007 } else if let Some(reason) = metadata.rejection_reason(requirements) {
1008 rejected.push((metadata.clone(), reason));
1009 }
1010 }
1011
1012 if candidates.is_empty() {
1013 let available = self
1014 .available_providers
1015 .iter()
1016 .map(std::string::String::as_str)
1017 .collect::<Vec<_>>()
1018 .join(", ");
1019 return Err(LlmError::provider(format!(
1020 "No available model found satisfying requirements. Available providers: [{}]",
1021 if available.is_empty() {
1022 "none (set API keys)".to_string()
1023 } else {
1024 available
1025 }
1026 )));
1027 }
1028
1029 candidates.sort_by(|a, b| {
1031 b.1.total
1032 .partial_cmp(&a.1.total)
1033 .unwrap_or(std::cmp::Ordering::Equal)
1034 });
1035
1036 let (selected, fitness) = candidates[0].clone();
1038
1039 Ok(SelectionResult {
1040 selected,
1041 fitness,
1042 candidates,
1043 rejected,
1044 })
1045 }
1046}
1047
1048impl ModelSelectorTrait for ProviderRegistry {
1049 fn select(&self, requirements: &AgentRequirements) -> Result<(String, String), LlmError> {
1050 let all_candidates = self.base_selector.list_satisfying(requirements);
1052
1053 let mut candidates: Vec<(&ModelMetadata, f64)> = all_candidates
1055 .iter()
1056 .filter(|m| self.available_providers.contains(&m.provider))
1057 .map(|m| {
1058 let metadata = self
1060 .metadata_overrides
1061 .get(&(m.provider.clone(), m.model.clone()))
1062 .unwrap_or(m);
1063 (metadata, metadata.fitness_score(requirements))
1064 })
1065 .collect();
1066
1067 if candidates.is_empty() {
1068 let available = self
1069 .available_providers
1070 .iter()
1071 .map(std::string::String::as_str)
1072 .collect::<Vec<_>>()
1073 .join(", ");
1074 return Err(LlmError::provider(format!(
1075 "No available model found satisfying requirements. Available providers: [{}]",
1076 if available.is_empty() {
1077 "none (set API keys)".to_string()
1078 } else {
1079 available
1080 }
1081 )));
1082 }
1083
1084 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1086
1087 let best = candidates[0].0;
1089 Ok((best.provider.clone(), best.model.clone()))
1090 }
1091}
1092
1093impl Default for ProviderRegistry {
1094 fn default() -> Self {
1095 Self::from_env()
1096 }
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101 use super::*;
1102 use converge_traits::CostClass;
1103
1104 #[test]
1105 fn test_provider_availability_check() {
1106 let _ = is_provider_available("anthropic");
1108 }
1109
1110 #[test]
1111 fn test_registry_with_explicit_providers() {
1112 let registry = ProviderRegistry::with_providers(&["anthropic", "openai"]);
1113 assert!(registry.is_available("anthropic"));
1114 assert!(registry.is_available("openai"));
1115 assert!(!registry.is_available("gemini"));
1116 }
1117
1118 #[test]
1119 fn test_metadata_override() {
1120 let mut registry = ProviderRegistry::with_providers(&["anthropic"]);
1121
1122 let updated = ModelMetadata::new(
1124 "anthropic",
1125 "claude-haiku-4-5-20251001",
1126 CostClass::VeryLow,
1127 1000, 0.78,
1129 );
1130 registry.update_metadata("anthropic", "claude-haiku-4-5-20251001", updated);
1131
1132 let reqs = AgentRequirements::fast_cheap();
1133 let result = registry.select(&reqs);
1134 assert!(result.is_ok());
1135 }
1136
1137 #[test]
1138 fn test_model_selection() {
1139 let selector = ModelSelector::new();
1140 let reqs = AgentRequirements::fast_cheap();
1141
1142 let (provider, model) = selector.select(&reqs).unwrap();
1143 assert!(
1145 provider == "anthropic"
1146 || provider == "openai"
1147 || provider == "gemini"
1148 || provider == "qwen"
1149 );
1150 assert!(
1151 model.contains("haiku")
1152 || model.contains("flash")
1153 || model.contains("turbo")
1154 || model.contains("qwen")
1155 );
1156 }
1157
1158 #[test]
1159 fn test_selection_requires_reasoning_and_web_search() {
1160 let selector = ModelSelector::empty()
1161 .with_model(ModelMetadata::new(
1162 "alpha",
1163 "basic",
1164 CostClass::Low,
1165 1200,
1166 0.85,
1167 ))
1168 .with_model(
1169 ModelMetadata::new("beta", "reasoning-only", CostClass::Low, 1400, 0.88)
1170 .with_reasoning(true),
1171 )
1172 .with_model(
1173 ModelMetadata::new("gamma", "reasoning-search", CostClass::Low, 1500, 0.87)
1174 .with_reasoning(true)
1175 .with_web_search(true),
1176 );
1177
1178 let reqs = AgentRequirements::new(CostClass::Low, 5000, true).with_web_search(true);
1179 let (provider, model) = selector.select(&reqs).unwrap();
1180 assert_eq!(provider, "gamma");
1181 assert_eq!(model, "reasoning-search");
1182 }
1183
1184 #[test]
1185 fn test_selection_respects_data_sovereignty_and_compliance() {
1186 let selector = ModelSelector::empty()
1187 .with_model(
1188 ModelMetadata::new("us", "us-model", CostClass::Low, 1500, 0.85)
1189 .with_data_sovereignty(DataSovereignty::US),
1190 )
1191 .with_model(
1192 ModelMetadata::new("eu", "eu-gdpr", CostClass::Low, 1800, 0.86)
1193 .with_data_sovereignty(DataSovereignty::EU)
1194 .with_compliance(ComplianceLevel::GDPR),
1195 );
1196
1197 let reqs = AgentRequirements::balanced()
1198 .with_data_sovereignty(DataSovereignty::EU)
1199 .with_compliance(ComplianceLevel::GDPR);
1200 let (provider, model) = selector.select(&reqs).unwrap();
1201 assert_eq!(provider, "eu");
1202 assert_eq!(model, "eu-gdpr");
1203 }
1204
1205 #[test]
1206 fn test_selection_requires_multilingual() {
1207 let selector = ModelSelector::empty()
1208 .with_model(
1209 ModelMetadata::new("mono", "fast", CostClass::VeryLow, 800, 0.80)
1210 .with_multilingual(false),
1211 )
1212 .with_model(
1213 ModelMetadata::new("multi", "polyglot", CostClass::Low, 1200, 0.82)
1214 .with_multilingual(true),
1215 );
1216
1217 let reqs = AgentRequirements::new(CostClass::Low, 2000, false).with_multilingual(true);
1218 let (provider, model) = selector.select(&reqs).unwrap();
1219 assert_eq!(provider, "multi");
1220 assert_eq!(model, "polyglot");
1221 }
1222}