Skip to main content

converge_provider/
registry_loader.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3// See LICENSE file in the project root for full license information.
4
5//! YAML-based model registry loader.
6//!
7//! Loads model metadata from `config/models.yaml` and provides
8//! a registry that can be used for model selection.
9//!
10//! # Example
11//!
12//! ```ignore
13//! use converge_provider::registry_loader::{load_registry, RegistryConfig};
14//!
15//! // Load from default path
16//! let registry = load_registry()?;
17//!
18//! // Check available providers
19//! for provider in registry.providers() {
20//!     println!("{}: {} (key: {})",
21//!         provider.id,
22//!         provider.api_url,
23//!         if provider.is_available() { "set" } else { "missing" }
24//!     );
25//! }
26//! ```
27
28use crate::model_selection::{ModelMetadata, ModelSelector};
29use converge_provider_api::selection::{ComplianceLevel, CostClass, DataSovereignty};
30use schemars::JsonSchema;
31use serde::Deserialize;
32use std::collections::HashMap;
33use std::path::Path;
34
35/// Error type for registry loading.
36#[derive(Debug, thiserror::Error)]
37pub enum RegistryError {
38    /// Failed to read the YAML file.
39    #[error("Failed to read registry file: {0}")]
40    IoError(#[from] std::io::Error),
41
42    /// Failed to parse the YAML.
43    #[error("Failed to parse registry YAML: {0}")]
44    ParseError(#[from] serde_yaml::Error),
45
46    /// Validation error in the registry.
47    #[error("Registry validation failed: {0}")]
48    ValidationError(String),
49}
50
51// =============================================================================
52// YAML SCHEMA TYPES (Type-safe with serde enums)
53// =============================================================================
54
55/// Root of the YAML file.
56///
57/// This is the schema for `config/models.yaml`.
58#[derive(Debug, Deserialize, JsonSchema)]
59#[serde(deny_unknown_fields)]
60pub struct RegistryYaml {
61    /// All providers.
62    pub providers: HashMap<String, ProviderYaml>,
63}
64
65/// Provider type classification.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
67#[serde(rename_all = "snake_case")]
68pub enum ProviderTypeYaml {
69    /// Direct API access to model provider (default).
70    #[default]
71    Direct,
72    /// Routes to multiple underlying providers (adds latency overhead).
73    Aggregator,
74}
75
76/// A provider in the YAML.
77#[derive(Debug, Deserialize, JsonSchema)]
78#[serde(deny_unknown_fields)]
79pub struct ProviderYaml {
80    /// Environment variable for API key.
81    pub env_key: String,
82    /// Optional secondary environment variable (e.g., Baidu secret key).
83    #[serde(default)]
84    pub env_key_secondary: Option<String>,
85    /// URL to get an API key.
86    pub key_url: String,
87    /// API endpoint URL.
88    pub api_url: String,
89    /// ISO country code (2 letters) or "LOCAL".
90    pub country: String,
91    /// Region (US, EU, CN, LOCAL, etc.).
92    pub region: RegionYaml,
93    /// Compliance certifications.
94    #[serde(default)]
95    pub compliance: Vec<ComplianceYaml>,
96    /// Provider type (direct or aggregator).
97    #[serde(default)]
98    pub provider_type: ProviderTypeYaml,
99    /// Models provided.
100    pub models: HashMap<String, ModelYaml>,
101}
102
103/// Region enum - type-safe parsing.
104///
105/// Represents the data residency region for a provider.
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
107pub enum RegionYaml {
108    /// United States
109    US,
110    /// European Union
111    EU,
112    /// European Economic Area
113    EEA,
114    /// Switzerland
115    CH,
116    /// China
117    CN,
118    /// Japan
119    JP,
120    /// United Kingdom
121    UK,
122    /// Local/on-premises (any jurisdiction)
123    LOCAL,
124}
125
126impl RegionYaml {
127    /// Converts to string for storage.
128    #[must_use]
129    pub fn as_str(&self) -> &'static str {
130        match self {
131            Self::US => "US",
132            Self::EU => "EU",
133            Self::EEA => "EEA",
134            Self::CH => "CH",
135            Self::CN => "CN",
136            Self::JP => "JP",
137            Self::UK => "UK",
138            Self::LOCAL => "LOCAL",
139        }
140    }
141}
142
143/// Compliance certification enum.
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
145pub enum ComplianceYaml {
146    /// General Data Protection Regulation (EU)
147    GDPR,
148    /// Service Organization Control 2
149    SOC2,
150    /// Health Insurance Portability and Accountability Act
151    HIPAA,
152}
153
154impl From<ComplianceYaml> for ComplianceLevel {
155    fn from(c: ComplianceYaml) -> Self {
156        match c {
157            ComplianceYaml::GDPR => ComplianceLevel::GDPR,
158            ComplianceYaml::SOC2 => ComplianceLevel::SOC2,
159            ComplianceYaml::HIPAA => ComplianceLevel::HIPAA,
160        }
161    }
162}
163
164/// Cost class for model pricing tier.
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
166pub enum CostClassYaml {
167    /// Very low cost (e.g., Haiku, GPT-3.5, local models)
168    VeryLow,
169    /// Low cost (e.g., Sonnet, GPT-4o)
170    Low,
171    /// Medium cost (e.g., GPT-4 Turbo)
172    Medium,
173    /// High cost (e.g., Opus, o1-mini)
174    High,
175    /// Very high cost (e.g., o1-preview)
176    VeryHigh,
177}
178
179impl From<CostClassYaml> for CostClass {
180    fn from(c: CostClassYaml) -> Self {
181        match c {
182            CostClassYaml::VeryLow => CostClass::VeryLow,
183            CostClassYaml::Low => CostClass::Low,
184            CostClassYaml::Medium => CostClass::Medium,
185            CostClassYaml::High => CostClass::High,
186            CostClassYaml::VeryHigh => CostClass::VeryHigh,
187        }
188    }
189}
190
191/// Model capability flags.
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
193#[serde(rename_all = "snake_case")]
194pub enum CapabilityYaml {
195    /// Function/tool calling support
196    ToolUse,
197    /// Image understanding
198    Vision,
199    /// JSON mode / schema enforcement
200    StructuredOutput,
201    /// Code generation/understanding
202    Code,
203    /// Multi-step logical reasoning
204    Reasoning,
205    /// Good performance across languages
206    Multilingual,
207    /// Real-time web information retrieval
208    WebSearch,
209    /// Audio input/output support
210    Audio,
211    /// Image generation support
212    ImageGeneration,
213    /// Streaming responses
214    Streaming,
215    /// Logprobs support
216    Logprobs,
217    /// Deterministic seed support
218    Seed,
219    /// Tool choice (e.g., required/none/auto)
220    ToolChoice,
221    /// Parallel tool call support
222    ParallelToolCalls,
223    /// Prompt caching support
224    PromptCaching,
225    /// Built-in file search retrieval
226    FileSearch,
227    /// Built-in code interpreter / sandbox execution
228    CodeInterpreter,
229    /// Built-in browser automation / computer use
230    ComputerUse,
231    /// Tool-level web search (native search tool)
232    ToolSearch,
233    /// Model Context Protocol tool support
234    Mcp,
235    /// Hosted shell tool support
236    HostedShell,
237    /// Apply-patch tool support
238    ApplyPatch,
239    /// Native context compaction support
240    NativeCompaction,
241    /// Reasoning effort controls (e.g., low/medium/high)
242    ReasoningEffort,
243    /// Strong content generation / business writing
244    ContentGeneration,
245    /// Business acumen (financial, strategic, market analysis)
246    BusinessAcumen,
247}
248
249impl CapabilityYaml {
250    /// Stable `snake_case` string representation used in API responses.
251    #[must_use]
252    pub fn as_str(&self) -> &'static str {
253        match self {
254            Self::ToolUse => "tool_use",
255            Self::Vision => "vision",
256            Self::StructuredOutput => "structured_output",
257            Self::Code => "code",
258            Self::Reasoning => "reasoning",
259            Self::Multilingual => "multilingual",
260            Self::WebSearch => "web_search",
261            Self::Audio => "audio",
262            Self::ImageGeneration => "image_generation",
263            Self::Streaming => "streaming",
264            Self::Logprobs => "logprobs",
265            Self::Seed => "seed",
266            Self::ToolChoice => "tool_choice",
267            Self::ParallelToolCalls => "parallel_tool_calls",
268            Self::PromptCaching => "prompt_caching",
269            Self::FileSearch => "file_search",
270            Self::CodeInterpreter => "code_interpreter",
271            Self::ComputerUse => "computer_use",
272            Self::ToolSearch => "tool_search",
273            Self::Mcp => "mcp",
274            Self::HostedShell => "hosted_shell",
275            Self::ApplyPatch => "apply_patch",
276            Self::NativeCompaction => "native_compaction",
277            Self::ReasoningEffort => "reasoning_effort",
278            Self::ContentGeneration => "content_generation",
279            Self::BusinessAcumen => "business_acumen",
280        }
281    }
282}
283
284/// Supported reasoning effort level.
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
286#[serde(rename_all = "snake_case")]
287pub enum ReasoningEffortYaml {
288    /// Disable explicit chain-of-thought style effort controls.
289    None,
290    /// Minimal extra reasoning.
291    Minimal,
292    /// Low extra reasoning.
293    Low,
294    /// Medium extra reasoning.
295    Medium,
296    /// High extra reasoning.
297    High,
298    /// Extra-high reasoning.
299    Xhigh,
300}
301
302impl ReasoningEffortYaml {
303    /// Stable `snake_case` string representation used in API responses.
304    #[must_use]
305    pub fn as_str(&self) -> &'static str {
306        match self {
307            Self::None => "none",
308            Self::Minimal => "minimal",
309            Self::Low => "low",
310            Self::Medium => "medium",
311            Self::High => "high",
312            Self::Xhigh => "xhigh",
313        }
314    }
315}
316
317/// Model type classification.
318#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
319#[serde(rename_all = "snake_case")]
320pub enum ModelTypeYaml {
321    /// LLM for chat/completion (default)
322    #[default]
323    Llm,
324    /// Vector embedding model
325    Embedding,
326    /// Cross-encoder reranking model
327    Reranker,
328    /// OCR / Document AI model
329    Ocr,
330}
331
332/// Model architecture type.
333#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
334#[serde(rename_all = "snake_case")]
335pub enum ArchitectureYaml {
336    /// Traditional transformer (all parameters active).
337    #[default]
338    Dense,
339    /// Mixture of Experts (only subset active per forward pass).
340    Moe,
341    /// Hybrid architecture (e.g., Jamba's Mamba-Transformer).
342    Hybrid,
343}
344
345/// Input modality type.
346#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
347#[serde(rename_all = "snake_case")]
348pub enum ModalityYaml {
349    /// Text input/output.
350    Text,
351    /// Image input.
352    Image,
353    /// Video input.
354    Video,
355    /// Audio input.
356    Audio,
357}
358
359/// Agentic capabilities configuration.
360#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
361#[serde(deny_unknown_fields)]
362pub struct AgenticYaml {
363    /// Maximum number of parallel agents this model can orchestrate.
364    #[serde(default)]
365    pub max_parallel_agents: Option<u32>,
366    /// Whether the model supports agent orchestration/swarm.
367    #[serde(default)]
368    pub supports_orchestration: bool,
369}
370
371/// Pricing information (USD per million tokens).
372#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
373#[serde(deny_unknown_fields)]
374pub struct PricingYaml {
375    /// Input price per million tokens (USD).
376    #[serde(default)]
377    pub input_per_m: Option<f64>,
378    /// Output price per million tokens (USD).
379    #[serde(default)]
380    pub output_per_m: Option<f64>,
381}
382
383/// Rate limit information (provider- or model-specific).
384#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
385#[serde(deny_unknown_fields)]
386pub struct RateLimitsYaml {
387    /// Requests per minute.
388    #[serde(default)]
389    pub requests_per_min: Option<u32>,
390    /// Tokens per minute.
391    #[serde(default)]
392    pub tokens_per_min: Option<u32>,
393    /// Requests per day.
394    #[serde(default)]
395    pub requests_per_day: Option<u32>,
396    /// Maximum concurrent requests.
397    #[serde(default)]
398    pub concurrent_requests: Option<u32>,
399}
400
401/// A model entry in the registry.
402#[derive(Debug, Deserialize, JsonSchema)]
403#[serde(deny_unknown_fields)]
404pub struct ModelYaml {
405    /// Cost class - validated at parse time.
406    pub cost_class: CostClassYaml,
407    /// Typical latency in milliseconds (must be > 0).
408    pub typical_latency_ms: u32,
409    /// Quality score (must be 0.0-1.0).
410    pub quality: f64,
411    /// Context window size in tokens.
412    #[serde(default = "default_context_tokens")]
413    pub context_tokens: usize,
414    /// Capabilities list - validated at parse time.
415    #[serde(default)]
416    pub capabilities: Vec<CapabilityYaml>,
417    /// Model type - validated at parse time.
418    #[serde(default, rename = "type")]
419    pub model_type: ModelTypeYaml,
420    /// Embedding dimensions (required for embedding models).
421    #[serde(default)]
422    pub dimensions: Option<usize>,
423
424    // === ENRICHED SCHEMA ===
425    /// Model architecture (dense, moe, hybrid).
426    #[serde(default)]
427    pub architecture: ArchitectureYaml,
428    /// Total parameters in billions.
429    #[serde(default)]
430    pub total_params_b: Option<f64>,
431    /// Active parameters per forward pass in billions (for `MoE` models).
432    #[serde(default)]
433    pub active_params_b: Option<f64>,
434    /// Maximum output tokens.
435    #[serde(default)]
436    pub max_output_tokens: Option<usize>,
437    /// Whether the model is native multimodal (trained on mixed modalities).
438    #[serde(default)]
439    pub native_multimodal: bool,
440    /// Supported input modalities.
441    #[serde(default)]
442    pub modalities: Vec<ModalityYaml>,
443    /// Agentic/swarm capabilities.
444    #[serde(default)]
445    pub agentic: Option<AgenticYaml>,
446    /// Whether the model supports extended thinking/reasoning mode.
447    #[serde(default)]
448    pub thinking_mode: bool,
449    /// Supported reasoning effort levels (e.g., [low, medium, high]).
450    #[serde(default)]
451    pub reasoning_effort_levels: Vec<ReasoningEffortYaml>,
452    /// Whether the model supports native context compaction.
453    #[serde(default)]
454    pub native_compaction: bool,
455    /// Model ID of the thinking variant (if this is the base model).
456    #[serde(default)]
457    pub thinking_variant: Option<String>,
458    /// Pricing information.
459    #[serde(default)]
460    pub pricing: Option<PricingYaml>,
461    /// Model publisher or organization (e.g., `OpenAI`, Anthropic).
462    #[serde(default)]
463    pub publisher: Option<String>,
464    /// Model family name (e.g., Claude, GPT, Llama).
465    #[serde(default)]
466    pub family: Option<String>,
467    /// Release date (ISO-8601 format recommended).
468    #[serde(default)]
469    pub release_date: Option<String>,
470    /// Training data cutoff date (ISO-8601 format recommended).
471    #[serde(default)]
472    pub training_cutoff: Option<String>,
473    /// Whether model weights are openly available.
474    #[serde(default)]
475    pub open_weights: bool,
476    /// License identifier or URL.
477    #[serde(default)]
478    pub license: Option<String>,
479    /// Whether the model is deprecated.
480    #[serde(default)]
481    pub deprecated: bool,
482    /// Whether the model is in beta/preview.
483    #[serde(default)]
484    pub beta: bool,
485    /// Benchmark scores (keyed by benchmark name).
486    #[serde(default)]
487    pub benchmarks: HashMap<String, f64>,
488    /// Free-form tags for routing or promotion.
489    #[serde(default)]
490    pub tags: Vec<String>,
491    /// Rate limit information (if published).
492    #[serde(default)]
493    pub rate_limits: Option<RateLimitsYaml>,
494    /// Free-form notes.
495    #[serde(default)]
496    pub notes: Option<String>,
497}
498
499fn default_context_tokens() -> usize {
500    8192
501}
502
503/// Generates JSON Schema for the model registry.
504///
505/// This can be used for:
506/// - IDE autocompletion in YAML files
507/// - Pre-runtime validation
508/// - Documentation generation
509///
510/// # Example
511///
512/// ```
513/// use converge_provider::registry_loader::generate_schema;
514///
515/// let schema = generate_schema();
516/// println!("{}", serde_json::to_string_pretty(&schema).unwrap());
517/// ```
518#[must_use]
519pub fn generate_schema() -> schemars::schema::RootSchema {
520    schemars::schema_for!(RegistryYaml)
521}
522
523// =============================================================================
524// LOADED REGISTRY
525// =============================================================================
526
527/// Provider type.
528#[derive(Debug, Clone, Copy, PartialEq, Eq)]
529pub enum ProviderType {
530    /// Direct API access to model provider.
531    Direct,
532    /// Routes to multiple underlying providers (adds latency overhead).
533    Aggregator,
534}
535
536/// A loaded provider with its models.
537#[derive(Debug, Clone)]
538pub struct LoadedProvider {
539    /// Provider ID (e.g., "anthropic").
540    pub id: String,
541    /// Environment variable name for API key.
542    pub env_key: String,
543    /// Optional secondary env key.
544    pub env_key_secondary: Option<String>,
545    /// URL to get an API key.
546    pub key_url: String,
547    /// API endpoint URL.
548    pub api_url: String,
549    /// ISO country code.
550    pub country: String,
551    /// Region.
552    pub region: String,
553    /// Compliance certifications.
554    pub compliance: Vec<ComplianceLevel>,
555    /// Provider type (direct or aggregator).
556    pub provider_type: ProviderType,
557    /// Models available.
558    pub models: Vec<LoadedModel>,
559}
560
561impl LoadedProvider {
562    /// Checks if this provider is available (env key is set).
563    #[must_use]
564    pub fn is_available(&self) -> bool {
565        let primary_ok = std::env::var(&self.env_key).is_ok();
566        let secondary_ok = self
567            .env_key_secondary
568            .as_ref()
569            .map(|k| std::env::var(k).is_ok())
570            .unwrap_or(true);
571        primary_ok && secondary_ok
572    }
573
574    /// Returns the API key from environment (if available).
575    #[must_use]
576    pub fn api_key(&self) -> Option<String> {
577        std::env::var(&self.env_key).ok()
578    }
579
580    /// Returns the secondary API key from environment (if available).
581    #[must_use]
582    pub fn secondary_api_key(&self) -> Option<String> {
583        self.env_key_secondary
584            .as_ref()
585            .and_then(|k| std::env::var(k).ok())
586    }
587}
588
589/// Model architecture.
590#[derive(Debug, Clone, Copy, PartialEq, Eq)]
591pub enum Architecture {
592    /// Traditional transformer (all parameters active).
593    Dense,
594    /// Mixture of Experts (only subset active per forward pass).
595    Moe,
596    /// Hybrid architecture (e.g., Jamba's Mamba-Transformer).
597    Hybrid,
598}
599
600/// Input modality.
601#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
602pub enum Modality {
603    /// Text input/output.
604    Text,
605    /// Image input.
606    Image,
607    /// Video input.
608    Video,
609    /// Audio input.
610    Audio,
611}
612
613/// Agentic capabilities.
614#[derive(Debug, Clone, Default)]
615pub struct AgenticCapabilities {
616    /// Maximum number of parallel agents this model can orchestrate.
617    pub max_parallel_agents: Option<u32>,
618    /// Whether the model supports agent orchestration/swarm.
619    pub supports_orchestration: bool,
620}
621
622/// Pricing information (USD per million tokens).
623#[derive(Debug, Clone, Default)]
624pub struct Pricing {
625    /// Input price per million tokens (USD).
626    pub input_per_m: Option<f64>,
627    /// Output price per million tokens (USD).
628    pub output_per_m: Option<f64>,
629}
630
631/// Rate limit information (provider- or model-specific).
632#[derive(Debug, Clone, Default)]
633pub struct RateLimits {
634    /// Requests per minute.
635    pub requests_per_min: Option<u32>,
636    /// Tokens per minute.
637    pub tokens_per_min: Option<u32>,
638    /// Requests per day.
639    pub requests_per_day: Option<u32>,
640    /// Maximum concurrent requests.
641    pub concurrent_requests: Option<u32>,
642}
643
644/// Reasoning effort level.
645#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
646pub enum ReasoningEffort {
647    /// Disable explicit chain-of-thought style effort controls.
648    None,
649    /// Minimal extra reasoning.
650    Minimal,
651    /// Low extra reasoning.
652    Low,
653    /// Medium extra reasoning.
654    Medium,
655    /// High extra reasoning.
656    High,
657    /// Extra-high reasoning.
658    Xhigh,
659}
660
661impl ReasoningEffort {
662    /// Stable `snake_case` string representation used in API responses.
663    #[must_use]
664    pub fn as_str(&self) -> &'static str {
665        match self {
666            Self::None => "none",
667            Self::Minimal => "minimal",
668            Self::Low => "low",
669            Self::Medium => "medium",
670            Self::High => "high",
671            Self::Xhigh => "xhigh",
672        }
673    }
674}
675
676/// A loaded model.
677#[derive(Debug, Clone)]
678#[allow(clippy::struct_excessive_bools)]
679pub struct LoadedModel {
680    /// Model ID.
681    pub id: String,
682    /// Cost class.
683    pub cost_class: CostClass,
684    /// Typical latency in ms.
685    pub typical_latency_ms: u32,
686    /// Quality score.
687    pub quality: f64,
688    /// Context tokens.
689    pub context_tokens: usize,
690    /// Model type (llm, embedding, reranker).
691    pub model_type: ModelType,
692    /// Embedding dimensions (for embedding models).
693    pub dimensions: Option<usize>,
694    /// Full capability list (`snake_case` enum values from YAML).
695    pub capabilities: Vec<CapabilityYaml>,
696    // Capabilities
697    /// Tool use support.
698    pub supports_tool_use: bool,
699    /// Vision support.
700    pub supports_vision: bool,
701    /// Structured output support.
702    pub supports_structured_output: bool,
703    /// Code support.
704    pub supports_code: bool,
705    /// Reasoning support.
706    pub supports_reasoning: bool,
707    /// Multilingual support.
708    pub supports_multilingual: bool,
709    /// Web search support.
710    pub supports_web_search: bool,
711    /// Content generation / business writing support.
712    pub supports_content_generation: bool,
713    /// Business acumen (financial, strategic, market analysis).
714    pub supports_business_acumen: bool,
715
716    // === ENRICHED FIELDS ===
717    /// Model architecture (dense, moe, hybrid).
718    pub architecture: Architecture,
719    /// Total parameters in billions.
720    pub total_params_b: Option<f64>,
721    /// Active parameters per forward pass in billions (for `MoE` models).
722    pub active_params_b: Option<f64>,
723    /// Maximum output tokens.
724    pub max_output_tokens: Option<usize>,
725    /// Whether the model is native multimodal (trained on mixed modalities).
726    pub native_multimodal: bool,
727    /// Supported input modalities.
728    pub modalities: Vec<Modality>,
729    /// Agentic/swarm capabilities.
730    pub agentic: Option<AgenticCapabilities>,
731    /// Whether the model supports extended thinking/reasoning mode.
732    pub thinking_mode: bool,
733    /// Supported reasoning effort levels.
734    pub reasoning_effort_levels: Vec<ReasoningEffort>,
735    /// Whether the model supports native context compaction.
736    pub native_compaction: bool,
737    /// Model ID of the thinking variant (if this is the base model).
738    pub thinking_variant: Option<String>,
739    /// Pricing information.
740    pub pricing: Option<Pricing>,
741    /// Model publisher or organization (e.g., `OpenAI`, Anthropic).
742    pub publisher: Option<String>,
743    /// Model family name (e.g., Claude, GPT, Llama).
744    pub family: Option<String>,
745    /// Release date (ISO-8601 format recommended).
746    pub release_date: Option<String>,
747    /// Training data cutoff date (ISO-8601 format recommended).
748    pub training_cutoff: Option<String>,
749    /// Whether model weights are openly available.
750    pub open_weights: bool,
751    /// License identifier or URL.
752    pub license: Option<String>,
753    /// Whether the model is deprecated.
754    pub deprecated: bool,
755    /// Whether the model is in beta/preview.
756    pub beta: bool,
757    /// Benchmark scores (keyed by benchmark name).
758    pub benchmarks: HashMap<String, f64>,
759    /// Free-form tags for routing or promotion.
760    pub tags: Vec<String>,
761    /// Rate limit information (if published).
762    pub rate_limits: Option<RateLimits>,
763    /// Free-form notes.
764    pub notes: Option<String>,
765}
766
767/// Model type.
768#[derive(Debug, Clone, Copy, PartialEq, Eq)]
769pub enum ModelType {
770    /// LLM for chat/completion.
771    Llm,
772    /// Embedding model.
773    Embedding,
774    /// Reranker model.
775    Reranker,
776    /// OCR / Document AI model.
777    Ocr,
778}
779
780/// The loaded model registry.
781#[derive(Debug, Clone)]
782pub struct LoadedRegistry {
783    /// All providers.
784    providers: Vec<LoadedProvider>,
785}
786
787impl LoadedRegistry {
788    /// Returns all providers.
789    #[must_use]
790    pub fn providers(&self) -> &[LoadedProvider] {
791        &self.providers
792    }
793
794    /// Returns available providers (with API keys set).
795    #[must_use]
796    pub fn available_providers(&self) -> Vec<&LoadedProvider> {
797        self.providers.iter().filter(|p| p.is_available()).collect()
798    }
799
800    /// Finds a provider by ID.
801    #[must_use]
802    pub fn get_provider(&self, id: &str) -> Option<&LoadedProvider> {
803        self.providers.iter().find(|p| p.id == id)
804    }
805
806    /// Returns all LLM models.
807    #[must_use]
808    pub fn llm_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
809        self.providers
810            .iter()
811            .flat_map(|p| {
812                p.models
813                    .iter()
814                    .filter(|m| m.model_type == ModelType::Llm)
815                    .map(move |m| (p, m))
816            })
817            .collect()
818    }
819
820    /// Returns all embedding models.
821    #[must_use]
822    pub fn embedding_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
823        self.providers
824            .iter()
825            .flat_map(|p| {
826                p.models
827                    .iter()
828                    .filter(|m| m.model_type == ModelType::Embedding)
829                    .map(move |m| (p, m))
830            })
831            .collect()
832    }
833
834    /// Returns all reranker models.
835    #[must_use]
836    pub fn reranker_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
837        self.providers
838            .iter()
839            .flat_map(|p| {
840                p.models
841                    .iter()
842                    .filter(|m| m.model_type == ModelType::Reranker)
843                    .map(move |m| (p, m))
844            })
845            .collect()
846    }
847
848    /// Converts to a `ModelSelector` for use with the selection system.
849    #[must_use]
850    pub fn to_model_selector(&self) -> ModelSelector {
851        let mut selector = ModelSelector::empty();
852
853        for provider in &self.providers {
854            for model in &provider.models {
855                if model.model_type != ModelType::Llm {
856                    continue; // ModelSelector is for LLMs only
857                }
858
859                let data_sovereignty = match provider.region.as_str() {
860                    "EU" | "EEA" => DataSovereignty::EU,
861                    "CH" => DataSovereignty::Switzerland,
862                    "CN" => DataSovereignty::China,
863                    "US" => DataSovereignty::US,
864                    "LOCAL" => DataSovereignty::OnPremises,
865                    _ => DataSovereignty::Any,
866                };
867
868                let compliance = provider
869                    .compliance
870                    .first()
871                    .copied()
872                    .unwrap_or(ComplianceLevel::None);
873
874                let metadata = ModelMetadata::new(
875                    &provider.id,
876                    &model.id,
877                    model.cost_class,
878                    model.typical_latency_ms,
879                    model.quality,
880                )
881                .with_reasoning(model.supports_reasoning)
882                .with_web_search(model.supports_web_search)
883                .with_data_sovereignty(data_sovereignty)
884                .with_compliance(compliance)
885                .with_multilingual(model.supports_multilingual)
886                .with_context_tokens(model.context_tokens)
887                .with_tool_use(model.supports_tool_use)
888                .with_vision(model.supports_vision)
889                .with_structured_output(model.supports_structured_output)
890                .with_code(model.supports_code)
891                .with_content_generation(model.supports_content_generation)
892                .with_business_acumen(model.supports_business_acumen)
893                .with_location(&provider.country, &provider.region);
894
895                selector = selector.with_model(metadata);
896            }
897        }
898
899        selector
900    }
901
902    /// Prints a summary of all providers.
903    pub fn print_summary(&self) {
904        println!("Model Registry Summary");
905        println!("======================\n");
906
907        for provider in &self.providers {
908            let status = if provider.is_available() {
909                "✓ available"
910            } else {
911                "✗ no key"
912            };
913
914            println!(
915                "{} ({}) - {} models [{}]",
916                provider.id,
917                provider.region,
918                provider.models.len(),
919                status
920            );
921            println!("  Key URL: {}", provider.key_url);
922            println!("  API URL: {}", provider.api_url);
923            println!();
924        }
925    }
926}
927
928// =============================================================================
929// LOADING FUNCTIONS
930// =============================================================================
931
932/// Default path for the model registry relative to crate root.
933pub const DEFAULT_REGISTRY_PATH: &str = "converge-provider/config/models.yaml";
934
935/// Loads the registry from the default path.
936///
937/// Tries these paths in order:
938/// 1. `converge-provider/config/models.yaml` (when run from workspace root)
939/// 2. `config/models.yaml` (when run from converge-provider directory)
940/// 3. `CONVERGE_MODELS_PATH` environment variable
941///
942/// # Errors
943///
944/// Returns error if the file cannot be read or parsed.
945pub fn load_registry() -> Result<LoadedRegistry, RegistryError> {
946    // Check environment variable first
947    if let Ok(path) = std::env::var("CONVERGE_MODELS_PATH") {
948        return load_registry_from_path(&path);
949    }
950
951    // Try workspace-relative path
952    if std::path::Path::new(DEFAULT_REGISTRY_PATH).exists() {
953        return load_registry_from_path(DEFAULT_REGISTRY_PATH);
954    }
955
956    // Try crate-relative path
957    let crate_path = "config/models.yaml";
958    if std::path::Path::new(crate_path).exists() {
959        return load_registry_from_path(crate_path);
960    }
961
962    // Fall back to compiled-in default
963    load_registry_from_str(include_str!("../config/models.yaml"))
964}
965
966/// Loads the registry from a specific path.
967///
968/// # Errors
969///
970/// Returns error if the file cannot be read or parsed.
971pub fn load_registry_from_path(path: impl AsRef<Path>) -> Result<LoadedRegistry, RegistryError> {
972    let content = std::fs::read_to_string(path)?;
973    load_registry_from_str(&content)
974}
975
976/// Loads the registry from a YAML string.
977///
978/// # Errors
979///
980/// Returns error if the YAML cannot be parsed or validation fails.
981pub fn load_registry_from_str(yaml: &str) -> Result<LoadedRegistry, RegistryError> {
982    let registry_yaml: RegistryYaml = serde_yaml::from_str(yaml)?;
983
984    let mut providers = Vec::new();
985    let mut errors = Vec::new();
986
987    for (provider_id, provider_yaml) in registry_yaml.providers {
988        // Validate provider
989        if let Err(e) = validate_provider(&provider_id, &provider_yaml) {
990            errors.push(e);
991            continue;
992        }
993
994        let compliance = provider_yaml
995            .compliance
996            .iter()
997            .map(|c| ComplianceLevel::from(*c))
998            .collect();
999
1000        let mut models = Vec::new();
1001
1002        for (model_id, model_yaml) in provider_yaml.models {
1003            // Validate model
1004            if let Err(e) = validate_model(&provider_id, &model_id, &model_yaml) {
1005                errors.push(e);
1006                continue;
1007            }
1008
1009            let capabilities: std::collections::HashSet<_> =
1010                model_yaml.capabilities.iter().copied().collect();
1011
1012            // Map modalities
1013            let modalities: Vec<Modality> = model_yaml
1014                .modalities
1015                .iter()
1016                .map(|m| match m {
1017                    ModalityYaml::Text => Modality::Text,
1018                    ModalityYaml::Image => Modality::Image,
1019                    ModalityYaml::Video => Modality::Video,
1020                    ModalityYaml::Audio => Modality::Audio,
1021                })
1022                .collect();
1023
1024            // Map reasoning effort levels
1025            let reasoning_effort_levels = model_yaml
1026                .reasoning_effort_levels
1027                .iter()
1028                .copied()
1029                .map(ReasoningEffort::from)
1030                .collect();
1031
1032            // Map agentic capabilities
1033            let agentic = model_yaml.agentic.as_ref().map(|a| AgenticCapabilities {
1034                max_parallel_agents: a.max_parallel_agents,
1035                supports_orchestration: a.supports_orchestration,
1036            });
1037
1038            // Map pricing
1039            let pricing = model_yaml.pricing.as_ref().map(|p| Pricing {
1040                input_per_m: p.input_per_m,
1041                output_per_m: p.output_per_m,
1042            });
1043
1044            // Map rate limits
1045            let rate_limits = model_yaml.rate_limits.as_ref().map(|r| RateLimits {
1046                requests_per_min: r.requests_per_min,
1047                tokens_per_min: r.tokens_per_min,
1048                requests_per_day: r.requests_per_day,
1049                concurrent_requests: r.concurrent_requests,
1050            });
1051
1052            let model = LoadedModel {
1053                id: model_id,
1054                cost_class: model_yaml.cost_class.into(),
1055                typical_latency_ms: model_yaml.typical_latency_ms,
1056                quality: model_yaml.quality,
1057                context_tokens: model_yaml.context_tokens,
1058                model_type: model_yaml.model_type.into(),
1059                dimensions: model_yaml.dimensions,
1060                capabilities: model_yaml.capabilities.clone(),
1061                supports_tool_use: capabilities.contains(&CapabilityYaml::ToolUse),
1062                supports_vision: capabilities.contains(&CapabilityYaml::Vision),
1063                supports_structured_output: capabilities
1064                    .contains(&CapabilityYaml::StructuredOutput),
1065                supports_code: capabilities.contains(&CapabilityYaml::Code),
1066                supports_reasoning: capabilities.contains(&CapabilityYaml::Reasoning),
1067                supports_multilingual: capabilities.contains(&CapabilityYaml::Multilingual),
1068                supports_web_search: capabilities.contains(&CapabilityYaml::WebSearch),
1069                supports_content_generation: capabilities
1070                    .contains(&CapabilityYaml::ContentGeneration),
1071                supports_business_acumen: capabilities.contains(&CapabilityYaml::BusinessAcumen),
1072                // Enriched fields
1073                architecture: model_yaml.architecture.into(),
1074                total_params_b: model_yaml.total_params_b,
1075                active_params_b: model_yaml.active_params_b,
1076                max_output_tokens: model_yaml.max_output_tokens,
1077                native_multimodal: model_yaml.native_multimodal,
1078                modalities,
1079                agentic,
1080                thinking_mode: model_yaml.thinking_mode,
1081                reasoning_effort_levels,
1082                native_compaction: model_yaml.native_compaction,
1083                thinking_variant: model_yaml.thinking_variant.clone(),
1084                pricing,
1085                publisher: model_yaml.publisher.clone(),
1086                family: model_yaml.family.clone(),
1087                release_date: model_yaml.release_date.clone(),
1088                training_cutoff: model_yaml.training_cutoff.clone(),
1089                open_weights: model_yaml.open_weights,
1090                license: model_yaml.license.clone(),
1091                deprecated: model_yaml.deprecated,
1092                beta: model_yaml.beta,
1093                benchmarks: model_yaml.benchmarks.clone(),
1094                tags: model_yaml.tags.clone(),
1095                rate_limits,
1096                notes: model_yaml.notes.clone(),
1097            };
1098
1099            models.push(model);
1100        }
1101
1102        // Sort models by id for consistent ordering
1103        models.sort_by(|a, b| a.id.cmp(&b.id));
1104
1105        let provider = LoadedProvider {
1106            id: provider_id,
1107            env_key: provider_yaml.env_key,
1108            env_key_secondary: provider_yaml.env_key_secondary,
1109            key_url: provider_yaml.key_url,
1110            api_url: provider_yaml.api_url,
1111            country: provider_yaml.country,
1112            region: provider_yaml.region.as_str().to_string(),
1113            compliance,
1114            provider_type: provider_yaml.provider_type.into(),
1115            models,
1116        };
1117
1118        providers.push(provider);
1119    }
1120
1121    // Fail if there were any validation errors
1122    if !errors.is_empty() {
1123        return Err(RegistryError::ValidationError(errors.join("; ")));
1124    }
1125
1126    // Sort providers alphabetically for consistent ordering
1127    providers.sort_by(|a, b| a.id.cmp(&b.id));
1128
1129    Ok(LoadedRegistry { providers })
1130}
1131
1132/// Validates a provider entry.
1133fn validate_provider(id: &str, provider: &ProviderYaml) -> Result<(), String> {
1134    // Validate env_key is not empty
1135    if provider.env_key.is_empty() {
1136        return Err(format!("Provider '{id}': env_key cannot be empty"));
1137    }
1138
1139    // Validate URLs are valid
1140    if !provider.key_url.starts_with("http://") && !provider.key_url.starts_with("https://") {
1141        return Err(format!(
1142            "Provider '{id}': key_url must be a valid URL, got '{}'",
1143            provider.key_url
1144        ));
1145    }
1146
1147    if !provider.api_url.starts_with("http://") && !provider.api_url.starts_with("https://") {
1148        return Err(format!(
1149            "Provider '{id}': api_url must be a valid URL, got '{}'",
1150            provider.api_url
1151        ));
1152    }
1153
1154    // Validate country code (2 letters or LOCAL)
1155    if provider.country != "LOCAL" && provider.country.len() != 2 {
1156        return Err(format!(
1157            "Provider '{id}': country must be 2-letter ISO code or 'LOCAL', got '{}'",
1158            provider.country
1159        ));
1160    }
1161
1162    // Validate has at least one model
1163    if provider.models.is_empty() {
1164        return Err(format!("Provider '{id}': must have at least one model"));
1165    }
1166
1167    Ok(())
1168}
1169
1170/// Validates a model entry.
1171fn validate_model(provider_id: &str, model_id: &str, model: &ModelYaml) -> Result<(), String> {
1172    // Validate quality is in range
1173    if !(0.0..=1.0).contains(&model.quality) {
1174        return Err(format!(
1175            "Model '{provider_id}/{model_id}': quality must be 0.0-1.0, got {}",
1176            model.quality
1177        ));
1178    }
1179
1180    // Validate latency is reasonable
1181    if model.typical_latency_ms == 0 {
1182        return Err(format!(
1183            "Model '{provider_id}/{model_id}': typical_latency_ms must be > 0"
1184        ));
1185    }
1186
1187    // Validate context_tokens is reasonable
1188    if model.context_tokens == 0 {
1189        return Err(format!(
1190            "Model '{provider_id}/{model_id}': context_tokens must be > 0"
1191        ));
1192    }
1193
1194    // Validate embedding models have dimensions
1195    if model.model_type == ModelTypeYaml::Embedding && model.dimensions.is_none() {
1196        return Err(format!(
1197            "Model '{provider_id}/{model_id}': embedding models must specify dimensions"
1198        ));
1199    }
1200
1201    Ok(())
1202}
1203
1204impl From<ModelTypeYaml> for ModelType {
1205    fn from(t: ModelTypeYaml) -> Self {
1206        match t {
1207            ModelTypeYaml::Llm => ModelType::Llm,
1208            ModelTypeYaml::Embedding => ModelType::Embedding,
1209            ModelTypeYaml::Reranker => ModelType::Reranker,
1210            ModelTypeYaml::Ocr => ModelType::Ocr,
1211        }
1212    }
1213}
1214
1215impl From<ArchitectureYaml> for Architecture {
1216    fn from(a: ArchitectureYaml) -> Self {
1217        match a {
1218            ArchitectureYaml::Dense => Architecture::Dense,
1219            ArchitectureYaml::Moe => Architecture::Moe,
1220            ArchitectureYaml::Hybrid => Architecture::Hybrid,
1221        }
1222    }
1223}
1224
1225impl From<ReasoningEffortYaml> for ReasoningEffort {
1226    fn from(effort: ReasoningEffortYaml) -> Self {
1227        match effort {
1228            ReasoningEffortYaml::None => Self::None,
1229            ReasoningEffortYaml::Minimal => Self::Minimal,
1230            ReasoningEffortYaml::Low => Self::Low,
1231            ReasoningEffortYaml::Medium => Self::Medium,
1232            ReasoningEffortYaml::High => Self::High,
1233            ReasoningEffortYaml::Xhigh => Self::Xhigh,
1234        }
1235    }
1236}
1237
1238impl From<ProviderTypeYaml> for ProviderType {
1239    fn from(p: ProviderTypeYaml) -> Self {
1240        match p {
1241            ProviderTypeYaml::Direct => ProviderType::Direct,
1242            ProviderTypeYaml::Aggregator => ProviderType::Aggregator,
1243        }
1244    }
1245}
1246
1247// =============================================================================
1248// TESTS
1249// =============================================================================
1250
1251#[cfg(test)]
1252mod tests {
1253    use super::*;
1254
1255    const TEST_YAML: &str = r"
1256providers:
1257  test-provider:
1258    env_key: TEST_API_KEY
1259    key_url: https://test.com/keys
1260    api_url: https://api.test.com/v1
1261    country: US
1262    region: US
1263    models:
1264      test-model:
1265        cost_class: Low
1266        typical_latency_ms: 2000
1267        quality: 0.85
1268        context_tokens: 128000
1269        capabilities: [tool_use, reasoning, code]
1270
1271      test-embedding:
1272        cost_class: VeryLow
1273        typical_latency_ms: 100
1274        quality: 0.80
1275        context_tokens: 8192
1276        capabilities: []
1277        type: embedding
1278        dimensions: 1024
1279";
1280
1281    const INVALID_COST_CLASS_YAML: &str = r"
1282providers:
1283  bad-provider:
1284    env_key: TEST_KEY
1285    key_url: https://test.com/keys
1286    api_url: https://api.test.com/v1
1287    country: US
1288    region: US
1289    models:
1290      bad-model:
1291        cost_class: SuperLow
1292        typical_latency_ms: 100
1293        quality: 0.5
1294";
1295
1296    const INVALID_CAPABILITY_YAML: &str = r"
1297providers:
1298  bad-provider:
1299    env_key: TEST_KEY
1300    key_url: https://test.com/keys
1301    api_url: https://api.test.com/v1
1302    country: US
1303    region: US
1304    models:
1305      bad-model:
1306        cost_class: Low
1307        typical_latency_ms: 100
1308        quality: 0.5
1309        capabilities: [tool_use, telepathy]
1310";
1311
1312    const INVALID_QUALITY_YAML: &str = r"
1313providers:
1314  bad-provider:
1315    env_key: TEST_KEY
1316    key_url: https://test.com/keys
1317    api_url: https://api.test.com/v1
1318    country: US
1319    region: US
1320    models:
1321      bad-model:
1322        cost_class: Low
1323        typical_latency_ms: 100
1324        quality: 1.5
1325";
1326
1327    const MISSING_DIMENSIONS_YAML: &str = r"
1328providers:
1329  bad-provider:
1330    env_key: TEST_KEY
1331    key_url: https://test.com/keys
1332    api_url: https://api.test.com/v1
1333    country: US
1334    region: US
1335    models:
1336      bad-embedding:
1337        cost_class: Low
1338        typical_latency_ms: 100
1339        quality: 0.5
1340        type: embedding
1341";
1342
1343    const UNKNOWN_FIELD_YAML: &str = r"
1344providers:
1345  bad-provider:
1346    env_key: TEST_KEY
1347    key_url: https://test.com/keys
1348    api_url: https://api.test.com/v1
1349    country: US
1350    region: US
1351    unknown_field: oops
1352    models:
1353      model:
1354        cost_class: Low
1355        typical_latency_ms: 100
1356        quality: 0.5
1357";
1358
1359    #[test]
1360    fn parse_yaml() {
1361        let registry = load_registry_from_str(TEST_YAML).unwrap();
1362        assert_eq!(registry.providers.len(), 1);
1363
1364        let provider = &registry.providers[0];
1365        assert_eq!(provider.id, "test-provider");
1366        assert_eq!(provider.key_url, "https://test.com/keys");
1367        assert_eq!(provider.api_url, "https://api.test.com/v1");
1368        assert_eq!(provider.models.len(), 2);
1369    }
1370
1371    #[test]
1372    fn parse_model_capabilities() {
1373        let registry = load_registry_from_str(TEST_YAML).unwrap();
1374        let provider = &registry.providers[0];
1375
1376        let llm = provider
1377            .models
1378            .iter()
1379            .find(|m| m.id == "test-model")
1380            .unwrap();
1381        assert!(llm.supports_tool_use);
1382        assert!(llm.supports_reasoning);
1383        assert!(llm.supports_code);
1384        assert!(!llm.supports_vision);
1385        assert_eq!(llm.model_type, ModelType::Llm);
1386    }
1387
1388    #[test]
1389    fn parse_embedding_model() {
1390        let registry = load_registry_from_str(TEST_YAML).unwrap();
1391        let provider = &registry.providers[0];
1392
1393        let embedding = provider
1394            .models
1395            .iter()
1396            .find(|m| m.id == "test-embedding")
1397            .unwrap();
1398        assert_eq!(embedding.model_type, ModelType::Embedding);
1399        assert_eq!(embedding.dimensions, Some(1024));
1400    }
1401
1402    #[test]
1403    fn filter_by_model_type() {
1404        let registry = load_registry_from_str(TEST_YAML).unwrap();
1405
1406        let llms = registry.llm_models();
1407        assert_eq!(llms.len(), 1);
1408        assert_eq!(llms[0].1.id, "test-model");
1409
1410        let embeddings = registry.embedding_models();
1411        assert_eq!(embeddings.len(), 1);
1412        assert_eq!(embeddings[0].1.id, "test-embedding");
1413    }
1414
1415    #[test]
1416    fn to_model_selector() {
1417        let registry = load_registry_from_str(TEST_YAML).unwrap();
1418        let selector = registry.to_model_selector();
1419
1420        // Should have 1 LLM model (embedding is excluded)
1421        let reqs = converge_core::model_selection::AgentRequirements::balanced();
1422        let satisfying = selector.list_satisfying(&reqs);
1423        assert_eq!(satisfying.len(), 1);
1424    }
1425
1426    #[test]
1427    fn provider_availability() {
1428        let registry = load_registry_from_str(TEST_YAML).unwrap();
1429        let provider = &registry.providers[0];
1430
1431        // Should not be available (TEST_API_KEY not set by default)
1432        // Note: We don't test setting env vars as it requires unsafe in Rust 2024
1433        let _ = provider.is_available(); // Just verify method works
1434    }
1435
1436    #[test]
1437    fn load_real_registry() {
1438        // This tests the compiled-in registry via include_str!
1439        let registry = load_registry().unwrap();
1440
1441        // Should have multiple providers
1442        assert!(
1443            registry.providers.len() >= 10,
1444            "Expected at least 10 providers"
1445        );
1446
1447        // Check some known providers exist
1448        let provider_ids: Vec<_> = registry.providers.iter().map(|p| p.id.as_str()).collect();
1449        assert!(provider_ids.contains(&"anthropic"), "Missing anthropic");
1450        assert!(provider_ids.contains(&"openai"), "Missing openai");
1451        assert!(provider_ids.contains(&"mistral"), "Missing mistral");
1452        assert!(provider_ids.contains(&"ollama"), "Missing ollama");
1453
1454        // Check anthropic has correct URLs
1455        let anthropic = registry.get_provider("anthropic").unwrap();
1456        assert_eq!(
1457            anthropic.key_url,
1458            "https://console.anthropic.com/settings/keys"
1459        );
1460        assert_eq!(anthropic.api_url, "https://api.anthropic.com/v1");
1461        assert_eq!(anthropic.env_key, "ANTHROPIC_API_KEY");
1462
1463        // Check ollama is marked as LOCAL
1464        let ollama = registry.get_provider("ollama").unwrap();
1465        assert_eq!(ollama.region, "LOCAL");
1466
1467        // Check we have LLM models
1468        let llms = registry.llm_models();
1469        assert!(llms.len() >= 30, "Expected at least 30 LLM models");
1470
1471        // Check we have embedding models
1472        let embeddings = registry.embedding_models();
1473        assert!(
1474            embeddings.len() >= 3,
1475            "Expected at least 3 embedding models"
1476        );
1477
1478        println!(
1479            "Loaded {} providers with {} LLM models and {} embedding models",
1480            registry.providers.len(),
1481            llms.len(),
1482            embeddings.len()
1483        );
1484    }
1485
1486    // =========================================================================
1487    // TYPE-SAFE VALIDATION TESTS
1488    // =========================================================================
1489
1490    #[test]
1491    fn rejects_invalid_cost_class() {
1492        let result = load_registry_from_str(INVALID_COST_CLASS_YAML);
1493        assert!(result.is_err());
1494        let err = result.unwrap_err().to_string();
1495        assert!(
1496            err.contains("SuperLow") || err.contains("unknown variant"),
1497            "Expected error about invalid cost class, got: {err}"
1498        );
1499    }
1500
1501    #[test]
1502    fn rejects_invalid_capability() {
1503        let result = load_registry_from_str(INVALID_CAPABILITY_YAML);
1504        assert!(result.is_err());
1505        let err = result.unwrap_err().to_string();
1506        assert!(
1507            err.contains("telepathy") || err.contains("unknown variant"),
1508            "Expected error about invalid capability, got: {err}"
1509        );
1510    }
1511
1512    #[test]
1513    fn rejects_invalid_quality() {
1514        let result = load_registry_from_str(INVALID_QUALITY_YAML);
1515        assert!(result.is_err());
1516        let err = result.unwrap_err().to_string();
1517        assert!(
1518            err.contains("quality") && err.contains("1.5"),
1519            "Expected error about quality out of range, got: {err}"
1520        );
1521    }
1522
1523    #[test]
1524    fn rejects_embedding_without_dimensions() {
1525        let result = load_registry_from_str(MISSING_DIMENSIONS_YAML);
1526        assert!(result.is_err());
1527        let err = result.unwrap_err().to_string();
1528        assert!(
1529            err.contains("dimensions"),
1530            "Expected error about missing dimensions, got: {err}"
1531        );
1532    }
1533
1534    #[test]
1535    fn rejects_unknown_fields() {
1536        let result = load_registry_from_str(UNKNOWN_FIELD_YAML);
1537        assert!(result.is_err());
1538        let err = result.unwrap_err().to_string();
1539        assert!(
1540            err.contains("unknown_field") || err.contains("unknown field"),
1541            "Expected error about unknown field, got: {err}"
1542        );
1543    }
1544
1545    #[test]
1546    fn rejects_invalid_region() {
1547        let yaml = r"
1548providers:
1549  bad:
1550    env_key: KEY
1551    key_url: https://test.com
1552    api_url: https://api.test.com
1553    country: US
1554    region: INVALID
1555    models:
1556      m:
1557        cost_class: Low
1558        typical_latency_ms: 100
1559        quality: 0.5
1560";
1561        let result = load_registry_from_str(yaml);
1562        assert!(result.is_err());
1563        let err = result.unwrap_err().to_string();
1564        assert!(
1565            err.contains("INVALID") || err.contains("unknown variant"),
1566            "Expected error about invalid region, got: {err}"
1567        );
1568    }
1569
1570    #[test]
1571    fn rejects_invalid_url() {
1572        let yaml = r"
1573providers:
1574  bad:
1575    env_key: KEY
1576    key_url: not-a-url
1577    api_url: https://api.test.com
1578    country: US
1579    region: US
1580    models:
1581      m:
1582        cost_class: Low
1583        typical_latency_ms: 100
1584        quality: 0.5
1585";
1586        let result = load_registry_from_str(yaml);
1587        assert!(result.is_err());
1588        let err = result.unwrap_err().to_string();
1589        assert!(
1590            err.contains("key_url") && err.contains("URL"),
1591            "Expected error about invalid URL, got: {err}"
1592        );
1593    }
1594
1595    #[test]
1596    fn rejects_zero_latency() {
1597        let yaml = r"
1598providers:
1599  bad:
1600    env_key: KEY
1601    key_url: https://test.com
1602    api_url: https://api.test.com
1603    country: US
1604    region: US
1605    models:
1606      m:
1607        cost_class: Low
1608        typical_latency_ms: 0
1609        quality: 0.5
1610";
1611        let result = load_registry_from_str(yaml);
1612        assert!(result.is_err());
1613        let err = result.unwrap_err().to_string();
1614        assert!(
1615            err.contains("latency") && err.contains("0"),
1616            "Expected error about zero latency, got: {err}"
1617        );
1618    }
1619
1620    #[test]
1621    fn rejects_empty_provider() {
1622        let yaml = r"
1623providers:
1624  empty:
1625    env_key: KEY
1626    key_url: https://test.com
1627    api_url: https://api.test.com
1628    country: US
1629    region: US
1630    models: {}
1631";
1632        let result = load_registry_from_str(yaml);
1633        assert!(result.is_err());
1634        let err = result.unwrap_err().to_string();
1635        assert!(
1636            err.contains("at least one model"),
1637            "Expected error about empty models, got: {err}"
1638        );
1639    }
1640}