use std::collections::HashMap;
use std::sync::OnceLock;
use crate::config::{Config, ProviderAuth, ProviderEntry};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ProviderKind {
OpenRouter,
OpenAI,
Anthropic,
Gemini,
DeepSeek,
Glm,
Ollama,
Custom,
}
pub fn default_model_for(provider_name: &str) -> &'static str {
match parse_provider(provider_name) {
Some(ProviderKind::OpenAI) => "gpt-4o",
Some(ProviderKind::Anthropic) => "claude-sonnet-4-6",
Some(ProviderKind::Gemini) => "gemini-2.0-flash",
Some(ProviderKind::DeepSeek) => "deepseek-v4-pro",
Some(ProviderKind::Glm) => "glm-4",
Some(ProviderKind::Ollama) => "llama3",
_ => "deepseek/deepseek-v4-flash",
}
}
pub fn default_model_for_entry(alias: &str, entry: &ProviderEntry) -> &'static str {
default_model_for(&Config::provider_type_of(alias, entry))
}
pub fn default_model_for_alias(
alias: &str,
providers: &HashMap<String, ProviderEntry>,
) -> &'static str {
match providers
.get(alias)
.or_else(|| providers.get(&alias.to_ascii_lowercase()))
{
Some(entry) => default_model_for_entry(alias, entry),
None => default_model_for(alias),
}
}
pub fn parse_provider(name: &str) -> Option<ProviderKind> {
match name.to_lowercase().as_str() {
"openrouter" => Some(ProviderKind::OpenRouter),
"openai" => Some(ProviderKind::OpenAI),
"anthropic" => Some(ProviderKind::Anthropic),
"gemini" | "google" => Some(ProviderKind::Gemini),
"deepseek" => Some(ProviderKind::DeepSeek),
"glm" | "zhipu" => Some(ProviderKind::Glm),
"ollama" => Some(ProviderKind::Ollama),
"custom" => Some(ProviderKind::Custom),
_ => None,
}
}
pub struct ProviderInfo {
pub kind: ProviderKind,
pub base_url: Option<String>,
pub api_key_env: Option<String>,
pub auth: Option<ProviderAuth>,
pub api_key_literal: Option<String>,
}
pub fn resolve_provider_info(
name: &str,
providers: &HashMap<String, ProviderEntry>,
) -> Option<ProviderInfo> {
let lower = name.to_ascii_lowercase();
if let Some(entry) = providers.get(name).or_else(|| providers.get(&lower)) {
let ptype = Config::provider_type_of(name, entry);
let kind = parse_provider(&ptype)?;
if let Some(url) = entry.base_url.as_deref()
&& let Err(err) = validate_custom_provider(
name,
url,
entry.allow_insecure,
false,
)
{
tracing::error!(
target: "dirge::provider",
"{err}"
);
eprintln!("error: {err}");
return None;
}
let api_key_literal = match entry.resolved_api_key() {
Some(Ok(k)) => Some(k),
Some(Err(missing)) => {
tracing::error!(
target: "dirge::provider",
"provider '{name}' references env var ${{{missing}}} via api_key but it is unset",
);
eprintln!(
"error: provider '{name}' references env var ${{{missing}}} via api_key but it is unset"
);
None
}
None => None,
};
return Some(ProviderInfo {
kind,
base_url: entry.base_url.clone(),
api_key_env: entry.api_key_env.clone(),
auth: entry.auth,
api_key_literal,
});
}
if let Some(entry) = plugin_provider(name).or_else(|| plugin_provider(&lower)) {
let ptype = Config::provider_type_of(name, &entry);
let kind = parse_provider(&ptype)?;
if let Some(url) = entry.base_url.as_deref()
&& let Err(err) = validate_custom_provider(
name,
url,
entry.allow_insecure,
true,
)
{
tracing::error!(
target: "dirge::provider",
"{err}"
);
eprintln!("error: {err}");
return None;
}
let api_key_literal = match entry.resolved_api_key() {
Some(Ok(k)) => Some(k),
Some(Err(missing)) => {
tracing::error!(
target: "dirge::provider",
"plugin provider '{name}' references env var ${{{missing}}} via api_key but it is unset",
);
eprintln!(
"error: plugin provider '{name}' references env var ${{{missing}}} via api_key but it is unset"
);
None
}
None => None,
};
return Some(ProviderInfo {
kind,
base_url: entry.base_url,
api_key_env: entry.api_key_env,
auth: entry.auth,
api_key_literal,
});
}
let kind = parse_provider(name)?;
Some(ProviderInfo {
kind,
base_url: None,
api_key_env: None,
auth: None,
api_key_literal: None,
})
}
const BUILTIN_PROVIDER_NAMES: &[&str] = &[
"openai",
"anthropic",
"gemini",
"google",
"deepseek",
"glm",
"zhipu",
"ollama",
"openrouter",
"custom",
];
pub(crate) fn validate_custom_provider(
name: &str,
base_url: &str,
allow_insecure: bool,
enforce_builtin_collision: bool,
) -> Result<(), String> {
if enforce_builtin_collision {
let lower = name.to_ascii_lowercase();
if BUILTIN_PROVIDER_NAMES
.iter()
.any(|b| b.eq_ignore_ascii_case(&lower))
{
return Err(format!(
"Custom provider '{}' collides with built-in provider name. \
Choose a different name.",
name
));
}
}
if !allow_insecure && !base_url.starts_with("https://") {
return Err(format!(
"Custom provider '{}' has insecure base_url '{}'. \
Set allow_insecure: true in config.json if this is a \
local-only endpoint (e.g. ollama, vllm). All other \
http:// URLs send your data in plaintext.",
name, base_url
));
}
if allow_insecure && base_url.starts_with("http://") && !looks_like_local_host(base_url) {
eprintln!(
" ⚠️ WARNING: custom provider '{}' is using http:// over a NON-LOCAL host: {}\n Every prompt, file content, and tool result is sent in plaintext.\n This is allowed because allow_insecure: true is set in config.json,\n but you should verify this is intentional — the typical allow_insecure\n use case is loopback (127.0.0.1 / localhost) endpoints like ollama.",
name, base_url,
);
}
Ok(())
}
fn looks_like_local_host(base_url: &str) -> bool {
let scheme_len = if base_url.len() >= 7 && base_url[..7].eq_ignore_ascii_case("http://") {
7
} else {
return false;
};
let after = &base_url[scheme_len..];
let end = after.find(['/', '?', '#']).unwrap_or(after.len());
let host_and_port = &after[..end];
let host: &str = if let Some(rest) = host_and_port.strip_prefix('[')
&& let Some(end) = rest.find(']')
{
&rest[..end]
} else {
host_and_port
.rsplit_once(':')
.map(|(h, _)| h)
.unwrap_or(host_and_port)
};
let lower = host.to_ascii_lowercase();
if matches!(
lower.as_str(),
"localhost" | "ip6-localhost" | "ip6-loopback"
) {
return true;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return match ip {
std::net::IpAddr::V4(v4) => v4.is_loopback() || v4.is_private() || v4.is_link_local(),
std::net::IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified(),
};
}
lower.ends_with(".local")
}
static PLUGIN_PROVIDERS: OnceLock<HashMap<String, ProviderEntry>> = OnceLock::new();
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
pub fn install_plugin_providers(map: HashMap<String, ProviderEntry>) -> usize {
let size = map.len();
if let Err(rejected) = PLUGIN_PROVIDERS.set(map) {
let in_effect = PLUGIN_PROVIDERS.get().map(|m| m.len()).unwrap_or(0);
tracing::warn!(
target: "dirge::provider",
attempted = rejected.len(),
in_effect,
"plugin providers already installed — ignoring re-registration (runtime hot-reload of providers is not supported)",
);
return in_effect;
}
size
}
fn plugin_provider(name: &str) -> Option<ProviderEntry> {
PLUGIN_PROVIDERS.get().and_then(|m| m.get(name).cloned())
}
fn provider_env_var(kind: ProviderKind) -> &'static str {
match kind {
ProviderKind::OpenAI => "OPENAI_API_KEY",
ProviderKind::Anthropic => "ANTHROPIC_API_KEY",
ProviderKind::Gemini => "GEMINI_API_KEY",
ProviderKind::DeepSeek => "DEEPSEEK_API_KEY",
ProviderKind::Glm => "GLM_API_KEY",
ProviderKind::Ollama => "OLLAMA_API_KEY",
ProviderKind::OpenRouter => "OPENROUTER_API_KEY",
ProviderKind::Custom => "CUSTOM_API_KEY",
}
}
pub fn auto_detect_provider() -> Option<&'static str> {
auto_detect_provider_from(|name| std::env::var(name).ok())
}
pub(crate) const PROVIDER_AUTODETECT_ORDER: &[(&str, &str)] = &[
("DEEPSEEK_API_KEY", "deepseek"),
("OPENAI_API_KEY", "openai"),
("ANTHROPIC_API_KEY", "anthropic"),
("GEMINI_API_KEY", "gemini"),
("GLM_API_KEY", "glm"),
("ZHIPU_API_KEY", "glm"),
("OLLAMA_API_KEY", "ollama"),
("OPENROUTER_API_KEY", "openrouter"),
];
pub(crate) fn auto_detect_provider_from<F: Fn(&str) -> Option<String>>(
env: F,
) -> Option<&'static str> {
for (env_var, provider_name) in PROVIDER_AUTODETECT_ORDER {
if let Some(v) = env(env_var)
&& !v.is_empty()
{
return Some(provider_name);
}
}
None
}
pub(crate) fn provider_env_var_fallbacks(kind: ProviderKind) -> &'static [&'static str] {
match kind {
ProviderKind::Glm => &["ZHIPU_API_KEY"],
ProviderKind::Anthropic => &["ANTHROPIC_OAUTH_TOKEN"],
ProviderKind::Gemini => &["GOOGLE_GENERATIVE_AI_API_KEY", "GOOGLE_API_KEY"],
_ => &[],
}
}
pub(crate) fn resolve_api_key_from<F>(
kind: ProviderKind,
api_key_env_override: Option<&str>,
cli_key: Option<&str>,
env: F,
) -> anyhow::Result<String>
where
F: Fn(&str) -> Option<String>,
{
if let Some(key) = cli_key.filter(|k| !k.is_empty()) {
return Ok(key.to_string());
}
let env_var = api_key_env_override
.filter(|s| !s.is_empty())
.unwrap_or_else(|| provider_env_var(kind));
if let Some(key) = env(env_var)
&& !key.is_empty()
{
return Ok(key);
}
if api_key_env_override.is_none_or(|s| s.is_empty()) {
for fallback in provider_env_var_fallbacks(kind) {
if let Some(key) = env(fallback)
&& !key.is_empty()
{
return Ok(key);
}
}
}
if kind == ProviderKind::Ollama {
return Ok(String::new());
}
if kind == ProviderKind::Custom {
return Ok(String::new());
}
let fallbacks = provider_env_var_fallbacks(kind);
if fallbacks.is_empty() {
anyhow::bail!(
"No API key found for {kind:?}. Set the {env_var} environment variable or pass --api-key."
)
} else {
anyhow::bail!(
"No API key found for {kind:?}. Set {env_var} (or one of: {}) or pass --api-key.",
fallbacks.join(", ")
)
}
}