use std::fmt;
use std::str::FromStr;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Provider {
OpenAI,
Aliyun,
Anthropic,
Google,
Ollama,
Zhipu,
}
impl Provider {
const fn default_base_url(self) -> &'static str {
match self {
Provider::OpenAI => "https://api.openai.com/v1",
Provider::Aliyun => "https://dashscope.aliyuncs.com/compatible-mode/v1",
Provider::Anthropic => "https://api.anthropic.com/v1",
Provider::Google => "https://generativelanguage.googleapis.com/v1beta",
Provider::Ollama => "http://localhost:11434/v1",
Provider::Zhipu => "https://open.bigmodel.cn/api/paas/v4",
}
}
}
impl fmt::Display for Provider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Provider::OpenAI => "openai",
Provider::Aliyun => "aliyun",
Provider::Anthropic => "anthropic",
Provider::Google => "google",
Provider::Ollama => "ollama",
Provider::Zhipu => "zhipu",
};
f.write_str(s)
}
}
impl FromStr for Provider {
type Err = crate::error::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"openai" => Ok(Provider::OpenAI),
"aliyun" => Ok(Provider::Aliyun),
"anthropic" => Ok(Provider::Anthropic),
"google" => Ok(Provider::Google),
"ollama" => Ok(Provider::Ollama),
"zhipu" => Ok(Provider::Zhipu),
other => Err(crate::error::Error::UnknownProvider(other.to_string())),
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub provider: Provider,
pub api_key: String,
pub base_url: String,
pub model: String,
pub dimension: Option<usize>,
pub timeout: Option<Duration>,
pub max_concurrent: Option<usize>,
}
impl ProviderConfig {
pub fn new(provider: Provider, api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self::with_base_url(provider, api_key, provider.default_base_url(), model)
}
pub fn with_base_url(
provider: Provider,
api_key: impl Into<String>,
base_url: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self {
provider,
api_key: api_key.into(),
base_url: base_url.into(),
model: model.into(),
dimension: None,
timeout: None,
max_concurrent: None,
}
}
}
#[cfg(test)]
mod tests {
use super::Provider;
use std::str::FromStr;
#[test]
fn provider_from_str_case_insensitive() {
assert_eq!(Provider::from_str("openai").unwrap(), Provider::OpenAI);
assert_eq!(Provider::from_str("Aliyun").unwrap(), Provider::Aliyun);
assert_eq!(
Provider::from_str("Anthropic").unwrap(),
Provider::Anthropic
);
assert_eq!(Provider::from_str("google").unwrap(), Provider::Google);
}
#[test]
fn provider_from_str_unknown() {
assert!(Provider::from_str("unknown").is_err());
}
#[test]
fn provider_default_base_url_matches_official_endpoints() {
assert_eq!(
Provider::OpenAI.default_base_url(),
"https://api.openai.com/v1"
);
assert_eq!(
Provider::Aliyun.default_base_url(),
"https://dashscope.aliyuncs.com/compatible-mode/v1"
);
assert_eq!(
Provider::Anthropic.default_base_url(),
"https://api.anthropic.com/v1"
);
assert_eq!(
Provider::Google.default_base_url(),
"https://generativelanguage.googleapis.com/v1beta"
);
assert_eq!(
Provider::Ollama.default_base_url(),
"http://localhost:11434/v1"
);
assert_eq!(
Provider::Zhipu.default_base_url(),
"https://open.bigmodel.cn/api/paas/v4"
);
}
#[test]
fn provider_config_new_uses_default_base_url() {
let cfg = super::ProviderConfig::new(Provider::OpenAI, "sk-test", "gpt-test");
assert_eq!(cfg.provider, Provider::OpenAI);
assert_eq!(cfg.api_key, "sk-test");
assert_eq!(cfg.base_url, Provider::OpenAI.default_base_url());
assert_eq!(cfg.model, "gpt-test");
assert_eq!(cfg.dimension, None);
assert_eq!(cfg.timeout, None);
assert_eq!(cfg.max_concurrent, None);
}
#[test]
fn provider_config_with_base_url_preserves_explicit_endpoint() {
let cfg = super::ProviderConfig::with_base_url(
Provider::Zhipu,
"zk-test",
"https://api.z.ai/api/paas/v4",
"glm-test",
);
assert_eq!(cfg.provider, Provider::Zhipu);
assert_eq!(cfg.api_key, "zk-test");
assert_eq!(cfg.base_url, "https://api.z.ai/api/paas/v4");
assert_eq!(cfg.model, "glm-test");
assert_eq!(cfg.dimension, None);
assert_eq!(cfg.timeout, None);
assert_eq!(cfg.max_concurrent, None);
}
}