1use crate::model_selection::{ModelMetadata, ModelSelector};
29use converge_provider_api::selection::{ComplianceLevel, CostClass, DataSovereignty};
30use schemars::JsonSchema;
31use serde::Deserialize;
32use std::collections::HashMap;
33use std::path::Path;
34
35#[derive(Debug, thiserror::Error)]
37pub enum RegistryError {
38 #[error("Failed to read registry file: {0}")]
40 IoError(#[from] std::io::Error),
41
42 #[error("Failed to parse registry YAML: {0}")]
44 ParseError(#[from] serde_yaml::Error),
45
46 #[error("Registry validation failed: {0}")]
48 ValidationError(String),
49}
50
51#[derive(Debug, Deserialize, JsonSchema)]
59#[serde(deny_unknown_fields)]
60pub struct RegistryYaml {
61 pub providers: HashMap<String, ProviderYaml>,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
67#[serde(rename_all = "snake_case")]
68pub enum ProviderTypeYaml {
69 #[default]
71 Direct,
72 Aggregator,
74}
75
76#[derive(Debug, Deserialize, JsonSchema)]
78#[serde(deny_unknown_fields)]
79pub struct ProviderYaml {
80 pub env_key: String,
82 #[serde(default)]
84 pub env_key_secondary: Option<String>,
85 pub key_url: String,
87 pub api_url: String,
89 pub country: String,
91 pub region: RegionYaml,
93 #[serde(default)]
95 pub compliance: Vec<ComplianceYaml>,
96 #[serde(default)]
98 pub provider_type: ProviderTypeYaml,
99 pub models: HashMap<String, ModelYaml>,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
107pub enum RegionYaml {
108 US,
110 EU,
112 EEA,
114 CH,
116 CN,
118 JP,
120 UK,
122 LOCAL,
124}
125
126impl RegionYaml {
127 #[must_use]
129 pub fn as_str(&self) -> &'static str {
130 match self {
131 Self::US => "US",
132 Self::EU => "EU",
133 Self::EEA => "EEA",
134 Self::CH => "CH",
135 Self::CN => "CN",
136 Self::JP => "JP",
137 Self::UK => "UK",
138 Self::LOCAL => "LOCAL",
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
145pub enum ComplianceYaml {
146 GDPR,
148 SOC2,
150 HIPAA,
152}
153
154impl From<ComplianceYaml> for ComplianceLevel {
155 fn from(c: ComplianceYaml) -> Self {
156 match c {
157 ComplianceYaml::GDPR => ComplianceLevel::GDPR,
158 ComplianceYaml::SOC2 => ComplianceLevel::SOC2,
159 ComplianceYaml::HIPAA => ComplianceLevel::HIPAA,
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
166pub enum CostClassYaml {
167 VeryLow,
169 Low,
171 Medium,
173 High,
175 VeryHigh,
177}
178
179impl From<CostClassYaml> for CostClass {
180 fn from(c: CostClassYaml) -> Self {
181 match c {
182 CostClassYaml::VeryLow => CostClass::VeryLow,
183 CostClassYaml::Low => CostClass::Low,
184 CostClassYaml::Medium => CostClass::Medium,
185 CostClassYaml::High => CostClass::High,
186 CostClassYaml::VeryHigh => CostClass::VeryHigh,
187 }
188 }
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
193#[serde(rename_all = "snake_case")]
194pub enum CapabilityYaml {
195 ToolUse,
197 Vision,
199 StructuredOutput,
201 Code,
203 Reasoning,
205 Multilingual,
207 WebSearch,
209 Audio,
211 ImageGeneration,
213 Streaming,
215 Logprobs,
217 Seed,
219 ToolChoice,
221 ParallelToolCalls,
223 PromptCaching,
225 FileSearch,
227 CodeInterpreter,
229 ComputerUse,
231 ToolSearch,
233 Mcp,
235 HostedShell,
237 ApplyPatch,
239 NativeCompaction,
241 ReasoningEffort,
243}
244
245impl CapabilityYaml {
246 #[must_use]
248 pub fn as_str(&self) -> &'static str {
249 match self {
250 Self::ToolUse => "tool_use",
251 Self::Vision => "vision",
252 Self::StructuredOutput => "structured_output",
253 Self::Code => "code",
254 Self::Reasoning => "reasoning",
255 Self::Multilingual => "multilingual",
256 Self::WebSearch => "web_search",
257 Self::Audio => "audio",
258 Self::ImageGeneration => "image_generation",
259 Self::Streaming => "streaming",
260 Self::Logprobs => "logprobs",
261 Self::Seed => "seed",
262 Self::ToolChoice => "tool_choice",
263 Self::ParallelToolCalls => "parallel_tool_calls",
264 Self::PromptCaching => "prompt_caching",
265 Self::FileSearch => "file_search",
266 Self::CodeInterpreter => "code_interpreter",
267 Self::ComputerUse => "computer_use",
268 Self::ToolSearch => "tool_search",
269 Self::Mcp => "mcp",
270 Self::HostedShell => "hosted_shell",
271 Self::ApplyPatch => "apply_patch",
272 Self::NativeCompaction => "native_compaction",
273 Self::ReasoningEffort => "reasoning_effort",
274 }
275 }
276}
277
278#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
280#[serde(rename_all = "snake_case")]
281pub enum ReasoningEffortYaml {
282 None,
284 Minimal,
286 Low,
288 Medium,
290 High,
292 Xhigh,
294}
295
296impl ReasoningEffortYaml {
297 #[must_use]
299 pub fn as_str(&self) -> &'static str {
300 match self {
301 Self::None => "none",
302 Self::Minimal => "minimal",
303 Self::Low => "low",
304 Self::Medium => "medium",
305 Self::High => "high",
306 Self::Xhigh => "xhigh",
307 }
308 }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
313#[serde(rename_all = "snake_case")]
314pub enum ModelTypeYaml {
315 #[default]
317 Llm,
318 Embedding,
320 Reranker,
322 Ocr,
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
328#[serde(rename_all = "snake_case")]
329pub enum ArchitectureYaml {
330 #[default]
332 Dense,
333 Moe,
335 Hybrid,
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
341#[serde(rename_all = "snake_case")]
342pub enum ModalityYaml {
343 Text,
345 Image,
347 Video,
349 Audio,
351}
352
353#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
355#[serde(deny_unknown_fields)]
356pub struct AgenticYaml {
357 #[serde(default)]
359 pub max_parallel_agents: Option<u32>,
360 #[serde(default)]
362 pub supports_orchestration: bool,
363}
364
365#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
367#[serde(deny_unknown_fields)]
368pub struct PricingYaml {
369 #[serde(default)]
371 pub input_per_m: Option<f64>,
372 #[serde(default)]
374 pub output_per_m: Option<f64>,
375}
376
377#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
379#[serde(deny_unknown_fields)]
380pub struct RateLimitsYaml {
381 #[serde(default)]
383 pub requests_per_min: Option<u32>,
384 #[serde(default)]
386 pub tokens_per_min: Option<u32>,
387 #[serde(default)]
389 pub requests_per_day: Option<u32>,
390 #[serde(default)]
392 pub concurrent_requests: Option<u32>,
393}
394
395#[derive(Debug, Deserialize, JsonSchema)]
397#[serde(deny_unknown_fields)]
398pub struct ModelYaml {
399 pub cost_class: CostClassYaml,
401 pub typical_latency_ms: u32,
403 pub quality: f64,
405 #[serde(default = "default_context_tokens")]
407 pub context_tokens: usize,
408 #[serde(default)]
410 pub capabilities: Vec<CapabilityYaml>,
411 #[serde(default, rename = "type")]
413 pub model_type: ModelTypeYaml,
414 #[serde(default)]
416 pub dimensions: Option<usize>,
417
418 #[serde(default)]
421 pub architecture: ArchitectureYaml,
422 #[serde(default)]
424 pub total_params_b: Option<f64>,
425 #[serde(default)]
427 pub active_params_b: Option<f64>,
428 #[serde(default)]
430 pub max_output_tokens: Option<usize>,
431 #[serde(default)]
433 pub native_multimodal: bool,
434 #[serde(default)]
436 pub modalities: Vec<ModalityYaml>,
437 #[serde(default)]
439 pub agentic: Option<AgenticYaml>,
440 #[serde(default)]
442 pub thinking_mode: bool,
443 #[serde(default)]
445 pub reasoning_effort_levels: Vec<ReasoningEffortYaml>,
446 #[serde(default)]
448 pub native_compaction: bool,
449 #[serde(default)]
451 pub thinking_variant: Option<String>,
452 #[serde(default)]
454 pub pricing: Option<PricingYaml>,
455 #[serde(default)]
457 pub publisher: Option<String>,
458 #[serde(default)]
460 pub family: Option<String>,
461 #[serde(default)]
463 pub release_date: Option<String>,
464 #[serde(default)]
466 pub training_cutoff: Option<String>,
467 #[serde(default)]
469 pub open_weights: bool,
470 #[serde(default)]
472 pub license: Option<String>,
473 #[serde(default)]
475 pub deprecated: bool,
476 #[serde(default)]
478 pub beta: bool,
479 #[serde(default)]
481 pub benchmarks: HashMap<String, f64>,
482 #[serde(default)]
484 pub tags: Vec<String>,
485 #[serde(default)]
487 pub rate_limits: Option<RateLimitsYaml>,
488 #[serde(default)]
490 pub notes: Option<String>,
491}
492
493fn default_context_tokens() -> usize {
494 8192
495}
496
497#[must_use]
513pub fn generate_schema() -> schemars::schema::RootSchema {
514 schemars::schema_for!(RegistryYaml)
515}
516
517#[derive(Debug, Clone, Copy, PartialEq, Eq)]
523pub enum ProviderType {
524 Direct,
526 Aggregator,
528}
529
530#[derive(Debug, Clone)]
532pub struct LoadedProvider {
533 pub id: String,
535 pub env_key: String,
537 pub env_key_secondary: Option<String>,
539 pub key_url: String,
541 pub api_url: String,
543 pub country: String,
545 pub region: String,
547 pub compliance: Vec<ComplianceLevel>,
549 pub provider_type: ProviderType,
551 pub models: Vec<LoadedModel>,
553}
554
555impl LoadedProvider {
556 #[must_use]
558 pub fn is_available(&self) -> bool {
559 let primary_ok = std::env::var(&self.env_key).is_ok();
560 let secondary_ok = self
561 .env_key_secondary
562 .as_ref()
563 .map(|k| std::env::var(k).is_ok())
564 .unwrap_or(true);
565 primary_ok && secondary_ok
566 }
567
568 #[must_use]
570 pub fn api_key(&self) -> Option<String> {
571 std::env::var(&self.env_key).ok()
572 }
573
574 #[must_use]
576 pub fn secondary_api_key(&self) -> Option<String> {
577 self.env_key_secondary
578 .as_ref()
579 .and_then(|k| std::env::var(k).ok())
580 }
581}
582
583#[derive(Debug, Clone, Copy, PartialEq, Eq)]
585pub enum Architecture {
586 Dense,
588 Moe,
590 Hybrid,
592}
593
594#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
596pub enum Modality {
597 Text,
599 Image,
601 Video,
603 Audio,
605}
606
607#[derive(Debug, Clone, Default)]
609pub struct AgenticCapabilities {
610 pub max_parallel_agents: Option<u32>,
612 pub supports_orchestration: bool,
614}
615
616#[derive(Debug, Clone, Default)]
618pub struct Pricing {
619 pub input_per_m: Option<f64>,
621 pub output_per_m: Option<f64>,
623}
624
625#[derive(Debug, Clone, Default)]
627pub struct RateLimits {
628 pub requests_per_min: Option<u32>,
630 pub tokens_per_min: Option<u32>,
632 pub requests_per_day: Option<u32>,
634 pub concurrent_requests: Option<u32>,
636}
637
638#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
640pub enum ReasoningEffort {
641 None,
643 Minimal,
645 Low,
647 Medium,
649 High,
651 Xhigh,
653}
654
655impl ReasoningEffort {
656 #[must_use]
658 pub fn as_str(&self) -> &'static str {
659 match self {
660 Self::None => "none",
661 Self::Minimal => "minimal",
662 Self::Low => "low",
663 Self::Medium => "medium",
664 Self::High => "high",
665 Self::Xhigh => "xhigh",
666 }
667 }
668}
669
670#[derive(Debug, Clone)]
672#[allow(clippy::struct_excessive_bools)]
673pub struct LoadedModel {
674 pub id: String,
676 pub cost_class: CostClass,
678 pub typical_latency_ms: u32,
680 pub quality: f64,
682 pub context_tokens: usize,
684 pub model_type: ModelType,
686 pub dimensions: Option<usize>,
688 pub capabilities: Vec<CapabilityYaml>,
690 pub supports_tool_use: bool,
693 pub supports_vision: bool,
695 pub supports_structured_output: bool,
697 pub supports_code: bool,
699 pub supports_reasoning: bool,
701 pub supports_multilingual: bool,
703 pub supports_web_search: bool,
705
706 pub architecture: Architecture,
709 pub total_params_b: Option<f64>,
711 pub active_params_b: Option<f64>,
713 pub max_output_tokens: Option<usize>,
715 pub native_multimodal: bool,
717 pub modalities: Vec<Modality>,
719 pub agentic: Option<AgenticCapabilities>,
721 pub thinking_mode: bool,
723 pub reasoning_effort_levels: Vec<ReasoningEffort>,
725 pub native_compaction: bool,
727 pub thinking_variant: Option<String>,
729 pub pricing: Option<Pricing>,
731 pub publisher: Option<String>,
733 pub family: Option<String>,
735 pub release_date: Option<String>,
737 pub training_cutoff: Option<String>,
739 pub open_weights: bool,
741 pub license: Option<String>,
743 pub deprecated: bool,
745 pub beta: bool,
747 pub benchmarks: HashMap<String, f64>,
749 pub tags: Vec<String>,
751 pub rate_limits: Option<RateLimits>,
753 pub notes: Option<String>,
755}
756
757#[derive(Debug, Clone, Copy, PartialEq, Eq)]
759pub enum ModelType {
760 Llm,
762 Embedding,
764 Reranker,
766 Ocr,
768}
769
770#[derive(Debug, Clone)]
772pub struct LoadedRegistry {
773 providers: Vec<LoadedProvider>,
775}
776
777impl LoadedRegistry {
778 #[must_use]
780 pub fn providers(&self) -> &[LoadedProvider] {
781 &self.providers
782 }
783
784 #[must_use]
786 pub fn available_providers(&self) -> Vec<&LoadedProvider> {
787 self.providers.iter().filter(|p| p.is_available()).collect()
788 }
789
790 #[must_use]
792 pub fn get_provider(&self, id: &str) -> Option<&LoadedProvider> {
793 self.providers.iter().find(|p| p.id == id)
794 }
795
796 #[must_use]
798 pub fn llm_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
799 self.providers
800 .iter()
801 .flat_map(|p| {
802 p.models
803 .iter()
804 .filter(|m| m.model_type == ModelType::Llm)
805 .map(move |m| (p, m))
806 })
807 .collect()
808 }
809
810 #[must_use]
812 pub fn embedding_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
813 self.providers
814 .iter()
815 .flat_map(|p| {
816 p.models
817 .iter()
818 .filter(|m| m.model_type == ModelType::Embedding)
819 .map(move |m| (p, m))
820 })
821 .collect()
822 }
823
824 #[must_use]
826 pub fn reranker_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
827 self.providers
828 .iter()
829 .flat_map(|p| {
830 p.models
831 .iter()
832 .filter(|m| m.model_type == ModelType::Reranker)
833 .map(move |m| (p, m))
834 })
835 .collect()
836 }
837
838 #[must_use]
840 pub fn to_model_selector(&self) -> ModelSelector {
841 let mut selector = ModelSelector::empty();
842
843 for provider in &self.providers {
844 for model in &provider.models {
845 if model.model_type != ModelType::Llm {
846 continue; }
848
849 let data_sovereignty = match provider.region.as_str() {
850 "EU" | "EEA" => DataSovereignty::EU,
851 "CH" => DataSovereignty::Switzerland,
852 "CN" => DataSovereignty::China,
853 "US" => DataSovereignty::US,
854 "LOCAL" => DataSovereignty::OnPremises,
855 _ => DataSovereignty::Any,
856 };
857
858 let compliance = provider
859 .compliance
860 .first()
861 .copied()
862 .unwrap_or(ComplianceLevel::None);
863
864 let metadata = ModelMetadata::new(
865 &provider.id,
866 &model.id,
867 model.cost_class,
868 model.typical_latency_ms,
869 model.quality,
870 )
871 .with_reasoning(model.supports_reasoning)
872 .with_web_search(model.supports_web_search)
873 .with_data_sovereignty(data_sovereignty)
874 .with_compliance(compliance)
875 .with_multilingual(model.supports_multilingual)
876 .with_context_tokens(model.context_tokens)
877 .with_tool_use(model.supports_tool_use)
878 .with_vision(model.supports_vision)
879 .with_structured_output(model.supports_structured_output)
880 .with_code(model.supports_code)
881 .with_location(&provider.country, &provider.region);
882
883 selector = selector.with_model(metadata);
884 }
885 }
886
887 selector
888 }
889
890 pub fn print_summary(&self) {
892 println!("Model Registry Summary");
893 println!("======================\n");
894
895 for provider in &self.providers {
896 let status = if provider.is_available() {
897 "✓ available"
898 } else {
899 "✗ no key"
900 };
901
902 println!(
903 "{} ({}) - {} models [{}]",
904 provider.id,
905 provider.region,
906 provider.models.len(),
907 status
908 );
909 println!(" Key URL: {}", provider.key_url);
910 println!(" API URL: {}", provider.api_url);
911 println!();
912 }
913 }
914}
915
916pub const DEFAULT_REGISTRY_PATH: &str = "converge-provider/config/models.yaml";
922
923pub fn load_registry() -> Result<LoadedRegistry, RegistryError> {
934 if let Ok(path) = std::env::var("CONVERGE_MODELS_PATH") {
936 return load_registry_from_path(&path);
937 }
938
939 if std::path::Path::new(DEFAULT_REGISTRY_PATH).exists() {
941 return load_registry_from_path(DEFAULT_REGISTRY_PATH);
942 }
943
944 let crate_path = "config/models.yaml";
946 if std::path::Path::new(crate_path).exists() {
947 return load_registry_from_path(crate_path);
948 }
949
950 load_registry_from_str(include_str!("../config/models.yaml"))
952}
953
954pub fn load_registry_from_path(path: impl AsRef<Path>) -> Result<LoadedRegistry, RegistryError> {
960 let content = std::fs::read_to_string(path)?;
961 load_registry_from_str(&content)
962}
963
964pub fn load_registry_from_str(yaml: &str) -> Result<LoadedRegistry, RegistryError> {
970 let registry_yaml: RegistryYaml = serde_yaml::from_str(yaml)?;
971
972 let mut providers = Vec::new();
973 let mut errors = Vec::new();
974
975 for (provider_id, provider_yaml) in registry_yaml.providers {
976 if let Err(e) = validate_provider(&provider_id, &provider_yaml) {
978 errors.push(e);
979 continue;
980 }
981
982 let compliance = provider_yaml
983 .compliance
984 .iter()
985 .map(|c| ComplianceLevel::from(*c))
986 .collect();
987
988 let mut models = Vec::new();
989
990 for (model_id, model_yaml) in provider_yaml.models {
991 if let Err(e) = validate_model(&provider_id, &model_id, &model_yaml) {
993 errors.push(e);
994 continue;
995 }
996
997 let capabilities: std::collections::HashSet<_> =
998 model_yaml.capabilities.iter().copied().collect();
999
1000 let modalities: Vec<Modality> = model_yaml
1002 .modalities
1003 .iter()
1004 .map(|m| match m {
1005 ModalityYaml::Text => Modality::Text,
1006 ModalityYaml::Image => Modality::Image,
1007 ModalityYaml::Video => Modality::Video,
1008 ModalityYaml::Audio => Modality::Audio,
1009 })
1010 .collect();
1011
1012 let reasoning_effort_levels = model_yaml
1014 .reasoning_effort_levels
1015 .iter()
1016 .copied()
1017 .map(ReasoningEffort::from)
1018 .collect();
1019
1020 let agentic = model_yaml.agentic.as_ref().map(|a| AgenticCapabilities {
1022 max_parallel_agents: a.max_parallel_agents,
1023 supports_orchestration: a.supports_orchestration,
1024 });
1025
1026 let pricing = model_yaml.pricing.as_ref().map(|p| Pricing {
1028 input_per_m: p.input_per_m,
1029 output_per_m: p.output_per_m,
1030 });
1031
1032 let rate_limits = model_yaml.rate_limits.as_ref().map(|r| RateLimits {
1034 requests_per_min: r.requests_per_min,
1035 tokens_per_min: r.tokens_per_min,
1036 requests_per_day: r.requests_per_day,
1037 concurrent_requests: r.concurrent_requests,
1038 });
1039
1040 let model = LoadedModel {
1041 id: model_id,
1042 cost_class: model_yaml.cost_class.into(),
1043 typical_latency_ms: model_yaml.typical_latency_ms,
1044 quality: model_yaml.quality,
1045 context_tokens: model_yaml.context_tokens,
1046 model_type: model_yaml.model_type.into(),
1047 dimensions: model_yaml.dimensions,
1048 capabilities: model_yaml.capabilities.clone(),
1049 supports_tool_use: capabilities.contains(&CapabilityYaml::ToolUse),
1050 supports_vision: capabilities.contains(&CapabilityYaml::Vision),
1051 supports_structured_output: capabilities
1052 .contains(&CapabilityYaml::StructuredOutput),
1053 supports_code: capabilities.contains(&CapabilityYaml::Code),
1054 supports_reasoning: capabilities.contains(&CapabilityYaml::Reasoning),
1055 supports_multilingual: capabilities.contains(&CapabilityYaml::Multilingual),
1056 supports_web_search: capabilities.contains(&CapabilityYaml::WebSearch),
1057 architecture: model_yaml.architecture.into(),
1059 total_params_b: model_yaml.total_params_b,
1060 active_params_b: model_yaml.active_params_b,
1061 max_output_tokens: model_yaml.max_output_tokens,
1062 native_multimodal: model_yaml.native_multimodal,
1063 modalities,
1064 agentic,
1065 thinking_mode: model_yaml.thinking_mode,
1066 reasoning_effort_levels,
1067 native_compaction: model_yaml.native_compaction,
1068 thinking_variant: model_yaml.thinking_variant.clone(),
1069 pricing,
1070 publisher: model_yaml.publisher.clone(),
1071 family: model_yaml.family.clone(),
1072 release_date: model_yaml.release_date.clone(),
1073 training_cutoff: model_yaml.training_cutoff.clone(),
1074 open_weights: model_yaml.open_weights,
1075 license: model_yaml.license.clone(),
1076 deprecated: model_yaml.deprecated,
1077 beta: model_yaml.beta,
1078 benchmarks: model_yaml.benchmarks.clone(),
1079 tags: model_yaml.tags.clone(),
1080 rate_limits,
1081 notes: model_yaml.notes.clone(),
1082 };
1083
1084 models.push(model);
1085 }
1086
1087 models.sort_by(|a, b| a.id.cmp(&b.id));
1089
1090 let provider = LoadedProvider {
1091 id: provider_id,
1092 env_key: provider_yaml.env_key,
1093 env_key_secondary: provider_yaml.env_key_secondary,
1094 key_url: provider_yaml.key_url,
1095 api_url: provider_yaml.api_url,
1096 country: provider_yaml.country,
1097 region: provider_yaml.region.as_str().to_string(),
1098 compliance,
1099 provider_type: provider_yaml.provider_type.into(),
1100 models,
1101 };
1102
1103 providers.push(provider);
1104 }
1105
1106 if !errors.is_empty() {
1108 return Err(RegistryError::ValidationError(errors.join("; ")));
1109 }
1110
1111 providers.sort_by(|a, b| a.id.cmp(&b.id));
1113
1114 Ok(LoadedRegistry { providers })
1115}
1116
1117fn validate_provider(id: &str, provider: &ProviderYaml) -> Result<(), String> {
1119 if provider.env_key.is_empty() {
1121 return Err(format!("Provider '{id}': env_key cannot be empty"));
1122 }
1123
1124 if !provider.key_url.starts_with("http://") && !provider.key_url.starts_with("https://") {
1126 return Err(format!(
1127 "Provider '{id}': key_url must be a valid URL, got '{}'",
1128 provider.key_url
1129 ));
1130 }
1131
1132 if !provider.api_url.starts_with("http://") && !provider.api_url.starts_with("https://") {
1133 return Err(format!(
1134 "Provider '{id}': api_url must be a valid URL, got '{}'",
1135 provider.api_url
1136 ));
1137 }
1138
1139 if provider.country != "LOCAL" && provider.country.len() != 2 {
1141 return Err(format!(
1142 "Provider '{id}': country must be 2-letter ISO code or 'LOCAL', got '{}'",
1143 provider.country
1144 ));
1145 }
1146
1147 if provider.models.is_empty() {
1149 return Err(format!("Provider '{id}': must have at least one model"));
1150 }
1151
1152 Ok(())
1153}
1154
1155fn validate_model(provider_id: &str, model_id: &str, model: &ModelYaml) -> Result<(), String> {
1157 if !(0.0..=1.0).contains(&model.quality) {
1159 return Err(format!(
1160 "Model '{provider_id}/{model_id}': quality must be 0.0-1.0, got {}",
1161 model.quality
1162 ));
1163 }
1164
1165 if model.typical_latency_ms == 0 {
1167 return Err(format!(
1168 "Model '{provider_id}/{model_id}': typical_latency_ms must be > 0"
1169 ));
1170 }
1171
1172 if model.context_tokens == 0 {
1174 return Err(format!(
1175 "Model '{provider_id}/{model_id}': context_tokens must be > 0"
1176 ));
1177 }
1178
1179 if model.model_type == ModelTypeYaml::Embedding && model.dimensions.is_none() {
1181 return Err(format!(
1182 "Model '{provider_id}/{model_id}': embedding models must specify dimensions"
1183 ));
1184 }
1185
1186 Ok(())
1187}
1188
1189impl From<ModelTypeYaml> for ModelType {
1190 fn from(t: ModelTypeYaml) -> Self {
1191 match t {
1192 ModelTypeYaml::Llm => ModelType::Llm,
1193 ModelTypeYaml::Embedding => ModelType::Embedding,
1194 ModelTypeYaml::Reranker => ModelType::Reranker,
1195 ModelTypeYaml::Ocr => ModelType::Ocr,
1196 }
1197 }
1198}
1199
1200impl From<ArchitectureYaml> for Architecture {
1201 fn from(a: ArchitectureYaml) -> Self {
1202 match a {
1203 ArchitectureYaml::Dense => Architecture::Dense,
1204 ArchitectureYaml::Moe => Architecture::Moe,
1205 ArchitectureYaml::Hybrid => Architecture::Hybrid,
1206 }
1207 }
1208}
1209
1210impl From<ReasoningEffortYaml> for ReasoningEffort {
1211 fn from(effort: ReasoningEffortYaml) -> Self {
1212 match effort {
1213 ReasoningEffortYaml::None => Self::None,
1214 ReasoningEffortYaml::Minimal => Self::Minimal,
1215 ReasoningEffortYaml::Low => Self::Low,
1216 ReasoningEffortYaml::Medium => Self::Medium,
1217 ReasoningEffortYaml::High => Self::High,
1218 ReasoningEffortYaml::Xhigh => Self::Xhigh,
1219 }
1220 }
1221}
1222
1223impl From<ProviderTypeYaml> for ProviderType {
1224 fn from(p: ProviderTypeYaml) -> Self {
1225 match p {
1226 ProviderTypeYaml::Direct => ProviderType::Direct,
1227 ProviderTypeYaml::Aggregator => ProviderType::Aggregator,
1228 }
1229 }
1230}
1231
1232#[cfg(test)]
1237mod tests {
1238 use super::*;
1239
1240 const TEST_YAML: &str = r"
1241providers:
1242 test-provider:
1243 env_key: TEST_API_KEY
1244 key_url: https://test.com/keys
1245 api_url: https://api.test.com/v1
1246 country: US
1247 region: US
1248 models:
1249 test-model:
1250 cost_class: Low
1251 typical_latency_ms: 2000
1252 quality: 0.85
1253 context_tokens: 128000
1254 capabilities: [tool_use, reasoning, code]
1255
1256 test-embedding:
1257 cost_class: VeryLow
1258 typical_latency_ms: 100
1259 quality: 0.80
1260 context_tokens: 8192
1261 capabilities: []
1262 type: embedding
1263 dimensions: 1024
1264";
1265
1266 const INVALID_COST_CLASS_YAML: &str = r"
1267providers:
1268 bad-provider:
1269 env_key: TEST_KEY
1270 key_url: https://test.com/keys
1271 api_url: https://api.test.com/v1
1272 country: US
1273 region: US
1274 models:
1275 bad-model:
1276 cost_class: SuperLow
1277 typical_latency_ms: 100
1278 quality: 0.5
1279";
1280
1281 const INVALID_CAPABILITY_YAML: &str = r"
1282providers:
1283 bad-provider:
1284 env_key: TEST_KEY
1285 key_url: https://test.com/keys
1286 api_url: https://api.test.com/v1
1287 country: US
1288 region: US
1289 models:
1290 bad-model:
1291 cost_class: Low
1292 typical_latency_ms: 100
1293 quality: 0.5
1294 capabilities: [tool_use, telepathy]
1295";
1296
1297 const INVALID_QUALITY_YAML: &str = r"
1298providers:
1299 bad-provider:
1300 env_key: TEST_KEY
1301 key_url: https://test.com/keys
1302 api_url: https://api.test.com/v1
1303 country: US
1304 region: US
1305 models:
1306 bad-model:
1307 cost_class: Low
1308 typical_latency_ms: 100
1309 quality: 1.5
1310";
1311
1312 const MISSING_DIMENSIONS_YAML: &str = r"
1313providers:
1314 bad-provider:
1315 env_key: TEST_KEY
1316 key_url: https://test.com/keys
1317 api_url: https://api.test.com/v1
1318 country: US
1319 region: US
1320 models:
1321 bad-embedding:
1322 cost_class: Low
1323 typical_latency_ms: 100
1324 quality: 0.5
1325 type: embedding
1326";
1327
1328 const UNKNOWN_FIELD_YAML: &str = r"
1329providers:
1330 bad-provider:
1331 env_key: TEST_KEY
1332 key_url: https://test.com/keys
1333 api_url: https://api.test.com/v1
1334 country: US
1335 region: US
1336 unknown_field: oops
1337 models:
1338 model:
1339 cost_class: Low
1340 typical_latency_ms: 100
1341 quality: 0.5
1342";
1343
1344 #[test]
1345 fn parse_yaml() {
1346 let registry = load_registry_from_str(TEST_YAML).unwrap();
1347 assert_eq!(registry.providers.len(), 1);
1348
1349 let provider = ®istry.providers[0];
1350 assert_eq!(provider.id, "test-provider");
1351 assert_eq!(provider.key_url, "https://test.com/keys");
1352 assert_eq!(provider.api_url, "https://api.test.com/v1");
1353 assert_eq!(provider.models.len(), 2);
1354 }
1355
1356 #[test]
1357 fn parse_model_capabilities() {
1358 let registry = load_registry_from_str(TEST_YAML).unwrap();
1359 let provider = ®istry.providers[0];
1360
1361 let llm = provider
1362 .models
1363 .iter()
1364 .find(|m| m.id == "test-model")
1365 .unwrap();
1366 assert!(llm.supports_tool_use);
1367 assert!(llm.supports_reasoning);
1368 assert!(llm.supports_code);
1369 assert!(!llm.supports_vision);
1370 assert_eq!(llm.model_type, ModelType::Llm);
1371 }
1372
1373 #[test]
1374 fn parse_embedding_model() {
1375 let registry = load_registry_from_str(TEST_YAML).unwrap();
1376 let provider = ®istry.providers[0];
1377
1378 let embedding = provider
1379 .models
1380 .iter()
1381 .find(|m| m.id == "test-embedding")
1382 .unwrap();
1383 assert_eq!(embedding.model_type, ModelType::Embedding);
1384 assert_eq!(embedding.dimensions, Some(1024));
1385 }
1386
1387 #[test]
1388 fn filter_by_model_type() {
1389 let registry = load_registry_from_str(TEST_YAML).unwrap();
1390
1391 let llms = registry.llm_models();
1392 assert_eq!(llms.len(), 1);
1393 assert_eq!(llms[0].1.id, "test-model");
1394
1395 let embeddings = registry.embedding_models();
1396 assert_eq!(embeddings.len(), 1);
1397 assert_eq!(embeddings[0].1.id, "test-embedding");
1398 }
1399
1400 #[test]
1401 fn to_model_selector() {
1402 let registry = load_registry_from_str(TEST_YAML).unwrap();
1403 let selector = registry.to_model_selector();
1404
1405 let reqs = converge_core::model_selection::AgentRequirements::balanced();
1407 let satisfying = selector.list_satisfying(&reqs);
1408 assert_eq!(satisfying.len(), 1);
1409 }
1410
1411 #[test]
1412 fn provider_availability() {
1413 let registry = load_registry_from_str(TEST_YAML).unwrap();
1414 let provider = ®istry.providers[0];
1415
1416 let _ = provider.is_available(); }
1420
1421 #[test]
1422 fn load_real_registry() {
1423 let registry = load_registry().unwrap();
1425
1426 assert!(
1428 registry.providers.len() >= 10,
1429 "Expected at least 10 providers"
1430 );
1431
1432 let provider_ids: Vec<_> = registry.providers.iter().map(|p| p.id.as_str()).collect();
1434 assert!(provider_ids.contains(&"anthropic"), "Missing anthropic");
1435 assert!(provider_ids.contains(&"openai"), "Missing openai");
1436 assert!(provider_ids.contains(&"mistral"), "Missing mistral");
1437 assert!(provider_ids.contains(&"ollama"), "Missing ollama");
1438
1439 let anthropic = registry.get_provider("anthropic").unwrap();
1441 assert_eq!(
1442 anthropic.key_url,
1443 "https://console.anthropic.com/settings/keys"
1444 );
1445 assert_eq!(anthropic.api_url, "https://api.anthropic.com/v1");
1446 assert_eq!(anthropic.env_key, "ANTHROPIC_API_KEY");
1447
1448 let ollama = registry.get_provider("ollama").unwrap();
1450 assert_eq!(ollama.region, "LOCAL");
1451
1452 let llms = registry.llm_models();
1454 assert!(llms.len() >= 30, "Expected at least 30 LLM models");
1455
1456 let embeddings = registry.embedding_models();
1458 assert!(
1459 embeddings.len() >= 3,
1460 "Expected at least 3 embedding models"
1461 );
1462
1463 println!(
1464 "Loaded {} providers with {} LLM models and {} embedding models",
1465 registry.providers.len(),
1466 llms.len(),
1467 embeddings.len()
1468 );
1469 }
1470
1471 #[test]
1476 fn rejects_invalid_cost_class() {
1477 let result = load_registry_from_str(INVALID_COST_CLASS_YAML);
1478 assert!(result.is_err());
1479 let err = result.unwrap_err().to_string();
1480 assert!(
1481 err.contains("SuperLow") || err.contains("unknown variant"),
1482 "Expected error about invalid cost class, got: {err}"
1483 );
1484 }
1485
1486 #[test]
1487 fn rejects_invalid_capability() {
1488 let result = load_registry_from_str(INVALID_CAPABILITY_YAML);
1489 assert!(result.is_err());
1490 let err = result.unwrap_err().to_string();
1491 assert!(
1492 err.contains("telepathy") || err.contains("unknown variant"),
1493 "Expected error about invalid capability, got: {err}"
1494 );
1495 }
1496
1497 #[test]
1498 fn rejects_invalid_quality() {
1499 let result = load_registry_from_str(INVALID_QUALITY_YAML);
1500 assert!(result.is_err());
1501 let err = result.unwrap_err().to_string();
1502 assert!(
1503 err.contains("quality") && err.contains("1.5"),
1504 "Expected error about quality out of range, got: {err}"
1505 );
1506 }
1507
1508 #[test]
1509 fn rejects_embedding_without_dimensions() {
1510 let result = load_registry_from_str(MISSING_DIMENSIONS_YAML);
1511 assert!(result.is_err());
1512 let err = result.unwrap_err().to_string();
1513 assert!(
1514 err.contains("dimensions"),
1515 "Expected error about missing dimensions, got: {err}"
1516 );
1517 }
1518
1519 #[test]
1520 fn rejects_unknown_fields() {
1521 let result = load_registry_from_str(UNKNOWN_FIELD_YAML);
1522 assert!(result.is_err());
1523 let err = result.unwrap_err().to_string();
1524 assert!(
1525 err.contains("unknown_field") || err.contains("unknown field"),
1526 "Expected error about unknown field, got: {err}"
1527 );
1528 }
1529
1530 #[test]
1531 fn rejects_invalid_region() {
1532 let yaml = r"
1533providers:
1534 bad:
1535 env_key: KEY
1536 key_url: https://test.com
1537 api_url: https://api.test.com
1538 country: US
1539 region: INVALID
1540 models:
1541 m:
1542 cost_class: Low
1543 typical_latency_ms: 100
1544 quality: 0.5
1545";
1546 let result = load_registry_from_str(yaml);
1547 assert!(result.is_err());
1548 let err = result.unwrap_err().to_string();
1549 assert!(
1550 err.contains("INVALID") || err.contains("unknown variant"),
1551 "Expected error about invalid region, got: {err}"
1552 );
1553 }
1554
1555 #[test]
1556 fn rejects_invalid_url() {
1557 let yaml = r"
1558providers:
1559 bad:
1560 env_key: KEY
1561 key_url: not-a-url
1562 api_url: https://api.test.com
1563 country: US
1564 region: US
1565 models:
1566 m:
1567 cost_class: Low
1568 typical_latency_ms: 100
1569 quality: 0.5
1570";
1571 let result = load_registry_from_str(yaml);
1572 assert!(result.is_err());
1573 let err = result.unwrap_err().to_string();
1574 assert!(
1575 err.contains("key_url") && err.contains("URL"),
1576 "Expected error about invalid URL, got: {err}"
1577 );
1578 }
1579
1580 #[test]
1581 fn rejects_zero_latency() {
1582 let yaml = r"
1583providers:
1584 bad:
1585 env_key: KEY
1586 key_url: https://test.com
1587 api_url: https://api.test.com
1588 country: US
1589 region: US
1590 models:
1591 m:
1592 cost_class: Low
1593 typical_latency_ms: 0
1594 quality: 0.5
1595";
1596 let result = load_registry_from_str(yaml);
1597 assert!(result.is_err());
1598 let err = result.unwrap_err().to_string();
1599 assert!(
1600 err.contains("latency") && err.contains("0"),
1601 "Expected error about zero latency, got: {err}"
1602 );
1603 }
1604
1605 #[test]
1606 fn rejects_empty_provider() {
1607 let yaml = r"
1608providers:
1609 empty:
1610 env_key: KEY
1611 key_url: https://test.com
1612 api_url: https://api.test.com
1613 country: US
1614 region: US
1615 models: {}
1616";
1617 let result = load_registry_from_str(yaml);
1618 assert!(result.is_err());
1619 let err = result.unwrap_err().to_string();
1620 assert!(
1621 err.contains("at least one model"),
1622 "Expected error about empty models, got: {err}"
1623 );
1624 }
1625}