use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderProtocol {
OpenAiCompletions,
Anthropic,
Ollama,
GithubCopilot,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SetupHint {
ApiKey {
secret_name: String,
#[serde(default)]
key_url: Option<String>,
display_name: String,
#[serde(default)]
can_list_models: bool,
#[serde(default)]
models_filter: Option<String>,
},
Ollama {
display_name: String,
#[serde(default)]
can_list_models: bool,
},
OpenAiCompatible {
secret_name: String,
display_name: String,
#[serde(default)]
can_list_models: bool,
},
}
impl SetupHint {
pub fn display_name(&self) -> &str {
match self {
Self::ApiKey { display_name, .. } => display_name,
Self::Ollama { display_name, .. } => display_name,
Self::OpenAiCompatible { display_name, .. } => display_name,
}
}
pub fn can_list_models(&self) -> bool {
match self {
Self::ApiKey {
can_list_models, ..
} => *can_list_models,
Self::Ollama {
can_list_models, ..
} => *can_list_models,
Self::OpenAiCompatible {
can_list_models, ..
} => *can_list_models,
}
}
pub fn secret_name(&self) -> Option<&str> {
match self {
Self::ApiKey { secret_name, .. } => Some(secret_name),
Self::OpenAiCompatible { secret_name, .. } => Some(secret_name),
Self::Ollama { .. } => None,
}
}
pub fn models_filter(&self) -> Option<&str> {
match self {
Self::ApiKey { models_filter, .. } => models_filter.as_deref(),
_ => None,
}
}
}
mod unsupported_params_de {
use serde::{Deserialize, Deserializer};
const VALID_PARAMS: &[&str] = &["temperature", "max_tokens", "stop_sequences"];
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
let params: Vec<String> = Deserialize::deserialize(deserializer)?;
for param in ¶ms {
if !VALID_PARAMS.contains(¶m.as_str()) {
return Err(serde::de::Error::custom(format!(
"unsupported parameter name '{}': must be one of: {}",
param,
VALID_PARAMS.join(", ")
)));
}
}
Ok(params)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderDefinition {
pub id: String,
#[serde(default)]
pub aliases: Vec<String>,
pub protocol: ProviderProtocol,
#[serde(default)]
pub default_base_url: Option<String>,
#[serde(default)]
pub base_url_env: Option<String>,
#[serde(default)]
pub base_url_required: bool,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default)]
pub api_key_required: bool,
pub model_env: String,
pub default_model: String,
pub description: String,
#[serde(default)]
pub extra_headers_env: Option<String>,
#[serde(default)]
pub setup: Option<SetupHint>,
#[serde(default, deserialize_with = "unsupported_params_de::deserialize")]
pub unsupported_params: Vec<String>,
}
pub struct ProviderRegistry {
providers: Vec<ProviderDefinition>,
lookup: HashMap<String, usize>,
}
impl ProviderRegistry {
pub fn new(providers: Vec<ProviderDefinition>) -> Self {
let mut lookup = HashMap::new();
for (idx, def) in providers.iter().enumerate() {
lookup.insert(def.id.to_lowercase(), idx);
for alias in &def.aliases {
lookup.insert(alias.to_lowercase(), idx);
}
}
Self { providers, lookup }
}
pub fn load() -> Self {
let builtins: Vec<ProviderDefinition> =
serde_json::from_str(include_str!("../../providers.json"))
.expect("built-in providers.json must be valid JSON");
let mut all = builtins;
if let Some(user_path) = user_providers_path()
&& user_path.exists()
{
match std::fs::read_to_string(&user_path) {
Ok(contents) => match serde_json::from_str::<Vec<ProviderDefinition>>(&contents) {
Ok(user_defs) => {
tracing::info!(
count = user_defs.len(),
path = %user_path.display(),
"Loaded user provider definitions"
);
all.extend(user_defs);
}
Err(e) => {
tracing::warn!(
path = %user_path.display(),
error = %e,
"Failed to parse user providers.json, skipping"
);
}
},
Err(e) => {
tracing::warn!(
path = %user_path.display(),
error = %e,
"Failed to read user providers.json, skipping"
);
}
}
}
Self::new(all)
}
pub fn find(&self, id: &str) -> Option<&ProviderDefinition> {
self.lookup
.get(&id.to_lowercase())
.map(|&idx| &self.providers[idx])
}
pub fn all(&self) -> &[ProviderDefinition] {
&self.providers
}
pub fn selectable(&self) -> Vec<&ProviderDefinition> {
let mut seen = HashMap::new();
for def in &self.providers {
seen.insert(def.id.as_str(), def);
}
let mut result = Vec::new();
let mut emitted = std::collections::HashSet::new();
for def in &self.providers {
if emitted.insert(def.id.as_str()) {
let final_def = seen[def.id.as_str()];
if final_def.setup.is_some() {
result.push(final_def);
}
}
}
result
}
pub fn is_known(&self, backend: &str) -> bool {
backend == "nearai"
|| backend == "near_ai"
|| backend == "near"
|| self.find(backend).is_some()
}
pub fn model_env_var(&self, backend: &str) -> &str {
if backend == "nearai" || backend == "near_ai" || backend == "near" {
return "NEARAI_MODEL";
}
self.find(backend)
.map(|def| def.model_env.as_str())
.unwrap_or("LLM_MODEL")
}
}
fn user_providers_path() -> Option<std::path::PathBuf> {
Some(crate::bootstrap::ironclaw_base_dir().join("providers.json"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builtin_registry_loads() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert!(
registry.all().len() >= 5,
"should have at least 5 built-in providers"
);
}
#[test]
fn test_find_by_id() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
let openai = registry.find("openai").expect("openai should exist");
assert_eq!(openai.id, "openai");
assert_eq!(openai.protocol, ProviderProtocol::OpenAiCompletions);
}
#[test]
fn test_find_by_alias() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
let openai = registry
.find("open_ai")
.expect("alias open_ai should resolve");
assert_eq!(openai.id, "openai");
}
#[test]
fn test_find_case_insensitive() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert!(registry.find("OpenAI").is_some());
assert!(registry.find("GROQ").is_some());
assert!(registry.find("Tinfoil").is_some());
}
#[test]
fn test_find_unknown_returns_none() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert!(registry.find("nonexistent_provider").is_none());
}
#[test]
fn test_selectable_has_setup_hints() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
let selectable = registry.selectable();
assert!(!selectable.is_empty());
for def in &selectable {
assert!(
def.setup.is_some(),
"selectable provider {} must have setup hint",
def.id
);
}
}
#[test]
fn test_user_override_wins() {
let builtins: Vec<ProviderDefinition> =
serde_json::from_str(include_str!("../../providers.json")).unwrap();
let mut all = builtins;
all.push(ProviderDefinition {
id: "tinfoil".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("https://custom.tinfoil.example/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: Some("TINFOIL_API_KEY".to_string()),
api_key_required: true,
model_env: "TINFOIL_MODEL".to_string(),
default_model: "custom-model".to_string(),
description: "Custom tinfoil".to_string(),
extra_headers_env: None,
setup: None,
unsupported_params: vec![],
});
let registry = ProviderRegistry::new(all);
let tf = registry.find("tinfoil").expect("tinfoil should exist");
assert_eq!(tf.default_model, "custom-model", "user override should win");
}
#[test]
fn test_model_env_var_nearai() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert_eq!(registry.model_env_var("nearai"), "NEARAI_MODEL");
assert_eq!(registry.model_env_var("near_ai"), "NEARAI_MODEL");
}
#[test]
fn test_model_env_var_registry_provider() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert_eq!(registry.model_env_var("groq"), "GROQ_MODEL");
assert_eq!(registry.model_env_var("tinfoil"), "TINFOIL_MODEL");
assert_eq!(registry.model_env_var("openai"), "OPENAI_MODEL");
}
#[test]
fn test_model_env_var_unknown_fallback() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert_eq!(registry.model_env_var("nonexistent"), "LLM_MODEL");
}
#[test]
fn test_is_known() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
assert!(registry.is_known("nearai"));
assert!(registry.is_known("openai"));
assert!(registry.is_known("groq"));
assert!(!registry.is_known("nonexistent"));
}
#[test]
fn test_all_providers_have_required_fields() {
let providers: Vec<ProviderDefinition> =
serde_json::from_str(include_str!("../../providers.json")).unwrap();
for def in &providers {
assert!(!def.id.is_empty(), "provider must have an id");
assert!(!def.model_env.is_empty(), "{}: model_env required", def.id);
assert!(
!def.default_model.is_empty(),
"{}: default_model required",
def.id
);
assert!(
!def.description.is_empty(),
"{}: description required",
def.id
);
}
}
#[test]
fn test_openai_compatible_providers_have_base_url() {
let providers: Vec<ProviderDefinition> =
serde_json::from_str(include_str!("../../providers.json")).unwrap();
for def in &providers {
if def.protocol == ProviderProtocol::OpenAiCompletions
&& def.id != "openai"
&& def.id != "openai_compatible"
&& def.id != "bedrock"
&& def.id != "cloudflare"
{
assert!(
def.default_base_url.is_some(),
"{}: OpenAI-completions provider should have a default_base_url",
def.id
);
}
}
}
#[test]
fn test_models_filter_accessor() {
let registry = ProviderRegistry::new(
serde_json::from_str(include_str!("../../providers.json")).unwrap(),
);
let groq = registry.find("groq").expect("groq should exist");
let filter = groq
.setup
.as_ref()
.and_then(|s| s.models_filter())
.expect("groq should have models_filter");
assert_eq!(filter, "chat");
let openai = registry.find("openai").expect("openai should exist");
assert!(
openai
.setup
.as_ref()
.and_then(|s| s.models_filter())
.is_none(),
"openai should not have models_filter"
);
let ollama = registry.find("ollama").expect("ollama should exist");
assert!(
ollama
.setup
.as_ref()
.and_then(|s| s.models_filter())
.is_none(),
"ollama should not have models_filter"
);
}
#[test]
fn test_selectable_user_override_adds_setup() {
let mut providers: Vec<ProviderDefinition> = vec![ProviderDefinition {
id: "custom".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://localhost/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: None,
api_key_required: false,
model_env: "CUSTOM_MODEL".to_string(),
default_model: "m1".to_string(),
description: "No setup".to_string(),
extra_headers_env: None,
setup: None, unsupported_params: vec![],
}];
let registry = ProviderRegistry::new(providers.clone());
assert!(
registry.selectable().is_empty(),
"provider without setup should not be selectable"
);
providers.push(ProviderDefinition {
id: "custom".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://localhost/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: Some("CUSTOM_API_KEY".to_string()),
api_key_required: true,
model_env: "CUSTOM_MODEL".to_string(),
default_model: "m1".to_string(),
description: "Now with setup".to_string(),
extra_headers_env: None,
setup: Some(SetupHint::ApiKey {
secret_name: "llm_custom_api_key".to_string(),
key_url: None,
display_name: "Custom".to_string(),
can_list_models: false,
models_filter: None,
}),
unsupported_params: vec![],
});
let registry = ProviderRegistry::new(providers);
let selectable = registry.selectable();
assert_eq!(
selectable.len(),
1,
"user override with setup should appear"
);
assert_eq!(selectable[0].id, "custom");
assert_eq!(
selectable[0].description, "Now with setup",
"should use the overridden definition"
);
}
#[test]
fn test_selectable_user_override_removes_setup() {
let providers = vec![
ProviderDefinition {
id: "provider_a".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://a/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: Some("A_KEY".to_string()),
api_key_required: true,
model_env: "A_MODEL".to_string(),
default_model: "m1".to_string(),
description: "Has setup".to_string(),
extra_headers_env: None,
setup: Some(SetupHint::ApiKey {
secret_name: "a".to_string(),
key_url: None,
display_name: "A".to_string(),
can_list_models: false,
models_filter: None,
}),
unsupported_params: vec![],
},
ProviderDefinition {
id: "provider_a".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://a/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: Some("A_KEY".to_string()),
api_key_required: false,
model_env: "A_MODEL".to_string(),
default_model: "m1".to_string(),
description: "No setup now".to_string(),
extra_headers_env: None,
setup: None,
unsupported_params: vec![],
},
];
let registry = ProviderRegistry::new(providers);
assert!(
registry.selectable().is_empty(),
"user override removing setup should exclude from selectable"
);
let def = registry
.find("provider_a")
.expect("should still be findable");
assert_eq!(def.description, "No setup now");
}
#[test]
fn test_selectable_preserves_order_with_dedup() {
let providers = vec![
ProviderDefinition {
id: "aaa".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://a/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: None,
api_key_required: false,
model_env: "A".to_string(),
default_model: "m".to_string(),
description: "A".to_string(),
extra_headers_env: None,
setup: Some(SetupHint::Ollama {
display_name: "A".to_string(),
can_list_models: false,
}),
unsupported_params: vec![],
},
ProviderDefinition {
id: "bbb".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://b/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: None,
api_key_required: false,
model_env: "B".to_string(),
default_model: "m".to_string(),
description: "B-original".to_string(),
extra_headers_env: None,
setup: Some(SetupHint::Ollama {
display_name: "B".to_string(),
can_list_models: false,
}),
unsupported_params: vec![],
},
ProviderDefinition {
id: "ccc".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://c/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: None,
api_key_required: false,
model_env: "C".to_string(),
default_model: "m".to_string(),
description: "C".to_string(),
extra_headers_env: None,
setup: Some(SetupHint::Ollama {
display_name: "C".to_string(),
can_list_models: false,
}),
unsupported_params: vec![],
},
ProviderDefinition {
id: "bbb".to_string(),
aliases: vec![],
protocol: ProviderProtocol::OpenAiCompletions,
default_base_url: Some("http://b-new/v1".to_string()),
base_url_env: None,
base_url_required: false,
api_key_env: None,
api_key_required: false,
model_env: "B".to_string(),
default_model: "m".to_string(),
description: "B-override".to_string(),
extra_headers_env: None,
setup: Some(SetupHint::Ollama {
display_name: "B".to_string(),
can_list_models: false,
}),
unsupported_params: vec![],
},
];
let registry = ProviderRegistry::new(providers);
let selectable = registry.selectable();
let ids: Vec<&str> = selectable.iter().map(|d| d.id.as_str()).collect();
assert_eq!(ids, vec!["aaa", "bbb", "ccc"], "order should be preserved");
assert_eq!(
selectable[1].description, "B-override",
"should use the overridden definition"
);
}
#[test]
fn test_unsupported_params_deserialized() {
let providers: Vec<ProviderDefinition> =
serde_json::from_str(include_str!("../../providers.json")).unwrap();
let tinfoil = providers.iter().find(|p| p.id == "tinfoil").unwrap();
assert!(
tinfoil
.unsupported_params
.contains(&"temperature".to_string()),
"tinfoil should have 'temperature' in unsupported_params"
);
let openai = providers.iter().find(|p| p.id == "openai").unwrap();
assert!(
openai
.unsupported_params
.contains(&"temperature".to_string()),
"openai should have 'temperature' in unsupported_params"
);
let groq = providers.iter().find(|p| p.id == "groq").unwrap();
assert!(
groq.unsupported_params.is_empty(),
"groq should have empty unsupported_params (field absent in JSON)"
);
for def in &providers {
for param in &def.unsupported_params {
assert!(
!param.is_empty(),
"{}: unsupported_params contains empty string",
def.id
);
assert!(
matches!(
param.as_str(),
"temperature" | "max_tokens" | "stop_sequences"
),
"{}: unsupported_params contains invalid parameter '{}'",
def.id,
param
);
}
}
}
#[test]
fn test_unsupported_params_validation_rejects_invalid() {
let invalid_json = r#"[{
"id": "test",
"protocol": "open_ai_completions",
"model_env": "TEST_MODEL",
"default_model": "test-model",
"description": "Test provider",
"unsupported_params": ["temperrature"]
}]"#;
let result: Result<Vec<ProviderDefinition>, _> = serde_json::from_str(invalid_json);
assert!(
result.is_err(),
"should reject invalid parameter name 'temperrature'"
);
assert!(
result.err().unwrap().to_string().contains("temperrature"),
"error message should mention the invalid parameter"
);
}
#[test]
fn test_all_builtin_api_key_providers_have_api_key_env() {
let providers: Vec<ProviderDefinition> =
serde_json::from_str(include_str!("../../providers.json")).unwrap();
for def in &providers {
if let Some(SetupHint::ApiKey { .. }) = &def.setup {
assert!(
def.api_key_env.is_some(),
"{}: ApiKey setup hint requires api_key_env to be set",
def.id
);
}
}
}
}