use echo_core::error::{ConfigError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::OnceLock;
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
pub enum LlmProvider {
#[default]
OpenAi,
Anthropic,
Ollama,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct LlmConfig {
#[serde(default)]
pub provider: LlmProvider,
pub base_url: String,
pub api_key: String,
pub model: String,
}
impl std::fmt::Debug for LlmConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmConfig")
.field("provider", &self.provider)
.field("base_url", &self.base_url)
.field("api_key", &"[REDACTED]")
.field("model", &self.model)
.finish()
}
}
impl LlmConfig {
pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self {
provider: LlmProvider::OpenAi,
base_url: base_url.into(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn from_model(model_name: &str) -> Result<Self> {
let config = Config::get_model(model_name)?;
Ok(Self {
provider: config.provider,
base_url: config.baseurl,
api_key: config.apikey,
model: config.model,
})
}
pub fn from_env(model_name: &str) -> Result<Self> {
Self::from_model(model_name)
}
pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LlmProvider::OpenAi,
base_url: "https://api.openai.com/v1/chat/completions".to_string(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn anthropic(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LlmProvider::Anthropic,
base_url: "https://api.anthropic.com/v1/messages".to_string(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn ollama(model: impl Into<String>) -> Self {
Self {
provider: LlmProvider::Ollama,
base_url: "http://localhost:11434/api/chat".to_string(),
api_key: String::new(),
model: model.into(),
}
}
pub fn deepseek(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LlmProvider::OpenAi,
base_url: "https://api.deepseek.com/chat/completions".to_string(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn dashscope(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LlmProvider::OpenAi,
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
.to_string(),
api_key: api_key.into(),
model: model.into(),
}
}
pub fn custom(
base_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self::new(base_url, api_key, model)
}
pub fn build_client(&self) -> Result<Box<dyn echo_core::llm::LlmClient>> {
match self.provider {
LlmProvider::OpenAi => {
let client = super::openai::OpenAiClient::new(self.clone())?;
Ok(Box::new(client))
}
LlmProvider::Anthropic => {
let client = super::anthropic::AnthropicClient::with_base_url(
&self.base_url,
&self.api_key,
&self.model,
);
Ok(Box::new(client))
}
LlmProvider::Ollama => {
let client =
super::ollama::OllamaClient::with_base_url(&self.base_url, &self.model);
Ok(Box::new(client))
}
}
}
pub fn to_model_config(&self) -> ModelConfig {
ModelConfig {
model: self.model.clone(),
baseurl: self.base_url.clone(),
apikey: self.api_key.clone(),
provider: self.provider.clone(),
}
}
}
pub struct ProviderFactory;
impl ProviderFactory {
pub fn create(config_str: &str) -> Result<Box<dyn echo_core::llm::LlmClient>> {
if let Some((provider_name, model_name)) = config_str.split_once(':') {
Self::from_provider_model(provider_name.trim(), model_name.trim())
} else {
let config = LlmConfig::from_model(config_str)?;
config.build_client()
}
}
pub fn from_config(config: &LlmConfig) -> Result<Box<dyn echo_core::llm::LlmClient>> {
config.build_client()
}
fn from_provider_model(
provider: &str,
model: &str,
) -> Result<Box<dyn echo_core::llm::LlmClient>> {
let base_url = provider_base_url(provider).ok_or_else(|| {
ConfigError::ConfigFileError(format!(
"未知的 provider: '{provider}',\
支持: openai, anthropic, deepseek, dashscope, moonshot, zhipu, ollama"
))
})?;
let api_key = Self::env_api_key(provider);
if api_key.trim().is_empty() && !matches!(provider.to_lowercase().as_str(), "ollama") {
return Err(ConfigError::MissingConfig(
format!("{provider}:{model}"),
format!(
"缺少 API key,请设置以下任一环境变量: {}",
provider_env_var_names(provider).join(", ")
),
)
.into());
}
let llm_provider = parse_provider(provider);
let config = LlmConfig {
provider: llm_provider,
base_url: base_url.to_string(),
api_key,
model: model.to_string(),
};
config.build_client()
}
fn env_api_key(provider: &str) -> String {
let env_vars: &[&str] = match provider.to_lowercase().as_str() {
"anthropic" => &["ANTHROPIC_API_KEY"],
"openai" => &["OPENAI_API_KEY"],
"deepseek" => &["DEEPSEEK_API_KEY"],
"dashscope" | "qwen" | "aliyun" => &["DASHSCOPE_API_KEY", "QWEN_API_KEY"],
"moonshot" | "kimi" => &["MOONSHOT_API_KEY", "KIMI_API_KEY"],
"zhipu" | "glm" => &["ZHIPU_API_KEY", "GLM_API_KEY"],
"ollama" => return String::new(),
_ => return String::new(),
};
first_present_env(env_vars).unwrap_or_default()
}
pub fn supported_providers() -> &'static [&'static str] {
&[
"openai",
"anthropic",
"deepseek",
"dashscope",
"qwen",
"moonshot",
"kimi",
"zhipu",
"glm",
"ollama",
]
}
}
fn provider_base_url(provider: &str) -> Option<&'static str> {
match provider.to_lowercase().as_str() {
"openai" => Some("https://api.openai.com/v1/chat/completions"),
"anthropic" => Some("https://api.anthropic.com/v1/messages"),
"deepseek" => Some("https://api.deepseek.com/chat/completions"),
"dashscope" | "qwen" | "aliyun" => {
Some("https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions")
}
"moonshot" | "kimi" => Some("https://api.moonshot.cn/v1/chat/completions"),
"zhipu" | "glm" => Some("https://open.bigmodel.cn/api/paas/v4/chat/completions"),
"ollama" => Some("http://localhost:11434/v1/chat/completions"),
_ => None,
}
}
fn parse_provider(provider: &str) -> LlmProvider {
match provider.to_lowercase().as_str() {
"anthropic" => LlmProvider::Anthropic,
"ollama" => LlmProvider::Ollama,
_ => LlmProvider::OpenAi,
}
}
fn detect_provider_from_url(url: &str) -> LlmProvider {
let lower = url.to_lowercase();
if lower.contains("anthropic.com") {
LlmProvider::Anthropic
} else if lower.contains("localhost:11434") || lower.contains("ollama") {
LlmProvider::Ollama
} else {
LlmProvider::OpenAi
}
}
#[derive(Debug, Deserialize)]
struct ConfigFile {
models: HashMap<String, ModelEntry>,
#[serde(default)]
embedding: Option<EmbeddingEntry>,
}
#[derive(Deserialize)]
struct ModelEntry {
#[serde(default)]
base_url: Option<String>,
api_key: String,
#[serde(default)]
model: Option<String>,
#[serde(default)]
provider: Option<String>,
}
impl std::fmt::Debug for ModelEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModelEntry")
.field("base_url", &self.base_url)
.field("api_key", &"[REDACTED]")
.field("model", &self.model)
.field("provider", &self.provider)
.finish()
}
}
#[derive(Deserialize)]
struct EmbeddingEntry {
#[serde(default)]
endpoint_url: Option<String>,
#[serde(default)]
base_url: Option<String>,
api_key: String,
#[serde(default)]
model: Option<String>,
#[serde(default)]
timeout_secs: Option<u64>,
}
impl std::fmt::Debug for EmbeddingEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingEntry")
.field("endpoint_url", &self.endpoint_url)
.field("base_url", &self.base_url)
.field("api_key", &"[REDACTED]")
.field("model", &self.model)
.field("timeout_secs", &self.timeout_secs)
.finish()
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct ModelConfig {
pub model: String,
pub baseurl: String,
pub apikey: String,
#[serde(default)]
pub provider: LlmProvider,
}
impl std::fmt::Debug for ModelConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModelConfig")
.field("model", &self.model)
.field("baseurl", &self.baseurl)
.field("apikey", &"[REDACTED]")
.field("provider", &self.provider)
.finish()
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Config {
pub models: HashMap<String, ModelConfig>,
#[serde(default)]
invalid_models: HashMap<String, String>,
#[serde(default)]
pub embedding: Option<EmbeddingConfig>,
#[serde(default)]
invalid_embedding: Option<String>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct EmbeddingConfig {
pub url: String,
pub api_key: String,
pub model: String,
pub timeout_secs: u64,
}
impl std::fmt::Debug for EmbeddingConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingConfig")
.field("url", &self.url)
.field("api_key", &"[REDACTED]")
.field("model", &self.model)
.field("timeout_secs", &self.timeout_secs)
.finish()
}
}
static MODEL_CONFIG: OnceLock<std::result::Result<Config, String>> = OnceLock::new();
impl Config {
pub fn load() -> Result<Self> {
dotenv::dotenv().ok();
if let Some(config) = Self::from_config_file()? {
tracing::info!("已从配置文件加载模型配置");
return Ok(config);
}
Err(ConfigError::ConfigFileError(
"未找到模型配置文件,请提供 echo-agent.yaml;环境变量仅支持通过 `${VAR}` 在 YAML 中注入值".to_string(),
)
.into())
}
fn config_file_path() -> Option<PathBuf> {
if let Ok(path) = std::env::var("ECHO_AGENT_CONFIG") {
let p = PathBuf::from(&path);
if p.exists() {
return Some(p);
}
}
let local = PathBuf::from("./echo-agent.yaml");
if local.exists() {
return Some(local);
}
if let Ok(home) = std::env::var("HOME") {
let global = PathBuf::from(home).join(".echo-agent").join("config.yaml");
if global.exists() {
return Some(global);
}
}
None
}
fn from_config_file() -> Result<Option<Self>> {
let path = match Self::config_file_path() {
Some(p) => p,
None => return Ok(None),
};
tracing::debug!("正在加载配置文件: {}", path.display());
let content = std::fs::read_to_string(&path).map_err(|e| {
ConfigError::ConfigFileError(format!("无法读取配置文件 {}: {}", path.display(), e))
})?;
let file: ConfigFile = serde_yaml::from_str(&content).map_err(|e| {
ConfigError::ConfigFileError(format!("配置文件解析失败 {}: {}", path.display(), e))
})?;
let mut models = HashMap::new();
let mut invalid_models = HashMap::new();
for (key, entry) in file.models {
let parsed: Result<(String, String, String, LlmProvider)> = (|| {
let base_url = match (entry.base_url.as_deref(), entry.provider.as_deref()) {
(Some(url), _) => resolve_env_ref(url),
(None, Some(provider)) => {
let resolved_provider = resolve_env_ref(provider);
provider_base_url(&resolved_provider)
.ok_or_else(|| {
ConfigError::ConfigFileError(format!(
"模型 '{}' 指定了未知的 provider: '{}',\
支持的 provider: openai, anthropic, deepseek, dashscope, moonshot, zhipu, ollama",
key, resolved_provider
))
})?
.to_string()
}
(None, None) => {
return Err(ConfigError::MissingConfig(
key.clone(),
"base_url 或 provider".to_string(),
)
.into());
}
};
let api_key = ensure_resolved_api_key(
&key,
"api_key",
&entry.api_key,
&resolve_env_ref(&entry.api_key),
)?;
let model_name = entry
.model
.as_deref()
.map(resolve_env_ref)
.unwrap_or_else(|| key.clone());
let provider = match entry.provider.as_deref() {
Some(p) => parse_provider(&resolve_env_ref(p)),
None => detect_provider_from_url(&base_url),
};
Ok((base_url, api_key, model_name, provider))
})();
match parsed {
Ok((base_url, api_key, model_name, provider)) => {
let mc = ModelConfig {
model: model_name.clone(),
baseurl: base_url,
apikey: api_key,
provider,
};
models.insert(key.clone(), mc.clone());
if key != model_name {
models.insert(model_name, mc);
}
}
Err(err) => {
tracing::warn!("跳过无效模型配置 {}: {}", key, err);
invalid_models.insert(key.clone(), err.to_string());
}
}
}
let (embedding, invalid_embedding) = match file.embedding {
Some(entry) => {
let parsed: Result<EmbeddingConfig> = (|| {
let url = match (entry.endpoint_url.as_deref(), entry.base_url.as_deref()) {
(Some(url), _) => resolve_env_ref(url),
(None, Some(base)) => {
let resolved = resolve_env_ref(base);
format!("{}/v1/embeddings", resolved.trim_end_matches('/'))
}
(None, None) => {
return Err(ConfigError::MissingConfig(
"embedding".to_string(),
"endpoint_url 或 base_url".to_string(),
)
.into());
}
};
Ok(EmbeddingConfig {
url,
api_key: ensure_resolved_api_key(
"embedding",
"api_key",
&entry.api_key,
&resolve_env_ref(&entry.api_key),
)?,
model: entry
.model
.as_deref()
.map(resolve_env_ref)
.unwrap_or_else(|| "text-embedding-3-small".to_string()),
timeout_secs: entry.timeout_secs.unwrap_or(30),
})
})();
match parsed {
Ok(cfg) => (Some(cfg), None),
Err(err) => {
tracing::warn!("跳过无效 embedding 配置: {}", err);
(None, Some(err.to_string()))
}
}
}
None => (None, None),
};
Ok(Some(Config {
models,
invalid_models,
embedding,
invalid_embedding,
}))
}
pub fn load_cached() -> Result<&'static Config> {
let result = MODEL_CONFIG.get_or_init(|| Config::load().map_err(|e| e.to_string()));
match result {
Ok(config) => Ok(config),
Err(msg) => Err(ConfigError::ConfigFileError(msg.clone()).into()),
}
}
pub fn get_model(model: &str) -> Result<ModelConfig> {
let config = Self::load_cached()?;
if let Some(err) = config.invalid_models.get(model) {
return Err(ConfigError::ConfigFileError(err.clone()).into());
}
Ok(config
.models
.get(model)
.ok_or_else(|| {
let available: Vec<&str> = config.models.keys().map(|k| k.as_str()).collect();
ConfigError::NotFindModelError(format!(
"{}(可用模型: {})",
model,
if available.is_empty() {
"无,请创建 echo-agent.yaml 并在其中声明 models.*".to_string()
} else {
available.join(", ")
}
))
})
.cloned()?)
}
pub fn has_model(model: &str) -> bool {
Self::load_cached()
.map(|config| config.models.contains_key(model))
.unwrap_or(false)
}
pub fn list_models() -> Vec<String> {
Self::load_cached()
.map(|config| config.models.keys().cloned().collect())
.unwrap_or_default()
}
pub fn get_embedding() -> Result<EmbeddingConfig> {
let config = Self::load_cached()?;
if let Some(err) = &config.invalid_embedding {
return Err(ConfigError::ConfigFileError(err.clone()).into());
}
config.embedding.clone().ok_or_else(|| {
ConfigError::MissingConfig(
"embedding".to_string(),
"请在 echo-agent.yaml 中配置 embedding 段".to_string(),
)
.into()
})
}
pub fn has_embedding() -> bool {
Self::load_cached()
.map(|config| config.embedding.is_some())
.unwrap_or(false)
}
pub fn from_env() -> Result<Self> {
Self::load()
}
}
fn resolve_env_ref(value: &str) -> String {
if !value.contains("${") {
return value.to_string();
}
let mut result = value.to_string();
let mut search_from = 0;
while let Some(rel_start) = result[search_from..].find("${") {
let start = search_from + rel_start;
if let Some(rel_end) = result[start..].find('}') {
let end = start + rel_end;
let var_name = &result[start + 2..end];
match std::env::var(var_name).ok().or_else(|| {
fallback_env_alias(var_name)
.and_then(std::env::var_os)
.map(|v| v.to_string_lossy().into_owned())
}) {
Some(val) => {
result = format!("{}{}{}", &result[..start], val, &result[end + 1..]);
search_from = start + val.len();
}
None => {
tracing::warn!("环境变量 {} 未设置", var_name);
search_from = end + 1;
}
}
} else {
break;
}
}
result
}
fn first_present_env(names: &[&str]) -> Option<String> {
names.iter().find_map(|name| {
std::env::var(name)
.ok()
.filter(|value| !value.trim().is_empty())
})
}
fn fallback_env_alias(var_name: &str) -> Option<&'static str> {
match var_name {
"DASHSCOPE_API_KEY" => Some("QWEN_API_KEY"),
"QWEN_API_KEY" => Some("DASHSCOPE_API_KEY"),
"MOONSHOT_API_KEY" => Some("KIMI_API_KEY"),
"KIMI_API_KEY" => Some("MOONSHOT_API_KEY"),
"ZHIPU_API_KEY" => Some("GLM_API_KEY"),
"GLM_API_KEY" => Some("ZHIPU_API_KEY"),
_ => None,
}
}
fn provider_env_var_names(provider: &str) -> &'static [&'static str] {
match provider.to_lowercase().as_str() {
"anthropic" => &["ANTHROPIC_API_KEY"],
"openai" => &["OPENAI_API_KEY"],
"deepseek" => &["DEEPSEEK_API_KEY"],
"dashscope" | "qwen" | "aliyun" => &["DASHSCOPE_API_KEY", "QWEN_API_KEY"],
"moonshot" | "kimi" => &["MOONSHOT_API_KEY", "KIMI_API_KEY"],
"zhipu" | "glm" => &["ZHIPU_API_KEY", "GLM_API_KEY"],
"ollama" => &[],
_ => &[],
}
}
fn ensure_resolved_api_key(
scope: &str,
field: &str,
raw_value: &str,
resolved_value: &str,
) -> Result<String> {
if !raw_value.contains("${") {
return Ok(resolved_value.to_string());
}
let unresolved: Vec<String> = extract_env_refs(raw_value)
.into_iter()
.filter(|name| {
std::env::var(name).is_err()
&& fallback_env_alias(name)
.and_then(std::env::var_os)
.is_none()
})
.collect();
if !unresolved.is_empty() || resolved_value.contains("${") || resolved_value.trim().is_empty() {
let details = unresolved
.into_iter()
.map(|name| {
if let Some(alias) = fallback_env_alias(&name) {
format!("{name}(或别名 {alias})")
} else {
name
}
})
.collect::<Vec<_>>()
.join(", ");
return Err(ConfigError::MissingConfig(
scope.to_string(),
if details.is_empty() {
format!("{field} 未解析为有效值")
} else {
format!("{field} 依赖的环境变量未设置: {details}")
},
)
.into());
}
Ok(resolved_value.to_string())
}
fn extract_env_refs(value: &str) -> Vec<String> {
let mut refs = Vec::new();
let mut search_from = 0;
while let Some(rel_start) = value[search_from..].find("${") {
let start = search_from + rel_start;
if let Some(rel_end) = value[start..].find('}') {
let end = start + rel_end;
refs.push(value[start + 2..end].to_string());
search_from = end + 1;
} else {
break;
}
}
refs
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_test_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
}
#[test]
fn test_llm_config_new() {
let config = LlmConfig::new("https://example.com", "sk-test", "gpt-4o");
assert_eq!(config.base_url, "https://example.com");
assert_eq!(config.api_key, "sk-test");
assert_eq!(config.model, "gpt-4o");
}
#[test]
fn test_llm_config_openai() {
let config = LlmConfig::openai("sk-test", "gpt-4o");
assert!(config.base_url.contains("openai.com"));
assert_eq!(config.model, "gpt-4o");
}
#[test]
fn test_llm_config_deepseek() {
let config = LlmConfig::deepseek("sk-test", "deepseek-chat");
assert!(config.base_url.contains("deepseek.com"));
assert_eq!(config.model, "deepseek-chat");
}
#[test]
fn test_llm_config_dashscope() {
let config = LlmConfig::dashscope("sk-test", "qwen3-max");
assert!(config.base_url.contains("dashscope.aliyuncs.com"));
assert_eq!(config.model, "qwen3-max");
}
#[test]
fn test_llm_config_anthropic() {
let config = LlmConfig::anthropic("sk-test", "claude-sonnet-4-6");
assert!(config.base_url.contains("anthropic.com"));
assert_eq!(config.model, "claude-sonnet-4-6");
}
#[test]
fn test_provider_base_url() {
assert!(provider_base_url("openai").is_some());
assert!(provider_base_url("anthropic").is_some());
assert!(provider_base_url("deepseek").is_some());
assert!(provider_base_url("dashscope").is_some());
assert!(provider_base_url("qwen").is_some());
assert!(provider_base_url("ollama").is_some());
assert!(provider_base_url("unknown_provider").is_none());
}
#[test]
fn test_resolve_env_ref_plain() {
assert_eq!(resolve_env_ref("sk-plain-key"), "sk-plain-key");
}
#[test]
fn test_resolve_env_ref_with_var() {
let _guard = env_test_lock();
unsafe { std::env::set_var("TEST_ECHO_KEY", "resolved-value") };
assert_eq!(resolve_env_ref("${TEST_ECHO_KEY}"), "resolved-value");
unsafe { std::env::remove_var("TEST_ECHO_KEY") };
}
#[test]
fn test_resolve_env_ref_missing_var() {
let _guard = env_test_lock();
let result = resolve_env_ref("${NONEXISTENT_VAR_12345}");
assert_eq!(result, "${NONEXISTENT_VAR_12345}");
}
#[test]
fn test_resolve_env_ref_supports_dashscope_qwen_alias() {
let _guard = env_test_lock();
unsafe {
std::env::remove_var("DASHSCOPE_API_KEY");
std::env::set_var("QWEN_API_KEY", "qwen-alias-value");
}
assert_eq!(resolve_env_ref("${DASHSCOPE_API_KEY}"), "qwen-alias-value");
unsafe {
std::env::remove_var("QWEN_API_KEY");
}
}
#[test]
fn test_env_api_key_supports_qwen_alias() {
let _guard = env_test_lock();
unsafe {
std::env::remove_var("DASHSCOPE_API_KEY");
std::env::set_var("QWEN_API_KEY", "qwen-provider-key");
}
assert_eq!(
ProviderFactory::env_api_key("dashscope"),
"qwen-provider-key"
);
assert_eq!(ProviderFactory::env_api_key("qwen"), "qwen-provider-key");
unsafe {
std::env::remove_var("QWEN_API_KEY");
}
}
#[test]
fn test_ensure_resolved_api_key_reports_missing_alias_group() {
let _guard = env_test_lock();
unsafe {
std::env::remove_var("DASHSCOPE_API_KEY");
std::env::remove_var("QWEN_API_KEY");
}
let err = ensure_resolved_api_key(
"qwen3.6-plus",
"api_key",
"${DASHSCOPE_API_KEY}",
"${DASHSCOPE_API_KEY}",
)
.unwrap_err();
assert!(format!("{err}").contains("DASHSCOPE_API_KEY"));
assert!(format!("{err}").contains("QWEN_API_KEY"));
}
#[test]
fn test_config_from_yaml_string() {
let yaml = r#"
models:
test-model:
base_url: https://api.example.com/v1/chat
api_key: sk-test-key
alias-model:
provider: openai
api_key: sk-alias
model: gpt-4o-mini
"#;
let file: ConfigFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(file.models.len(), 2);
assert!(file.models.contains_key("test-model"));
assert!(file.models.contains_key("alias-model"));
let entry = &file.models["alias-model"];
assert_eq!(entry.provider.as_deref(), Some("openai"));
assert_eq!(entry.model.as_deref(), Some("gpt-4o-mini"));
}
#[test]
fn test_config_from_yaml_with_embedding() {
let yaml = r#"
models:
test-model:
provider: openai
api_key: sk-test
embedding:
base_url: https://api.openai.com
api_key: ${TEST_EMBED_KEY}
model: text-embedding-3-small
timeout_secs: 45
"#;
unsafe { std::env::set_var("TEST_EMBED_KEY", "embed-key") };
let file: ConfigFile = serde_yaml::from_str(yaml).unwrap();
let entry = file.embedding.expect("embedding should exist");
assert_eq!(entry.base_url.as_deref(), Some("https://api.openai.com"));
assert_eq!(resolve_env_ref(&entry.api_key), "embed-key");
unsafe { std::env::remove_var("TEST_EMBED_KEY") };
}
#[test]
fn test_to_model_config() {
let config = LlmConfig::new("https://example.com", "sk-test", "model-1");
let mc = config.to_model_config();
assert_eq!(mc.model, "model-1");
assert_eq!(mc.baseurl, "https://example.com");
assert_eq!(mc.apikey, "sk-test");
assert_eq!(mc.provider, LlmProvider::OpenAi);
}
#[test]
fn test_to_model_config_anthropic() {
let config = LlmConfig::anthropic("sk-ant-test", "claude-sonnet-4-6");
let mc = config.to_model_config();
assert_eq!(mc.provider, LlmProvider::Anthropic);
}
#[test]
fn test_parse_provider() {
assert_eq!(parse_provider("anthropic"), LlmProvider::Anthropic);
assert_eq!(parse_provider("Anthropic"), LlmProvider::Anthropic);
assert_eq!(parse_provider("ollama"), LlmProvider::Ollama);
assert_eq!(parse_provider("openai"), LlmProvider::OpenAi);
assert_eq!(parse_provider("deepseek"), LlmProvider::OpenAi);
assert_eq!(parse_provider("unknown"), LlmProvider::OpenAi);
}
#[test]
fn test_detect_provider_from_url() {
assert_eq!(
detect_provider_from_url("https://api.anthropic.com/v1/messages"),
LlmProvider::Anthropic,
);
assert_eq!(
detect_provider_from_url("http://localhost:11434/api/chat"),
LlmProvider::Ollama,
);
assert_eq!(
detect_provider_from_url("https://api.openai.com/v1/chat/completions"),
LlmProvider::OpenAi,
);
assert_eq!(
detect_provider_from_url("https://api.deepseek.com/chat/completions"),
LlmProvider::OpenAi,
);
}
#[test]
fn test_provider_factory_supported_providers() {
let providers = ProviderFactory::supported_providers();
assert!(providers.contains(&"openai"));
assert!(providers.contains(&"anthropic"));
assert!(providers.contains(&"ollama"));
assert!(providers.contains(&"deepseek"));
}
#[test]
fn test_provider_factory_parse_config_str() {
assert_eq!(ProviderFactory::env_api_key("ollama"), "");
assert_eq!(ProviderFactory::env_api_key("unknown"), "");
}
#[test]
fn test_config_from_yaml_with_provider_detection() {
let yaml = r#"
models:
claude-test:
base_url: https://api.anthropic.com/v1/messages
api_key: sk-test
ollama-test:
base_url: http://localhost:11434/api/chat
api_key: ""
openai-test:
provider: openai
api_key: sk-test
"#;
let file: ConfigFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(file.models.len(), 3);
}
}