use crate::model_selection::{ModelMetadata, ModelSelector};
use converge_provider_api::selection::{ComplianceLevel, CostClass, DataSovereignty};
use schemars::JsonSchema;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
#[error("Failed to read registry file: {0}")]
IoError(#[from] std::io::Error),
#[error("Failed to parse registry YAML: {0}")]
ParseError(#[from] serde_yaml::Error),
#[error("Registry validation failed: {0}")]
ValidationError(String),
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct RegistryYaml {
pub providers: HashMap<String, ProviderYaml>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum ProviderTypeYaml {
#[default]
Direct,
Aggregator,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ProviderYaml {
pub env_key: String,
#[serde(default)]
pub env_key_secondary: Option<String>,
pub key_url: String,
pub api_url: String,
pub country: String,
pub region: RegionYaml,
#[serde(default)]
pub compliance: Vec<ComplianceYaml>,
#[serde(default)]
pub provider_type: ProviderTypeYaml,
pub models: HashMap<String, ModelYaml>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
pub enum RegionYaml {
US,
EU,
EEA,
CH,
CN,
JP,
UK,
LOCAL,
}
impl RegionYaml {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::US => "US",
Self::EU => "EU",
Self::EEA => "EEA",
Self::CH => "CH",
Self::CN => "CN",
Self::JP => "JP",
Self::UK => "UK",
Self::LOCAL => "LOCAL",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
pub enum ComplianceYaml {
GDPR,
SOC2,
HIPAA,
}
impl From<ComplianceYaml> for ComplianceLevel {
fn from(c: ComplianceYaml) -> Self {
match c {
ComplianceYaml::GDPR => ComplianceLevel::GDPR,
ComplianceYaml::SOC2 => ComplianceLevel::SOC2,
ComplianceYaml::HIPAA => ComplianceLevel::HIPAA,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema)]
pub enum CostClassYaml {
VeryLow,
Low,
Medium,
High,
VeryHigh,
}
impl From<CostClassYaml> for CostClass {
fn from(c: CostClassYaml) -> Self {
match c {
CostClassYaml::VeryLow => CostClass::VeryLow,
CostClassYaml::Low => CostClass::Low,
CostClassYaml::Medium => CostClass::Medium,
CostClassYaml::High => CostClass::High,
CostClassYaml::VeryHigh => CostClass::VeryHigh,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum CapabilityYaml {
ToolUse,
Vision,
StructuredOutput,
Code,
Reasoning,
Multilingual,
WebSearch,
Audio,
ImageGeneration,
Streaming,
Logprobs,
Seed,
ToolChoice,
ParallelToolCalls,
PromptCaching,
FileSearch,
CodeInterpreter,
ComputerUse,
ToolSearch,
Mcp,
HostedShell,
ApplyPatch,
NativeCompaction,
ReasoningEffort,
}
impl CapabilityYaml {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::ToolUse => "tool_use",
Self::Vision => "vision",
Self::StructuredOutput => "structured_output",
Self::Code => "code",
Self::Reasoning => "reasoning",
Self::Multilingual => "multilingual",
Self::WebSearch => "web_search",
Self::Audio => "audio",
Self::ImageGeneration => "image_generation",
Self::Streaming => "streaming",
Self::Logprobs => "logprobs",
Self::Seed => "seed",
Self::ToolChoice => "tool_choice",
Self::ParallelToolCalls => "parallel_tool_calls",
Self::PromptCaching => "prompt_caching",
Self::FileSearch => "file_search",
Self::CodeInterpreter => "code_interpreter",
Self::ComputerUse => "computer_use",
Self::ToolSearch => "tool_search",
Self::Mcp => "mcp",
Self::HostedShell => "hosted_shell",
Self::ApplyPatch => "apply_patch",
Self::NativeCompaction => "native_compaction",
Self::ReasoningEffort => "reasoning_effort",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffortYaml {
None,
Minimal,
Low,
Medium,
High,
Xhigh,
}
impl ReasoningEffortYaml {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Minimal => "minimal",
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
Self::Xhigh => "xhigh",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum ModelTypeYaml {
#[default]
Llm,
Embedding,
Reranker,
Ocr,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum ArchitectureYaml {
#[default]
Dense,
Moe,
Hybrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ModalityYaml {
Text,
Image,
Video,
Audio,
}
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct AgenticYaml {
#[serde(default)]
pub max_parallel_agents: Option<u32>,
#[serde(default)]
pub supports_orchestration: bool,
}
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct PricingYaml {
#[serde(default)]
pub input_per_m: Option<f64>,
#[serde(default)]
pub output_per_m: Option<f64>,
}
#[derive(Debug, Clone, Default, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct RateLimitsYaml {
#[serde(default)]
pub requests_per_min: Option<u32>,
#[serde(default)]
pub tokens_per_min: Option<u32>,
#[serde(default)]
pub requests_per_day: Option<u32>,
#[serde(default)]
pub concurrent_requests: Option<u32>,
}
#[derive(Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ModelYaml {
pub cost_class: CostClassYaml,
pub typical_latency_ms: u32,
pub quality: f64,
#[serde(default = "default_context_tokens")]
pub context_tokens: usize,
#[serde(default)]
pub capabilities: Vec<CapabilityYaml>,
#[serde(default, rename = "type")]
pub model_type: ModelTypeYaml,
#[serde(default)]
pub dimensions: Option<usize>,
#[serde(default)]
pub architecture: ArchitectureYaml,
#[serde(default)]
pub total_params_b: Option<f64>,
#[serde(default)]
pub active_params_b: Option<f64>,
#[serde(default)]
pub max_output_tokens: Option<usize>,
#[serde(default)]
pub native_multimodal: bool,
#[serde(default)]
pub modalities: Vec<ModalityYaml>,
#[serde(default)]
pub agentic: Option<AgenticYaml>,
#[serde(default)]
pub thinking_mode: bool,
#[serde(default)]
pub reasoning_effort_levels: Vec<ReasoningEffortYaml>,
#[serde(default)]
pub native_compaction: bool,
#[serde(default)]
pub thinking_variant: Option<String>,
#[serde(default)]
pub pricing: Option<PricingYaml>,
#[serde(default)]
pub publisher: Option<String>,
#[serde(default)]
pub family: Option<String>,
#[serde(default)]
pub release_date: Option<String>,
#[serde(default)]
pub training_cutoff: Option<String>,
#[serde(default)]
pub open_weights: bool,
#[serde(default)]
pub license: Option<String>,
#[serde(default)]
pub deprecated: bool,
#[serde(default)]
pub beta: bool,
#[serde(default)]
pub benchmarks: HashMap<String, f64>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub rate_limits: Option<RateLimitsYaml>,
#[serde(default)]
pub notes: Option<String>,
}
fn default_context_tokens() -> usize {
8192
}
#[must_use]
pub fn generate_schema() -> schemars::schema::RootSchema {
schemars::schema_for!(RegistryYaml)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProviderType {
Direct,
Aggregator,
}
#[derive(Debug, Clone)]
pub struct LoadedProvider {
pub id: String,
pub env_key: String,
pub env_key_secondary: Option<String>,
pub key_url: String,
pub api_url: String,
pub country: String,
pub region: String,
pub compliance: Vec<ComplianceLevel>,
pub provider_type: ProviderType,
pub models: Vec<LoadedModel>,
}
impl LoadedProvider {
#[must_use]
pub fn is_available(&self) -> bool {
let primary_ok = std::env::var(&self.env_key).is_ok();
let secondary_ok = self
.env_key_secondary
.as_ref()
.map(|k| std::env::var(k).is_ok())
.unwrap_or(true);
primary_ok && secondary_ok
}
#[must_use]
pub fn api_key(&self) -> Option<String> {
std::env::var(&self.env_key).ok()
}
#[must_use]
pub fn secondary_api_key(&self) -> Option<String> {
self.env_key_secondary
.as_ref()
.and_then(|k| std::env::var(k).ok())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Architecture {
Dense,
Moe,
Hybrid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Modality {
Text,
Image,
Video,
Audio,
}
#[derive(Debug, Clone, Default)]
pub struct AgenticCapabilities {
pub max_parallel_agents: Option<u32>,
pub supports_orchestration: bool,
}
#[derive(Debug, Clone, Default)]
pub struct Pricing {
pub input_per_m: Option<f64>,
pub output_per_m: Option<f64>,
}
#[derive(Debug, Clone, Default)]
pub struct RateLimits {
pub requests_per_min: Option<u32>,
pub tokens_per_min: Option<u32>,
pub requests_per_day: Option<u32>,
pub concurrent_requests: Option<u32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReasoningEffort {
None,
Minimal,
Low,
Medium,
High,
Xhigh,
}
impl ReasoningEffort {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Minimal => "minimal",
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
Self::Xhigh => "xhigh",
}
}
}
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct LoadedModel {
pub id: String,
pub cost_class: CostClass,
pub typical_latency_ms: u32,
pub quality: f64,
pub context_tokens: usize,
pub model_type: ModelType,
pub dimensions: Option<usize>,
pub capabilities: Vec<CapabilityYaml>,
pub supports_tool_use: bool,
pub supports_vision: bool,
pub supports_structured_output: bool,
pub supports_code: bool,
pub supports_reasoning: bool,
pub supports_multilingual: bool,
pub supports_web_search: bool,
pub architecture: Architecture,
pub total_params_b: Option<f64>,
pub active_params_b: Option<f64>,
pub max_output_tokens: Option<usize>,
pub native_multimodal: bool,
pub modalities: Vec<Modality>,
pub agentic: Option<AgenticCapabilities>,
pub thinking_mode: bool,
pub reasoning_effort_levels: Vec<ReasoningEffort>,
pub native_compaction: bool,
pub thinking_variant: Option<String>,
pub pricing: Option<Pricing>,
pub publisher: Option<String>,
pub family: Option<String>,
pub release_date: Option<String>,
pub training_cutoff: Option<String>,
pub open_weights: bool,
pub license: Option<String>,
pub deprecated: bool,
pub beta: bool,
pub benchmarks: HashMap<String, f64>,
pub tags: Vec<String>,
pub rate_limits: Option<RateLimits>,
pub notes: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelType {
Llm,
Embedding,
Reranker,
Ocr,
}
#[derive(Debug, Clone)]
pub struct LoadedRegistry {
providers: Vec<LoadedProvider>,
}
impl LoadedRegistry {
#[must_use]
pub fn providers(&self) -> &[LoadedProvider] {
&self.providers
}
#[must_use]
pub fn available_providers(&self) -> Vec<&LoadedProvider> {
self.providers.iter().filter(|p| p.is_available()).collect()
}
#[must_use]
pub fn get_provider(&self, id: &str) -> Option<&LoadedProvider> {
self.providers.iter().find(|p| p.id == id)
}
#[must_use]
pub fn llm_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
self.providers
.iter()
.flat_map(|p| {
p.models
.iter()
.filter(|m| m.model_type == ModelType::Llm)
.map(move |m| (p, m))
})
.collect()
}
#[must_use]
pub fn embedding_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
self.providers
.iter()
.flat_map(|p| {
p.models
.iter()
.filter(|m| m.model_type == ModelType::Embedding)
.map(move |m| (p, m))
})
.collect()
}
#[must_use]
pub fn reranker_models(&self) -> Vec<(&LoadedProvider, &LoadedModel)> {
self.providers
.iter()
.flat_map(|p| {
p.models
.iter()
.filter(|m| m.model_type == ModelType::Reranker)
.map(move |m| (p, m))
})
.collect()
}
#[must_use]
pub fn to_model_selector(&self) -> ModelSelector {
let mut selector = ModelSelector::empty();
for provider in &self.providers {
for model in &provider.models {
if model.model_type != ModelType::Llm {
continue; }
let data_sovereignty = match provider.region.as_str() {
"EU" | "EEA" => DataSovereignty::EU,
"CH" => DataSovereignty::Switzerland,
"CN" => DataSovereignty::China,
"US" => DataSovereignty::US,
"LOCAL" => DataSovereignty::OnPremises,
_ => DataSovereignty::Any,
};
let compliance = provider
.compliance
.first()
.copied()
.unwrap_or(ComplianceLevel::None);
let metadata = ModelMetadata::new(
&provider.id,
&model.id,
model.cost_class,
model.typical_latency_ms,
model.quality,
)
.with_reasoning(model.supports_reasoning)
.with_web_search(model.supports_web_search)
.with_data_sovereignty(data_sovereignty)
.with_compliance(compliance)
.with_multilingual(model.supports_multilingual)
.with_context_tokens(model.context_tokens)
.with_tool_use(model.supports_tool_use)
.with_vision(model.supports_vision)
.with_structured_output(model.supports_structured_output)
.with_code(model.supports_code)
.with_location(&provider.country, &provider.region);
selector = selector.with_model(metadata);
}
}
selector
}
pub fn print_summary(&self) {
println!("Model Registry Summary");
println!("======================\n");
for provider in &self.providers {
let status = if provider.is_available() {
"✓ available"
} else {
"✗ no key"
};
println!(
"{} ({}) - {} models [{}]",
provider.id,
provider.region,
provider.models.len(),
status
);
println!(" Key URL: {}", provider.key_url);
println!(" API URL: {}", provider.api_url);
println!();
}
}
}
pub const DEFAULT_REGISTRY_PATH: &str = "converge-provider/config/models.yaml";
pub fn load_registry() -> Result<LoadedRegistry, RegistryError> {
if let Ok(path) = std::env::var("CONVERGE_MODELS_PATH") {
return load_registry_from_path(&path);
}
if std::path::Path::new(DEFAULT_REGISTRY_PATH).exists() {
return load_registry_from_path(DEFAULT_REGISTRY_PATH);
}
let crate_path = "config/models.yaml";
if std::path::Path::new(crate_path).exists() {
return load_registry_from_path(crate_path);
}
load_registry_from_str(include_str!("../config/models.yaml"))
}
pub fn load_registry_from_path(path: impl AsRef<Path>) -> Result<LoadedRegistry, RegistryError> {
let content = std::fs::read_to_string(path)?;
load_registry_from_str(&content)
}
pub fn load_registry_from_str(yaml: &str) -> Result<LoadedRegistry, RegistryError> {
let registry_yaml: RegistryYaml = serde_yaml::from_str(yaml)?;
let mut providers = Vec::new();
let mut errors = Vec::new();
for (provider_id, provider_yaml) in registry_yaml.providers {
if let Err(e) = validate_provider(&provider_id, &provider_yaml) {
errors.push(e);
continue;
}
let compliance = provider_yaml
.compliance
.iter()
.map(|c| ComplianceLevel::from(*c))
.collect();
let mut models = Vec::new();
for (model_id, model_yaml) in provider_yaml.models {
if let Err(e) = validate_model(&provider_id, &model_id, &model_yaml) {
errors.push(e);
continue;
}
let capabilities: std::collections::HashSet<_> =
model_yaml.capabilities.iter().copied().collect();
let modalities: Vec<Modality> = model_yaml
.modalities
.iter()
.map(|m| match m {
ModalityYaml::Text => Modality::Text,
ModalityYaml::Image => Modality::Image,
ModalityYaml::Video => Modality::Video,
ModalityYaml::Audio => Modality::Audio,
})
.collect();
let reasoning_effort_levels = model_yaml
.reasoning_effort_levels
.iter()
.copied()
.map(ReasoningEffort::from)
.collect();
let agentic = model_yaml.agentic.as_ref().map(|a| AgenticCapabilities {
max_parallel_agents: a.max_parallel_agents,
supports_orchestration: a.supports_orchestration,
});
let pricing = model_yaml.pricing.as_ref().map(|p| Pricing {
input_per_m: p.input_per_m,
output_per_m: p.output_per_m,
});
let rate_limits = model_yaml.rate_limits.as_ref().map(|r| RateLimits {
requests_per_min: r.requests_per_min,
tokens_per_min: r.tokens_per_min,
requests_per_day: r.requests_per_day,
concurrent_requests: r.concurrent_requests,
});
let model = LoadedModel {
id: model_id,
cost_class: model_yaml.cost_class.into(),
typical_latency_ms: model_yaml.typical_latency_ms,
quality: model_yaml.quality,
context_tokens: model_yaml.context_tokens,
model_type: model_yaml.model_type.into(),
dimensions: model_yaml.dimensions,
capabilities: model_yaml.capabilities.clone(),
supports_tool_use: capabilities.contains(&CapabilityYaml::ToolUse),
supports_vision: capabilities.contains(&CapabilityYaml::Vision),
supports_structured_output: capabilities
.contains(&CapabilityYaml::StructuredOutput),
supports_code: capabilities.contains(&CapabilityYaml::Code),
supports_reasoning: capabilities.contains(&CapabilityYaml::Reasoning),
supports_multilingual: capabilities.contains(&CapabilityYaml::Multilingual),
supports_web_search: capabilities.contains(&CapabilityYaml::WebSearch),
architecture: model_yaml.architecture.into(),
total_params_b: model_yaml.total_params_b,
active_params_b: model_yaml.active_params_b,
max_output_tokens: model_yaml.max_output_tokens,
native_multimodal: model_yaml.native_multimodal,
modalities,
agentic,
thinking_mode: model_yaml.thinking_mode,
reasoning_effort_levels,
native_compaction: model_yaml.native_compaction,
thinking_variant: model_yaml.thinking_variant.clone(),
pricing,
publisher: model_yaml.publisher.clone(),
family: model_yaml.family.clone(),
release_date: model_yaml.release_date.clone(),
training_cutoff: model_yaml.training_cutoff.clone(),
open_weights: model_yaml.open_weights,
license: model_yaml.license.clone(),
deprecated: model_yaml.deprecated,
beta: model_yaml.beta,
benchmarks: model_yaml.benchmarks.clone(),
tags: model_yaml.tags.clone(),
rate_limits,
notes: model_yaml.notes.clone(),
};
models.push(model);
}
models.sort_by(|a, b| a.id.cmp(&b.id));
let provider = LoadedProvider {
id: provider_id,
env_key: provider_yaml.env_key,
env_key_secondary: provider_yaml.env_key_secondary,
key_url: provider_yaml.key_url,
api_url: provider_yaml.api_url,
country: provider_yaml.country,
region: provider_yaml.region.as_str().to_string(),
compliance,
provider_type: provider_yaml.provider_type.into(),
models,
};
providers.push(provider);
}
if !errors.is_empty() {
return Err(RegistryError::ValidationError(errors.join("; ")));
}
providers.sort_by(|a, b| a.id.cmp(&b.id));
Ok(LoadedRegistry { providers })
}
fn validate_provider(id: &str, provider: &ProviderYaml) -> Result<(), String> {
if provider.env_key.is_empty() {
return Err(format!("Provider '{id}': env_key cannot be empty"));
}
if !provider.key_url.starts_with("http://") && !provider.key_url.starts_with("https://") {
return Err(format!(
"Provider '{id}': key_url must be a valid URL, got '{}'",
provider.key_url
));
}
if !provider.api_url.starts_with("http://") && !provider.api_url.starts_with("https://") {
return Err(format!(
"Provider '{id}': api_url must be a valid URL, got '{}'",
provider.api_url
));
}
if provider.country != "LOCAL" && provider.country.len() != 2 {
return Err(format!(
"Provider '{id}': country must be 2-letter ISO code or 'LOCAL', got '{}'",
provider.country
));
}
if provider.models.is_empty() {
return Err(format!("Provider '{id}': must have at least one model"));
}
Ok(())
}
fn validate_model(provider_id: &str, model_id: &str, model: &ModelYaml) -> Result<(), String> {
if !(0.0..=1.0).contains(&model.quality) {
return Err(format!(
"Model '{provider_id}/{model_id}': quality must be 0.0-1.0, got {}",
model.quality
));
}
if model.typical_latency_ms == 0 {
return Err(format!(
"Model '{provider_id}/{model_id}': typical_latency_ms must be > 0"
));
}
if model.context_tokens == 0 {
return Err(format!(
"Model '{provider_id}/{model_id}': context_tokens must be > 0"
));
}
if model.model_type == ModelTypeYaml::Embedding && model.dimensions.is_none() {
return Err(format!(
"Model '{provider_id}/{model_id}': embedding models must specify dimensions"
));
}
Ok(())
}
impl From<ModelTypeYaml> for ModelType {
fn from(t: ModelTypeYaml) -> Self {
match t {
ModelTypeYaml::Llm => ModelType::Llm,
ModelTypeYaml::Embedding => ModelType::Embedding,
ModelTypeYaml::Reranker => ModelType::Reranker,
ModelTypeYaml::Ocr => ModelType::Ocr,
}
}
}
impl From<ArchitectureYaml> for Architecture {
fn from(a: ArchitectureYaml) -> Self {
match a {
ArchitectureYaml::Dense => Architecture::Dense,
ArchitectureYaml::Moe => Architecture::Moe,
ArchitectureYaml::Hybrid => Architecture::Hybrid,
}
}
}
impl From<ReasoningEffortYaml> for ReasoningEffort {
fn from(effort: ReasoningEffortYaml) -> Self {
match effort {
ReasoningEffortYaml::None => Self::None,
ReasoningEffortYaml::Minimal => Self::Minimal,
ReasoningEffortYaml::Low => Self::Low,
ReasoningEffortYaml::Medium => Self::Medium,
ReasoningEffortYaml::High => Self::High,
ReasoningEffortYaml::Xhigh => Self::Xhigh,
}
}
}
impl From<ProviderTypeYaml> for ProviderType {
fn from(p: ProviderTypeYaml) -> Self {
match p {
ProviderTypeYaml::Direct => ProviderType::Direct,
ProviderTypeYaml::Aggregator => ProviderType::Aggregator,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_YAML: &str = r"
providers:
test-provider:
env_key: TEST_API_KEY
key_url: https://test.com/keys
api_url: https://api.test.com/v1
country: US
region: US
models:
test-model:
cost_class: Low
typical_latency_ms: 2000
quality: 0.85
context_tokens: 128000
capabilities: [tool_use, reasoning, code]
test-embedding:
cost_class: VeryLow
typical_latency_ms: 100
quality: 0.80
context_tokens: 8192
capabilities: []
type: embedding
dimensions: 1024
";
const INVALID_COST_CLASS_YAML: &str = r"
providers:
bad-provider:
env_key: TEST_KEY
key_url: https://test.com/keys
api_url: https://api.test.com/v1
country: US
region: US
models:
bad-model:
cost_class: SuperLow
typical_latency_ms: 100
quality: 0.5
";
const INVALID_CAPABILITY_YAML: &str = r"
providers:
bad-provider:
env_key: TEST_KEY
key_url: https://test.com/keys
api_url: https://api.test.com/v1
country: US
region: US
models:
bad-model:
cost_class: Low
typical_latency_ms: 100
quality: 0.5
capabilities: [tool_use, telepathy]
";
const INVALID_QUALITY_YAML: &str = r"
providers:
bad-provider:
env_key: TEST_KEY
key_url: https://test.com/keys
api_url: https://api.test.com/v1
country: US
region: US
models:
bad-model:
cost_class: Low
typical_latency_ms: 100
quality: 1.5
";
const MISSING_DIMENSIONS_YAML: &str = r"
providers:
bad-provider:
env_key: TEST_KEY
key_url: https://test.com/keys
api_url: https://api.test.com/v1
country: US
region: US
models:
bad-embedding:
cost_class: Low
typical_latency_ms: 100
quality: 0.5
type: embedding
";
const UNKNOWN_FIELD_YAML: &str = r"
providers:
bad-provider:
env_key: TEST_KEY
key_url: https://test.com/keys
api_url: https://api.test.com/v1
country: US
region: US
unknown_field: oops
models:
model:
cost_class: Low
typical_latency_ms: 100
quality: 0.5
";
#[test]
fn parse_yaml() {
let registry = load_registry_from_str(TEST_YAML).unwrap();
assert_eq!(registry.providers.len(), 1);
let provider = ®istry.providers[0];
assert_eq!(provider.id, "test-provider");
assert_eq!(provider.key_url, "https://test.com/keys");
assert_eq!(provider.api_url, "https://api.test.com/v1");
assert_eq!(provider.models.len(), 2);
}
#[test]
fn parse_model_capabilities() {
let registry = load_registry_from_str(TEST_YAML).unwrap();
let provider = ®istry.providers[0];
let llm = provider
.models
.iter()
.find(|m| m.id == "test-model")
.unwrap();
assert!(llm.supports_tool_use);
assert!(llm.supports_reasoning);
assert!(llm.supports_code);
assert!(!llm.supports_vision);
assert_eq!(llm.model_type, ModelType::Llm);
}
#[test]
fn parse_embedding_model() {
let registry = load_registry_from_str(TEST_YAML).unwrap();
let provider = ®istry.providers[0];
let embedding = provider
.models
.iter()
.find(|m| m.id == "test-embedding")
.unwrap();
assert_eq!(embedding.model_type, ModelType::Embedding);
assert_eq!(embedding.dimensions, Some(1024));
}
#[test]
fn filter_by_model_type() {
let registry = load_registry_from_str(TEST_YAML).unwrap();
let llms = registry.llm_models();
assert_eq!(llms.len(), 1);
assert_eq!(llms[0].1.id, "test-model");
let embeddings = registry.embedding_models();
assert_eq!(embeddings.len(), 1);
assert_eq!(embeddings[0].1.id, "test-embedding");
}
#[test]
fn to_model_selector() {
let registry = load_registry_from_str(TEST_YAML).unwrap();
let selector = registry.to_model_selector();
let reqs = converge_core::model_selection::AgentRequirements::balanced();
let satisfying = selector.list_satisfying(&reqs);
assert_eq!(satisfying.len(), 1);
}
#[test]
fn provider_availability() {
let registry = load_registry_from_str(TEST_YAML).unwrap();
let provider = ®istry.providers[0];
let _ = provider.is_available(); }
#[test]
fn load_real_registry() {
let registry = load_registry().unwrap();
assert!(
registry.providers.len() >= 10,
"Expected at least 10 providers"
);
let provider_ids: Vec<_> = registry.providers.iter().map(|p| p.id.as_str()).collect();
assert!(provider_ids.contains(&"anthropic"), "Missing anthropic");
assert!(provider_ids.contains(&"openai"), "Missing openai");
assert!(provider_ids.contains(&"mistral"), "Missing mistral");
assert!(provider_ids.contains(&"ollama"), "Missing ollama");
let anthropic = registry.get_provider("anthropic").unwrap();
assert_eq!(
anthropic.key_url,
"https://console.anthropic.com/settings/keys"
);
assert_eq!(anthropic.api_url, "https://api.anthropic.com/v1");
assert_eq!(anthropic.env_key, "ANTHROPIC_API_KEY");
let ollama = registry.get_provider("ollama").unwrap();
assert_eq!(ollama.region, "LOCAL");
let llms = registry.llm_models();
assert!(llms.len() >= 30, "Expected at least 30 LLM models");
let embeddings = registry.embedding_models();
assert!(
embeddings.len() >= 3,
"Expected at least 3 embedding models"
);
println!(
"Loaded {} providers with {} LLM models and {} embedding models",
registry.providers.len(),
llms.len(),
embeddings.len()
);
}
#[test]
fn rejects_invalid_cost_class() {
let result = load_registry_from_str(INVALID_COST_CLASS_YAML);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("SuperLow") || err.contains("unknown variant"),
"Expected error about invalid cost class, got: {err}"
);
}
#[test]
fn rejects_invalid_capability() {
let result = load_registry_from_str(INVALID_CAPABILITY_YAML);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("telepathy") || err.contains("unknown variant"),
"Expected error about invalid capability, got: {err}"
);
}
#[test]
fn rejects_invalid_quality() {
let result = load_registry_from_str(INVALID_QUALITY_YAML);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("quality") && err.contains("1.5"),
"Expected error about quality out of range, got: {err}"
);
}
#[test]
fn rejects_embedding_without_dimensions() {
let result = load_registry_from_str(MISSING_DIMENSIONS_YAML);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("dimensions"),
"Expected error about missing dimensions, got: {err}"
);
}
#[test]
fn rejects_unknown_fields() {
let result = load_registry_from_str(UNKNOWN_FIELD_YAML);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("unknown_field") || err.contains("unknown field"),
"Expected error about unknown field, got: {err}"
);
}
#[test]
fn rejects_invalid_region() {
let yaml = r"
providers:
bad:
env_key: KEY
key_url: https://test.com
api_url: https://api.test.com
country: US
region: INVALID
models:
m:
cost_class: Low
typical_latency_ms: 100
quality: 0.5
";
let result = load_registry_from_str(yaml);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("INVALID") || err.contains("unknown variant"),
"Expected error about invalid region, got: {err}"
);
}
#[test]
fn rejects_invalid_url() {
let yaml = r"
providers:
bad:
env_key: KEY
key_url: not-a-url
api_url: https://api.test.com
country: US
region: US
models:
m:
cost_class: Low
typical_latency_ms: 100
quality: 0.5
";
let result = load_registry_from_str(yaml);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("key_url") && err.contains("URL"),
"Expected error about invalid URL, got: {err}"
);
}
#[test]
fn rejects_zero_latency() {
let yaml = r"
providers:
bad:
env_key: KEY
key_url: https://test.com
api_url: https://api.test.com
country: US
region: US
models:
m:
cost_class: Low
typical_latency_ms: 0
quality: 0.5
";
let result = load_registry_from_str(yaml);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("latency") && err.contains("0"),
"Expected error about zero latency, got: {err}"
);
}
#[test]
fn rejects_empty_provider() {
let yaml = r"
providers:
empty:
env_key: KEY
key_url: https://test.com
api_url: https://api.test.com
country: US
region: US
models: {}
";
let result = load_registry_from_str(yaml);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("at least one model"),
"Expected error about empty models, got: {err}"
);
}
}