use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Provider {
#[default]
OpenAI,
Anthropic,
Google,
}
impl Provider {
pub const ALL: &'static [Provider] = &[Provider::OpenAI, Provider::Anthropic, Provider::Google];
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::OpenAI => "openai",
Self::Anthropic => "anthropic",
Self::Google => "google",
}
}
#[must_use]
pub const fn default_model(&self) -> &'static str {
match self {
Self::OpenAI => "gpt-5.4",
Self::Anthropic => "claude-opus-4-6",
Self::Google => "gemini-3-pro-preview",
}
}
#[must_use]
pub const fn default_fast_model(&self) -> &'static str {
match self {
Self::OpenAI => "gpt-5.4-mini",
Self::Anthropic => "claude-haiku-4-5-20251001",
Self::Google => "gemini-2.5-flash",
}
}
#[must_use]
pub const fn context_window(&self) -> usize {
match self {
Self::OpenAI => 128_000,
Self::Anthropic => 200_000,
Self::Google => 1_000_000,
}
}
#[must_use]
pub const fn api_key_env(&self) -> &'static str {
match self {
Self::OpenAI => "OPENAI_API_KEY",
Self::Anthropic => "ANTHROPIC_API_KEY",
Self::Google => "GOOGLE_API_KEY",
}
}
#[must_use]
pub fn api_key_prefixes(&self) -> &'static [&'static str] {
match self {
Self::OpenAI => &["sk-", "sk-proj-"],
Self::Anthropic => &["sk-ant-"],
Self::Google => &[], }
}
#[must_use]
pub const fn api_key_prefix(&self) -> Option<&'static str> {
match self {
Self::OpenAI => Some("sk-"),
Self::Anthropic => Some("sk-ant-"),
Self::Google => None,
}
}
pub fn validate_api_key_format(&self, key: &str) -> Result<(), String> {
if key.len() < 20 {
return Err(format!(
"{} API key appears too short (got {} chars, expected 20+)",
self.name(),
key.len()
));
}
let prefixes = self.api_key_prefixes();
if !prefixes.is_empty() && !prefixes.iter().any(|p| key.starts_with(p)) {
let expected = if prefixes.len() == 1 {
format!("'{}'", prefixes[0])
} else {
prefixes
.iter()
.map(|p| format!("'{p}'"))
.collect::<Vec<_>>()
.join(" or ")
};
return Err(format!(
"{} API key should start with {} (key has unexpected prefix)",
self.name(),
expected
));
}
Ok(())
}
pub fn all_names() -> Vec<&'static str> {
Self::ALL.iter().map(Self::name).collect()
}
}
impl FromStr for Provider {
type Err = ProviderError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let lower = s.to_lowercase();
let normalized = match lower.as_str() {
"claude" => "anthropic",
"gemini" => "google",
_ => &lower,
};
Self::ALL
.iter()
.find(|p| p.name() == normalized)
.copied()
.ok_or_else(|| ProviderError::Unknown(s.to_string()))
}
}
impl fmt::Display for Provider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("Unknown provider: {0}. Supported: openai, anthropic, google")]
Unknown(String),
#[error("API key required for provider: {0}")]
MissingApiKey(String),
}
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ProviderConfig {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub api_key: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fast_model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_limit: Option<usize>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub additional_params: HashMap<String, String>,
}
impl fmt::Debug for ProviderConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProviderConfig")
.field(
"api_key",
if self.api_key.is_empty() {
&"<empty>"
} else {
&"[REDACTED]"
},
)
.field("model", &self.model)
.field("fast_model", &self.fast_model)
.field("token_limit", &self.token_limit)
.field("additional_params", &self.additional_params)
.finish()
}
}
impl ProviderConfig {
#[must_use]
pub fn with_defaults(provider: Provider) -> Self {
Self {
api_key: String::new(),
model: provider.default_model().to_string(),
fast_model: Some(provider.default_fast_model().to_string()),
token_limit: None,
additional_params: HashMap::new(),
}
}
#[must_use]
pub fn effective_model(&self, provider: Provider) -> &str {
if self.model.is_empty() {
provider.default_model()
} else {
&self.model
}
}
#[must_use]
pub fn effective_fast_model(&self, provider: Provider) -> &str {
self.fast_model
.as_deref()
.unwrap_or_else(|| provider.default_fast_model())
}
#[must_use]
pub fn effective_token_limit(&self, provider: Provider) -> usize {
self.token_limit
.unwrap_or_else(|| provider.context_window())
}
#[must_use]
pub fn has_api_key(&self) -> bool {
!self.api_key.is_empty()
}
#[must_use]
pub fn api_key_if_set(&self) -> Option<&str> {
if self.api_key.is_empty() {
None
} else {
Some(&self.api_key)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_from_str() {
assert_eq!("openai".parse::<Provider>().ok(), Some(Provider::OpenAI));
assert_eq!(
"ANTHROPIC".parse::<Provider>().ok(),
Some(Provider::Anthropic)
);
assert_eq!("claude".parse::<Provider>().ok(), Some(Provider::Anthropic)); assert_eq!("gemini".parse::<Provider>().ok(), Some(Provider::Google)); assert!("invalid".parse::<Provider>().is_err());
}
#[test]
fn test_provider_defaults() {
assert_eq!(Provider::OpenAI.default_model(), "gpt-5.4");
assert_eq!(Provider::OpenAI.default_fast_model(), "gpt-5.4-mini");
assert_eq!(Provider::Anthropic.context_window(), 200_000);
assert_eq!(Provider::Google.api_key_env(), "GOOGLE_API_KEY");
}
#[test]
fn test_provider_config_defaults() {
let config = ProviderConfig::with_defaults(Provider::Anthropic);
assert_eq!(config.model, "claude-opus-4-6");
assert_eq!(
config.fast_model.as_deref(),
Some("claude-haiku-4-5-20251001")
);
}
#[test]
fn test_api_key_prefix() {
assert_eq!(Provider::OpenAI.api_key_prefix(), Some("sk-"));
assert_eq!(Provider::Anthropic.api_key_prefix(), Some("sk-ant-"));
assert_eq!(Provider::Google.api_key_prefix(), None);
}
#[test]
fn test_api_key_if_set() {
let mut config = ProviderConfig::with_defaults(Provider::OpenAI);
config.api_key = "sk-test-key-12345678901234567890".to_string();
assert_eq!(
config.api_key_if_set(),
Some("sk-test-key-12345678901234567890")
);
config.api_key = String::new();
assert_eq!(config.api_key_if_set(), None);
}
#[test]
fn test_api_key_prefixes() {
assert_eq!(Provider::OpenAI.api_key_prefixes(), &["sk-", "sk-proj-"]);
assert_eq!(Provider::Anthropic.api_key_prefixes(), &["sk-ant-"]);
assert!(Provider::Google.api_key_prefixes().is_empty());
}
#[test]
fn test_api_key_validation_valid_openai() {
let result = Provider::OpenAI.validate_api_key_format("sk-1234567890abcdefghijklmnop");
assert!(result.is_ok());
}
#[test]
fn test_api_key_validation_valid_openai_project_key() {
let result = Provider::OpenAI.validate_api_key_format("sk-proj-1234567890abcdefghijklmnop");
assert!(result.is_ok());
}
#[test]
fn test_api_key_validation_valid_anthropic() {
let result =
Provider::Anthropic.validate_api_key_format("sk-ant-1234567890abcdefghijklmnop");
assert!(result.is_ok());
}
#[test]
fn test_api_key_validation_valid_google() {
let result = Provider::Google.validate_api_key_format("AIzaSyA1234567890abcdefgh");
assert!(result.is_ok());
}
#[test]
fn test_api_key_validation_too_short() {
let result = Provider::OpenAI.validate_api_key_format("sk-short");
assert!(result.is_err());
assert!(result.expect_err("should be err").contains("too short"));
}
#[test]
fn test_api_key_validation_wrong_prefix_openai() {
let result = Provider::OpenAI.validate_api_key_format("wrong-prefix-1234567890abcdef");
assert!(result.is_err());
let err = result.expect_err("should be err");
assert!(err.contains("should start with"));
assert!(err.contains("'sk-'") || err.contains("'sk-proj-'"));
assert!(!err.contains("wrong-"));
}
#[test]
fn test_api_key_validation_wrong_prefix_anthropic() {
let result = Provider::Anthropic.validate_api_key_format("sk-1234567890abcdefghijklmnop");
assert!(result.is_err());
let err = result.expect_err("should be err");
assert!(err.contains("sk-ant-"));
assert!(err.contains("unexpected prefix"));
}
}