1use crate::model_selection::{ModelMetadata, ModelSelector};
29use converge_provider_api::selection::{ComplianceLevel, CostClass, DataSovereignty};
30use schemars::JsonSchema;
31use serde::Deserialize;
32use std::collections::HashMap;
33use std::path::Path;
34
35#[derive(Debug, thiserror::Error)]
37pub enum RegistryError {
38 #[error("Failed to read registry file: {0}")]
40 IoError(#[from] std::io::Error),
41
42 #[error("Failed to parse registry YAML: {0}")]
44 ParseError(#[from] serde_yaml::Error),
45
46 #[error("Registry validation failed: {0}")]
48 ValidationError(String),
49}
50
51#[derive(Debug, Deserialize, JsonSchema)]
59#[serde(deny_unknown_fields)]
60pub struct RegistryYaml {
61 pub providers: HashMap<String, ProviderYaml>,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
67#[serde(rename_all = "snake_case")]
68pub enum ProviderTypeYaml {
69 #[default]
71 Direct,
72 Aggregator,
74}
75
76#[derive(Debug, Deserialize, JsonSchema)]
78#[serde(deny_unknown_fields)]
79pub struct ProviderYaml {
80 pub env_key: String,
82 #[serde(default)]
84 pub env_key_secondary: Option<String>,
85 pub key_url: String,
87 pub api_url: String,
89 pub country: String,
91 pub region: RegionYaml,
93 #[serde(default)]
95 pub compliance: Vec<ComplianceYaml>,
96 #[serde(default)]
98 pub provider_type: ProviderTypeYaml,
99 pub models: HashMap<String, ModelYaml>,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
107pub enum RegionYaml {
108 US,
110 EU,
112 EEA,
114 CH,
116 CN,
118 JP,
120 UK,
122 LOCAL,
124}
125
126impl RegionYaml {
127 #[must_use]
129 pub fn as_str(&self) -> &'static str {
130 match self {
131 Self::US => "US",
132 Self::EU => "EU",
133 Self::EEA => "EEA",
134 Self::CH => "CH",
135 Self::CN => "CN",
136 Self::JP => "JP",
137 Self::UK => "UK",
138 Self::LOCAL => "LOCAL",
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
145pub enum ComplianceYaml {
146 GDPR,
148 SOC2,
150 HIPAA,
152}
153
154impl From<ComplianceYaml> for ComplianceLevel {
155 fn from(c: ComplianceYaml) -> Self {
156 match c {
157 ComplianceYaml::GDPR => ComplianceLevel::GDPR,
158 ComplianceYaml::SOC2 => ComplianceLevel::SOC2,
159 ComplianceYaml::HIPAA => ComplianceLevel::HIPAA,
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
166pub enum CostClassYaml {
167 VeryLow,
169 Low,
171 Medium,
173 High,
175 VeryHigh,
177}
178
179impl From<CostClassYaml> for CostClass {
180 fn from(c: CostClassYaml) -> Self {
181 match c {
182 CostClassYaml::VeryLow => CostClass::VeryLow,
183 CostClassYaml::Low => CostClass::Low,
184 CostClassYaml::Medium => CostClass::Medium,
185 CostClassYaml::High => CostClass::High,
186 CostClassYaml::VeryHigh => CostClass::VeryHigh,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
193#[serde(rename_all = "snake_case")]
194pub enum CapabilityYaml {
195 ToolUse,
197 Vision,
199 StructuredOutput,
201 Code,
203 Reasoning,
205 Multilingual,
207 WebSearch,
209 Audio,
211 ImageGeneration,
213 Streaming,
215 Logprobs,
217 Seed,
219 ToolChoice,
221 ParallelToolCalls,
223 PromptCaching,
225 FileSearch,
227 CodeInterpreter,
229 ComputerUse,
231 ToolSearch,
233 Mcp,
235 HostedShell,
237 ApplyPatch,
239 NativeCompaction,
241 ReasoningEffort,
243 ContentGeneration,
245 BusinessAcumen,
247}
248
249impl CapabilityYaml {
250 #[must_use]
252 pub fn as_str(&self) -> &'static str {
253 match self {
254 Self::ToolUse => "tool_use",
255 Self::Vision => "vision",
256 Self::StructuredOutput => "structured_output",
257 Self::Code => "code",
258 Self::Reasoning => "reasoning",
259 Self::Multilingual => "multilingual",
260 Self::WebSearch => "web_search",
261 Self::Audio => "audio",
262 Self::ImageGeneration => "image_generation",
263 Self::Streaming => "streaming",
264 Self::Logprobs => "logprobs",
265 Self::Seed => "seed",
266 Self::ToolChoice => "tool_choice",
267 Self::ParallelToolCalls => "parallel_tool_calls",
268 Self::PromptCaching => "prompt_caching",
269 Self::FileSearch => "file_search",
270 Self::CodeInterpreter => "code_interpreter",
271 Self::ComputerUse => "computer_use",
272 Self::ToolSearch => "tool_search",
273 Self::Mcp => "mcp",
274 Self::HostedShell => "hosted_shell",
275 Self::ApplyPatch => "apply_patch",
276 Self::NativeCompaction => "native_compaction",
277 Self::ReasoningEffort => "reasoning_effort",
278 Self::ContentGeneration => "content_generation",
279 Self::BusinessAcumen => "business_acumen",
280 }
281 }
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
286#[serde(rename_all = "snake_case")]
287pub enum ReasoningEffortYaml {
288 None,
290 Minimal,
292 Low,
294 Medium,
296 High,
298 Xhigh,
300}
301
302impl ReasoningEffortYaml {
303 #[must_use]
305 pub fn as_str(&self) -> &'static str {
306 match self {
307 Self::None => "none",
308 Self::Minimal => "minimal",
309 Self::Low => "low",
310 Self::Medium => "medium",
311 Self::High => "high",
312 Self::Xhigh => "xhigh",
313 }
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
319#[serde(rename_all = "snake_case")]
320pub enum ModelTypeYaml {
321 #[default]
323 Llm,
324 Embedding,
326 Reranker,
328 Ocr,
330}
331
332#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
334#[serde(rename_all = "snake_case")]
335pub enum ArchitectureYaml {
336 #[default]
338 Dense,
339 Moe,
341 Hybrid,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
347#[serde(rename_all = "snake_case")]
348pub enum ModalityYaml {
349 Text,
351 Image,
353 Video,
355 Audio,
357}
358
359#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
361#[serde(deny_unknown_fields)]
362pub struct AgenticYaml {
363 #[serde(default)]
365 pub max_parallel_agents: Option<u32>,
366 #[serde(default)]
368 pub supports_orchestration: bool,
369}
370
371#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
373#[serde(deny_unknown_fields)]
374pub struct PricingYaml {
375 #[serde(default)]
377 pub input_per_m: Option<f64>,
378 #[serde(default)]
380 pub output_per_m: Option<f64>,
381}
382
383#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
385#[serde(deny_unknown_fields)]
386pub struct RateLimitsYaml {
387 #[serde(default)]
389 pub requests_per_min: Option<u32>,
390 #[serde(default)]
392 pub tokens_per_min: Option<u32>,
393 #[serde(default)]
395 pub requests_per_day: Option<u32>,
396 #[serde(default)]
398 pub concurrent_requests: Option<u32>,
399}
400
401#[derive(Debug, Deserialize, JsonSchema)]
403#[serde(deny_unknown_fields)]
404pub struct ModelYaml {
405 pub cost_class: CostClassYaml,
407 pub typical_latency_ms: u32,
409 pub quality: f64,
411 #[serde(default = "default_context_tokens")]
413 pub context_tokens: usize,
414 #[serde(default)]
416 pub capabilities: Vec<CapabilityYaml>,
417 #[serde(default, rename = "type")]
419 pub model_type: ModelTypeYaml,
420 #[serde(default)]
422 pub dimensions: Option<usize>,
423
424 #[serde(default)]
427 pub architecture: ArchitectureYaml,
428 #[serde(default)]
430 pub total_params_b: Option<f64>,
431 #[serde(default)]
433 pub active_params_b: Option<f64>,
434 #[serde(default)]
436 pub max_output_tokens: Option<usize>,
437 #[serde(default)]
439 pub native_multimodal: bool,
440 #[serde(default)]
442 pub modalities: Vec<ModalityYaml>,
443 #[serde(default)]
445 pub agentic: Option<AgenticYaml>,
446 #[serde(default)]
448 pub thinking_mode: bool,
449 #[serde(default)]
451 pub reasoning_effort_levels: Vec<ReasoningEffortYaml>,
452 #[serde(default)]
454 pub native_compaction: bool,
455 #[serde(default)]
457 pub thinking_variant: Option<String>,
458 #[serde(default)]
460 pub pricing: Option<PricingYaml>,
461 #[serde(default)]
463 pub publisher: Option<String>,
464 #[serde(default)]
466 pub family: Option<String>,
467 #[serde(default)]
469 pub release_date: Option<String>,
470 #[serde(default)]
472 pub training_cutoff: Option<String>,
473 #[serde(default)]
475 pub open_weights: bool,
476 #[serde(default)]
478 pub license: Option<String>,
479 #[serde(default)]
481 pub deprecated: bool,
482 #[serde(default)]
484 pub beta: bool,
485 #[serde(default)]
487 pub benchmarks: HashMap<String, f64>,
488 #[serde(default)]
490 pub tags: Vec<String>,
491 #[serde(default)]
493 pub rate_limits: Option<RateLimitsYaml>,
494 #[serde(default)]
496 pub notes: Option<String>,
497}
498
499fn default_context_tokens() -> usize {
500 8192
501}
502
503#[must_use]
519pub fn generate_schema() -> schemars::schema::RootSchema {
520 schemars::schema_for!(RegistryYaml)
521}
522
523#[derive(Debug, Clone, Copy, PartialEq, Eq)]
529pub enum ProviderType {
530 Direct,
532 Aggregator,
534}
535
536#[derive(Debug, Clone)]
538pub struct LoadedProvider {
539 pub id: String,
541 pub env_key: String,
543 pub env_key_secondary: Option<String>,
545 pub key_url: String,
547 pub api_url: String,
549 pub country: String,
551 pub region: String,
553 pub compliance: Vec<ComplianceLevel>,
555 pub provider_type: ProviderType,
557 pub models: Vec<LoadedModel>,
559}
560
561impl LoadedProvider {
562 #[must_use]
564 pub fn is_available(&self) -> bool {
565 let primary_ok = std::env::var(&self.env_key).is_ok();
566 let secondary_ok = self
567 .env_key_secondary
568 .as_ref()
569 .map(|k| std::env::var(k).is_ok())
570 .unwrap_or(true);
571 primary_ok && secondary_ok
572 }
573
574 #[must_use]
576 pub fn api_key(&self) -> Option<String> {
577 std::env::var(&self.env_key).ok()
578 }
579
580 #[must_use]
582 pub fn secondary_api_key(&self) -> Option<String> {
583 self.env_key_secondary
584 .as_ref()
585 .and_then(|k| std::env::var(k).ok())
586 }
587}
588
589#[derive(Debug, Clone, Copy, PartialEq, Eq)]
591pub enum Architecture {
592 Dense,
594 Moe,
596 Hybrid,
598}
599
600#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
602pub enum Modality {
603 Text,
605 Image,
607 Video,
609 Audio,
611}
612
613#[derive(Debug, Clone, Default)]
615pub struct AgenticCapabilities {
616 pub max_parallel_agents: Option<u32>,
618 pub supports_orchestration: bool,
620}
621
622#[derive(Debug, Clone, Default)]
624pub struct Pricing {
625 pub input_per_m: Option<f64>,
627 pub output_per_m: Option<f64>,
629}
630
631#[derive(Debug, Clone, Default)]
633pub struct RateLimits {
634 pub requests_per_min: Option<u32>,
636 pub tokens_per_min: Option<u32>,
638 pub requests_per_day: Option<u32>,
640 pub concurrent_requests: Option<u32>,
642}
643
644#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
646pub enum ReasoningEffort {
647 None,
649 Minimal,
651 Low,
653 Medium,
655 High,
657 Xhigh,
659}
660
661impl ReasoningEffort {
662 #[must_use]
664 pub fn as_str(&self) -> &'static str {
665 match self {
666 Self::None => "none",
667 Self::Minimal => "minimal",
668 Self::Low => "low",
669 Self::Medium => "medium",
670 Self::High => "high",
671 Self::Xhigh => "xhigh",
672 }
673 }
674}
675
676#[derive(Debug, Clone)]
678#[allow(clippy::struct_excessive_bools)]
679pub struct LoadedModel {
680 pub id: String,
682 pub cost_class: CostClass,
684 pub typical_latency_ms: u32,
686 pub quality: f64,
688 pub context_tokens: usize,
690 pub model_type: ModelType,
692 pub dimensions: Option<usize>,
694 pub capabilities: Vec<CapabilityYaml>,
696 pub supports_tool_use: bool,
699 pub supports_vision: bool,
701 pub supports_structured_output: bool,
703 pub supports_code: bool,
705 pub supports_reasoning: bool,
707 pub supports_multilingual: bool,
709 pub supports_web_search: bool,
711 pub supports_content_generation: bool,
713 pub supports_business_acumen: bool,
715
716 pub architecture: Architecture,
719 pub total_params_b: Option<f64>,
721 pub active_params_b: Option<f64>,
723 pub max_output_tokens: Option<usize>,
725 pub native_multimodal: bool,
727 pub modalities: Vec<Modality>,
729 pub agentic: Option<AgenticCapabilities>,
731 pub thinking_mode: bool,
733 pub reasoning_effort_levels: Vec<ReasoningEffort>,
735 pub native_compaction: bool,
737 pub thinking_variant: Option<String>,
739 pub pricing: Option<Pricing>,
741 pub publisher: Option<String>,
743 pub family: Option<String>,
745 pub release_date: Option<String>,
747 pub training_cutoff: Option<String>,
749 pub open_weights: bool,
751 pub license: Option<String>,
753 pub deprecated: bool,
755 pub beta: bool,
757 pub benchmarks: HashMap<String, f64>,
759 pub tags: Vec<String>,
761 pub rate_limits: Option<RateLimits>,
763 pub notes: Option<String>,
765}
766
767#[derive(Debug, Clone, Copy, PartialEq, Eq)]
769pub enum ModelType {
770 Llm,
772 Embedding,
774 Reranker,
776 Ocr,
778}
779
780#[derive(Debug, Clone)]
782pub struct LoadedRegistry {
783 providers: Vec<LoadedProvider>,
785}
786
787impl LoadedRegistry {
788 #[must_use]
790 pub fn providers(&self) -> &[LoadedProvider] {
791 &self.providers
792 }
793
794 #[must_use]
796 pub fn available_providers(&self) -> Vec<&LoadedProvider> {
797 self.providers.iter().filter(|p| p.is_available()).collect()
798 }
799
800 #[must_use]
802 pub fn get_provider(&self, id: &str) -> Option<&LoadedProvider> {
803 self.providers.iter().find(|p| p.id == id)
804 }
805
806 #[must_use]
808 pub fn llm_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
809 self.providers
810 .iter()
811 .flat_map(|p| {
812 p.models
813 .iter()
814 .filter(|m| m.model_type == ModelType::Llm)
815 .map(move |m| (p, m))
816 })
817 .collect()
818 }
819
820 #[must_use]
822 pub fn embedding_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
823 self.providers
824 .iter()
825 .flat_map(|p| {
826 p.models
827 .iter()
828 .filter(|m| m.model_type == ModelType::Embedding)
829 .map(move |m| (p, m))
830 })
831 .collect()
832 }
833
834 #[must_use]
836 pub fn reranker_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
837 self.providers
838 .iter()
839 .flat_map(|p| {
840 p.models
841 .iter()
842 .filter(|m| m.model_type == ModelType::Reranker)
843 .map(move |m| (p, m))
844 })
845 .collect()
846 }
847
848 #[must_use]
850 pub fn to_model_selector(&self) -> ModelSelector {
851 let mut selector = ModelSelector::empty();
852
853 for provider in &self.providers {
854 for model in &provider.models {
855 if model.model_type != ModelType::Llm {
856 continue; }
858
859 let data_sovereignty = match provider.region.as_str() {
860 "EU" | "EEA" => DataSovereignty::EU,
861 "CH" => DataSovereignty::Switzerland,
862 "CN" => DataSovereignty::China,
863 "US" => DataSovereignty::US,
864 "LOCAL" => DataSovereignty::OnPremises,
865 _ => DataSovereignty::Any,
866 };
867
868 let compliance = provider
869 .compliance
870 .first()
871 .copied()
872 .unwrap_or(ComplianceLevel::None);
873
874 let metadata = ModelMetadata::new(
875 &provider.id,
876 &model.id,
877 model.cost_class,
878 model.typical_latency_ms,
879 model.quality,
880 )
881 .with_reasoning(model.supports_reasoning)
882 .with_web_search(model.supports_web_search)
883 .with_data_sovereignty(data_sovereignty)
884 .with_compliance(compliance)
885 .with_multilingual(model.supports_multilingual)
886 .with_context_tokens(model.context_tokens)
887 .with_tool_use(model.supports_tool_use)
888 .with_vision(model.supports_vision)
889 .with_structured_output(model.supports_structured_output)
890 .with_code(model.supports_code)
891 .with_content_generation(model.supports_content_generation)
892 .with_business_acumen(model.supports_business_acumen)
893 .with_location(&provider.country, &provider.region);
894
895 selector = selector.with_model(metadata);
896 }
897 }
898
899 selector
900 }
901
902 pub fn print_summary(&self) {
904 println!("Model Registry Summary");
905 println!("======================\n");
906
907 for provider in &self.providers {
908 let status = if provider.is_available() {
909 "✓ available"
910 } else {
911 "✗ no key"
912 };
913
914 println!(
915 "{} ({}) - {} models [{}]",
916 provider.id,
917 provider.region,
918 provider.models.len(),
919 status
920 );
921 println!(" Key URL: {}", provider.key_url);
922 println!(" API URL: {}", provider.api_url);
923 println!();
924 }
925 }
926}
927
928pub const DEFAULT_REGISTRY_PATH: &str = "converge-provider/config/models.yaml";
934
935pub fn load_registry() -> Result<LoadedRegistry, RegistryError> {
946 if let Ok(path) = std::env::var("CONVERGE_MODELS_PATH") {
948 return load_registry_from_path(&path);
949 }
950
951 if std::path::Path::new(DEFAULT_REGISTRY_PATH).exists() {
953 return load_registry_from_path(DEFAULT_REGISTRY_PATH);
954 }
955
956 let crate_path = "config/models.yaml";
958 if std::path::Path::new(crate_path).exists() {
959 return load_registry_from_path(crate_path);
960 }
961
962 load_registry_from_str(include_str!("../config/models.yaml"))
964}
965
966pub fn load_registry_from_path(path: impl AsRef<Path>) -> Result<LoadedRegistry, RegistryError> {
972 let content = std::fs::read_to_string(path)?;
973 load_registry_from_str(&content)
974}
975
976pub fn load_registry_from_str(yaml: &str) -> Result<LoadedRegistry, RegistryError> {
982 let registry_yaml: RegistryYaml = serde_yaml::from_str(yaml)?;
983
984 let mut providers = Vec::new();
985 let mut errors = Vec::new();
986
987 for (provider_id, provider_yaml) in registry_yaml.providers {
988 if let Err(e) = validate_provider(&provider_id, &provider_yaml) {
990 errors.push(e);
991 continue;
992 }
993
994 let compliance = provider_yaml
995 .compliance
996 .iter()
997 .map(|c| ComplianceLevel::from(*c))
998 .collect();
999
1000 let mut models = Vec::new();
1001
1002 for (model_id, model_yaml) in provider_yaml.models {
1003 if let Err(e) = validate_model(&provider_id, &model_id, &model_yaml) {
1005 errors.push(e);
1006 continue;
1007 }
1008
1009 let capabilities: std::collections::HashSet<_> =
1010 model_yaml.capabilities.iter().copied().collect();
1011
1012 let modalities: Vec<Modality> = model_yaml
1014 .modalities
1015 .iter()
1016 .map(|m| match m {
1017 ModalityYaml::Text => Modality::Text,
1018 ModalityYaml::Image => Modality::Image,
1019 ModalityYaml::Video => Modality::Video,
1020 ModalityYaml::Audio => Modality::Audio,
1021 })
1022 .collect();
1023
1024 let reasoning_effort_levels = model_yaml
1026 .reasoning_effort_levels
1027 .iter()
1028 .copied()
1029 .map(ReasoningEffort::from)
1030 .collect();
1031
1032 let agentic = model_yaml.agentic.as_ref().map(|a| AgenticCapabilities {
1034 max_parallel_agents: a.max_parallel_agents,
1035 supports_orchestration: a.supports_orchestration,
1036 });
1037
1038 let pricing = model_yaml.pricing.as_ref().map(|p| Pricing {
1040 input_per_m: p.input_per_m,
1041 output_per_m: p.output_per_m,
1042 });
1043
1044 let rate_limits = model_yaml.rate_limits.as_ref().map(|r| RateLimits {
1046 requests_per_min: r.requests_per_min,
1047 tokens_per_min: r.tokens_per_min,
1048 requests_per_day: r.requests_per_day,
1049 concurrent_requests: r.concurrent_requests,
1050 });
1051
1052 let model = LoadedModel {
1053 id: model_id,
1054 cost_class: model_yaml.cost_class.into(),
1055 typical_latency_ms: model_yaml.typical_latency_ms,
1056 quality: model_yaml.quality,
1057 context_tokens: model_yaml.context_tokens,
1058 model_type: model_yaml.model_type.into(),
1059 dimensions: model_yaml.dimensions,
1060 capabilities: model_yaml.capabilities.clone(),
1061 supports_tool_use: capabilities.contains(&CapabilityYaml::ToolUse),
1062 supports_vision: capabilities.contains(&CapabilityYaml::Vision),
1063 supports_structured_output: capabilities
1064 .contains(&CapabilityYaml::StructuredOutput),
1065 supports_code: capabilities.contains(&CapabilityYaml::Code),
1066 supports_reasoning: capabilities.contains(&CapabilityYaml::Reasoning),
1067 supports_multilingual: capabilities.contains(&CapabilityYaml::Multilingual),
1068 supports_web_search: capabilities.contains(&CapabilityYaml::WebSearch),
1069 supports_content_generation: capabilities
1070 .contains(&CapabilityYaml::ContentGeneration),
1071 supports_business_acumen: capabilities.contains(&CapabilityYaml::BusinessAcumen),
1072 architecture: model_yaml.architecture.into(),
1074 total_params_b: model_yaml.total_params_b,
1075 active_params_b: model_yaml.active_params_b,
1076 max_output_tokens: model_yaml.max_output_tokens,
1077 native_multimodal: model_yaml.native_multimodal,
1078 modalities,
1079 agentic,
1080 thinking_mode: model_yaml.thinking_mode,
1081 reasoning_effort_levels,
1082 native_compaction: model_yaml.native_compaction,
1083 thinking_variant: model_yaml.thinking_variant.clone(),
1084 pricing,
1085 publisher: model_yaml.publisher.clone(),
1086 family: model_yaml.family.clone(),
1087 release_date: model_yaml.release_date.clone(),
1088 training_cutoff: model_yaml.training_cutoff.clone(),
1089 open_weights: model_yaml.open_weights,
1090 license: model_yaml.license.clone(),
1091 deprecated: model_yaml.deprecated,
1092 beta: model_yaml.beta,
1093 benchmarks: model_yaml.benchmarks.clone(),
1094 tags: model_yaml.tags.clone(),
1095 rate_limits,
1096 notes: model_yaml.notes.clone(),
1097 };
1098
1099 models.push(model);
1100 }
1101
1102 models.sort_by(|a, b| a.id.cmp(&b.id));
1104
1105 let provider = LoadedProvider {
1106 id: provider_id,
1107 env_key: provider_yaml.env_key,
1108 env_key_secondary: provider_yaml.env_key_secondary,
1109 key_url: provider_yaml.key_url,
1110 api_url: provider_yaml.api_url,
1111 country: provider_yaml.country,
1112 region: provider_yaml.region.as_str().to_string(),
1113 compliance,
1114 provider_type: provider_yaml.provider_type.into(),
1115 models,
1116 };
1117
1118 providers.push(provider);
1119 }
1120
1121 if !errors.is_empty() {
1123 return Err(RegistryError::ValidationError(errors.join("; ")));
1124 }
1125
1126 providers.sort_by(|a, b| a.id.cmp(&b.id));
1128
1129 Ok(LoadedRegistry { providers })
1130}
1131
1132fn validate_provider(id: &str, provider: &ProviderYaml) -> Result<(), String> {
1134 if provider.env_key.is_empty() {
1136 return Err(format!("Provider '{id}': env_key cannot be empty"));
1137 }
1138
1139 if !provider.key_url.starts_with("http://") && !provider.key_url.starts_with("https://") {
1141 return Err(format!(
1142 "Provider '{id}': key_url must be a valid URL, got '{}'",
1143 provider.key_url
1144 ));
1145 }
1146
1147 if !provider.api_url.starts_with("http://") && !provider.api_url.starts_with("https://") {
1148 return Err(format!(
1149 "Provider '{id}': api_url must be a valid URL, got '{}'",
1150 provider.api_url
1151 ));
1152 }
1153
1154 if provider.country != "LOCAL" && provider.country.len() != 2 {
1156 return Err(format!(
1157 "Provider '{id}': country must be 2-letter ISO code or 'LOCAL', got '{}'",
1158 provider.country
1159 ));
1160 }
1161
1162 if provider.models.is_empty() {
1164 return Err(format!("Provider '{id}': must have at least one model"));
1165 }
1166
1167 Ok(())
1168}
1169
1170fn validate_model(provider_id: &str, model_id: &str, model: &ModelYaml) -> Result<(), String> {
1172 if !(0.0..=1.0).contains(&model.quality) {
1174 return Err(format!(
1175 "Model '{provider_id}/{model_id}': quality must be 0.0-1.0, got {}",
1176 model.quality
1177 ));
1178 }
1179
1180 if model.typical_latency_ms == 0 {
1182 return Err(format!(
1183 "Model '{provider_id}/{model_id}': typical_latency_ms must be > 0"
1184 ));
1185 }
1186
1187 if model.context_tokens == 0 {
1189 return Err(format!(
1190 "Model '{provider_id}/{model_id}': context_tokens must be > 0"
1191 ));
1192 }
1193
1194 if model.model_type == ModelTypeYaml::Embedding && model.dimensions.is_none() {
1196 return Err(format!(
1197 "Model '{provider_id}/{model_id}': embedding models must specify dimensions"
1198 ));
1199 }
1200
1201 Ok(())
1202}
1203
1204impl From<ModelTypeYaml> for ModelType {
1205 fn from(t: ModelTypeYaml) -> Self {
1206 match t {
1207 ModelTypeYaml::Llm => ModelType::Llm,
1208 ModelTypeYaml::Embedding => ModelType::Embedding,
1209 ModelTypeYaml::Reranker => ModelType::Reranker,
1210 ModelTypeYaml::Ocr => ModelType::Ocr,
1211 }
1212 }
1213}
1214
1215impl From<ArchitectureYaml> for Architecture {
1216 fn from(a: ArchitectureYaml) -> Self {
1217 match a {
1218 ArchitectureYaml::Dense => Architecture::Dense,
1219 ArchitectureYaml::Moe => Architecture::Moe,
1220 ArchitectureYaml::Hybrid => Architecture::Hybrid,
1221 }
1222 }
1223}
1224
1225impl From<ReasoningEffortYaml> for ReasoningEffort {
1226 fn from(effort: ReasoningEffortYaml) -> Self {
1227 match effort {
1228 ReasoningEffortYaml::None => Self::None,
1229 ReasoningEffortYaml::Minimal => Self::Minimal,
1230 ReasoningEffortYaml::Low => Self::Low,
1231 ReasoningEffortYaml::Medium => Self::Medium,
1232 ReasoningEffortYaml::High => Self::High,
1233 ReasoningEffortYaml::Xhigh => Self::Xhigh,
1234 }
1235 }
1236}
1237
1238impl From<ProviderTypeYaml> for ProviderType {
1239 fn from(p: ProviderTypeYaml) -> Self {
1240 match p {
1241 ProviderTypeYaml::Direct => ProviderType::Direct,
1242 ProviderTypeYaml::Aggregator => ProviderType::Aggregator,
1243 }
1244 }
1245}
1246
1247#[cfg(test)]
1252mod tests {
1253 use super::*;
1254
1255 const TEST_YAML: &str = r"
1256providers:
1257 test-provider:
1258 env_key: TEST_API_KEY
1259 key_url: https://test.com/keys
1260 api_url: https://api.test.com/v1
1261 country: US
1262 region: US
1263 models:
1264 test-model:
1265 cost_class: Low
1266 typical_latency_ms: 2000
1267 quality: 0.85
1268 context_tokens: 128000
1269 capabilities: [tool_use, reasoning, code]
1270
1271 test-embedding:
1272 cost_class: VeryLow
1273 typical_latency_ms: 100
1274 quality: 0.80
1275 context_tokens: 8192
1276 capabilities: []
1277 type: embedding
1278 dimensions: 1024
1279";
1280
1281 const INVALID_COST_CLASS_YAML: &str = r"
1282providers:
1283 bad-provider:
1284 env_key: TEST_KEY
1285 key_url: https://test.com/keys
1286 api_url: https://api.test.com/v1
1287 country: US
1288 region: US
1289 models:
1290 bad-model:
1291 cost_class: SuperLow
1292 typical_latency_ms: 100
1293 quality: 0.5
1294";
1295
1296 const INVALID_CAPABILITY_YAML: &str = r"
1297providers:
1298 bad-provider:
1299 env_key: TEST_KEY
1300 key_url: https://test.com/keys
1301 api_url: https://api.test.com/v1
1302 country: US
1303 region: US
1304 models:
1305 bad-model:
1306 cost_class: Low
1307 typical_latency_ms: 100
1308 quality: 0.5
1309 capabilities: [tool_use, telepathy]
1310";
1311
1312 const INVALID_QUALITY_YAML: &str = r"
1313providers:
1314 bad-provider:
1315 env_key: TEST_KEY
1316 key_url: https://test.com/keys
1317 api_url: https://api.test.com/v1
1318 country: US
1319 region: US
1320 models:
1321 bad-model:
1322 cost_class: Low
1323 typical_latency_ms: 100
1324 quality: 1.5
1325";
1326
1327 const MISSING_DIMENSIONS_YAML: &str = r"
1328providers:
1329 bad-provider:
1330 env_key: TEST_KEY
1331 key_url: https://test.com/keys
1332 api_url: https://api.test.com/v1
1333 country: US
1334 region: US
1335 models:
1336 bad-embedding:
1337 cost_class: Low
1338 typical_latency_ms: 100
1339 quality: 0.5
1340 type: embedding
1341";
1342
1343 const UNKNOWN_FIELD_YAML: &str = r"
1344providers:
1345 bad-provider:
1346 env_key: TEST_KEY
1347 key_url: https://test.com/keys
1348 api_url: https://api.test.com/v1
1349 country: US
1350 region: US
1351 unknown_field: oops
1352 models:
1353 model:
1354 cost_class: Low
1355 typical_latency_ms: 100
1356 quality: 0.5
1357";
1358
1359 #[test]
1360 fn parse_yaml() {
1361 let registry = load_registry_from_str(TEST_YAML).unwrap();
1362 assert_eq!(registry.providers.len(), 1);
1363
1364 let provider = ®istry.providers[0];
1365 assert_eq!(provider.id, "test-provider");
1366 assert_eq!(provider.key_url, "https://test.com/keys");
1367 assert_eq!(provider.api_url, "https://api.test.com/v1");
1368 assert_eq!(provider.models.len(), 2);
1369 }
1370
1371 #[test]
1372 fn parse_model_capabilities() {
1373 let registry = load_registry_from_str(TEST_YAML).unwrap();
1374 let provider = ®istry.providers[0];
1375
1376 let llm = provider
1377 .models
1378 .iter()
1379 .find(|m| m.id == "test-model")
1380 .unwrap();
1381 assert!(llm.supports_tool_use);
1382 assert!(llm.supports_reasoning);
1383 assert!(llm.supports_code);
1384 assert!(!llm.supports_vision);
1385 assert_eq!(llm.model_type, ModelType::Llm);
1386 }
1387
1388 #[test]
1389 fn parse_embedding_model() {
1390 let registry = load_registry_from_str(TEST_YAML).unwrap();
1391 let provider = ®istry.providers[0];
1392
1393 let embedding = provider
1394 .models
1395 .iter()
1396 .find(|m| m.id == "test-embedding")
1397 .unwrap();
1398 assert_eq!(embedding.model_type, ModelType::Embedding);
1399 assert_eq!(embedding.dimensions, Some(1024));
1400 }
1401
1402 #[test]
1403 fn filter_by_model_type() {
1404 let registry = load_registry_from_str(TEST_YAML).unwrap();
1405
1406 let llms = registry.llm_models();
1407 assert_eq!(llms.len(), 1);
1408 assert_eq!(llms[0].1.id, "test-model");
1409
1410 let embeddings = registry.embedding_models();
1411 assert_eq!(embeddings.len(), 1);
1412 assert_eq!(embeddings[0].1.id, "test-embedding");
1413 }
1414
1415 #[test]
1416 fn to_model_selector() {
1417 let registry = load_registry_from_str(TEST_YAML).unwrap();
1418 let selector = registry.to_model_selector();
1419
1420 let reqs = converge_core::model_selection::AgentRequirements::balanced();
1422 let satisfying = selector.list_satisfying(&reqs);
1423 assert_eq!(satisfying.len(), 1);
1424 }
1425
1426 #[test]
1427 fn provider_availability() {
1428 let registry = load_registry_from_str(TEST_YAML).unwrap();
1429 let provider = ®istry.providers[0];
1430
1431 let _ = provider.is_available(); }
1435
1436 #[test]
1437 fn load_real_registry() {
1438 let registry = load_registry().unwrap();
1440
1441 assert!(
1443 registry.providers.len() >= 10,
1444 "Expected at least 10 providers"
1445 );
1446
1447 let provider_ids: Vec<_> = registry.providers.iter().map(|p| p.id.as_str()).collect();
1449 assert!(provider_ids.contains(&"anthropic"), "Missing anthropic");
1450 assert!(provider_ids.contains(&"openai"), "Missing openai");
1451 assert!(provider_ids.contains(&"mistral"), "Missing mistral");
1452 assert!(provider_ids.contains(&"ollama"), "Missing ollama");
1453
1454 let anthropic = registry.get_provider("anthropic").unwrap();
1456 assert_eq!(
1457 anthropic.key_url,
1458 "https://console.anthropic.com/settings/keys"
1459 );
1460 assert_eq!(anthropic.api_url, "https://api.anthropic.com/v1");
1461 assert_eq!(anthropic.env_key, "ANTHROPIC_API_KEY");
1462
1463 let ollama = registry.get_provider("ollama").unwrap();
1465 assert_eq!(ollama.region, "LOCAL");
1466
1467 let llms = registry.llm_models();
1469 assert!(llms.len() >= 30, "Expected at least 30 LLM models");
1470
1471 let embeddings = registry.embedding_models();
1473 assert!(
1474 embeddings.len() >= 3,
1475 "Expected at least 3 embedding models"
1476 );
1477
1478 println!(
1479 "Loaded {} providers with {} LLM models and {} embedding models",
1480 registry.providers.len(),
1481 llms.len(),
1482 embeddings.len()
1483 );
1484 }
1485
1486 #[test]
1491 fn rejects_invalid_cost_class() {
1492 let result = load_registry_from_str(INVALID_COST_CLASS_YAML);
1493 assert!(result.is_err());
1494 let err = result.unwrap_err().to_string();
1495 assert!(
1496 err.contains("SuperLow") || err.contains("unknown variant"),
1497 "Expected error about invalid cost class, got: {err}"
1498 );
1499 }
1500
1501 #[test]
1502 fn rejects_invalid_capability() {
1503 let result = load_registry_from_str(INVALID_CAPABILITY_YAML);
1504 assert!(result.is_err());
1505 let err = result.unwrap_err().to_string();
1506 assert!(
1507 err.contains("telepathy") || err.contains("unknown variant"),
1508 "Expected error about invalid capability, got: {err}"
1509 );
1510 }
1511
1512 #[test]
1513 fn rejects_invalid_quality() {
1514 let result = load_registry_from_str(INVALID_QUALITY_YAML);
1515 assert!(result.is_err());
1516 let err = result.unwrap_err().to_string();
1517 assert!(
1518 err.contains("quality") && err.contains("1.5"),
1519 "Expected error about quality out of range, got: {err}"
1520 );
1521 }
1522
1523 #[test]
1524 fn rejects_embedding_without_dimensions() {
1525 let result = load_registry_from_str(MISSING_DIMENSIONS_YAML);
1526 assert!(result.is_err());
1527 let err = result.unwrap_err().to_string();
1528 assert!(
1529 err.contains("dimensions"),
1530 "Expected error about missing dimensions, got: {err}"
1531 );
1532 }
1533
1534 #[test]
1535 fn rejects_unknown_fields() {
1536 let result = load_registry_from_str(UNKNOWN_FIELD_YAML);
1537 assert!(result.is_err());
1538 let err = result.unwrap_err().to_string();
1539 assert!(
1540 err.contains("unknown_field") || err.contains("unknown field"),
1541 "Expected error about unknown field, got: {err}"
1542 );
1543 }
1544
1545 #[test]
1546 fn rejects_invalid_region() {
1547 let yaml = r"
1548providers:
1549 bad:
1550 env_key: KEY
1551 key_url: https://test.com
1552 api_url: https://api.test.com
1553 country: US
1554 region: INVALID
1555 models:
1556 m:
1557 cost_class: Low
1558 typical_latency_ms: 100
1559 quality: 0.5
1560";
1561 let result = load_registry_from_str(yaml);
1562 assert!(result.is_err());
1563 let err = result.unwrap_err().to_string();
1564 assert!(
1565 err.contains("INVALID") || err.contains("unknown variant"),
1566 "Expected error about invalid region, got: {err}"
1567 );
1568 }
1569
1570 #[test]
1571 fn rejects_invalid_url() {
1572 let yaml = r"
1573providers:
1574 bad:
1575 env_key: KEY
1576 key_url: not-a-url
1577 api_url: https://api.test.com
1578 country: US
1579 region: US
1580 models:
1581 m:
1582 cost_class: Low
1583 typical_latency_ms: 100
1584 quality: 0.5
1585";
1586 let result = load_registry_from_str(yaml);
1587 assert!(result.is_err());
1588 let err = result.unwrap_err().to_string();
1589 assert!(
1590 err.contains("key_url") && err.contains("URL"),
1591 "Expected error about invalid URL, got: {err}"
1592 );
1593 }
1594
1595 #[test]
1596 fn rejects_zero_latency() {
1597 let yaml = r"
1598providers:
1599 bad:
1600 env_key: KEY
1601 key_url: https://test.com
1602 api_url: https://api.test.com
1603 country: US
1604 region: US
1605 models:
1606 m:
1607 cost_class: Low
1608 typical_latency_ms: 0
1609 quality: 0.5
1610";
1611 let result = load_registry_from_str(yaml);
1612 assert!(result.is_err());
1613 let err = result.unwrap_err().to_string();
1614 assert!(
1615 err.contains("latency") && err.contains("0"),
1616 "Expected error about zero latency, got: {err}"
1617 );
1618 }
1619
1620 #[test]
1621 fn rejects_empty_provider() {
1622 let yaml = r"
1623providers:
1624 empty:
1625 env_key: KEY
1626 key_url: https://test.com
1627 api_url: https://api.test.com
1628 country: US
1629 region: US
1630 models: {}
1631";
1632 let result = load_registry_from_str(yaml);
1633 assert!(result.is_err());
1634 let err = result.unwrap_err().to_string();
1635 assert!(
1636 err.contains("at least one model"),
1637 "Expected error about empty models, got: {err}"
1638 );
1639 }
1640}