use crate::error::{GatewayError, Result};
use crate::providers::ProviderType;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub api_key: String,
pub base_url: Option<String>,
pub default_model: Option<String>,
pub extra: HashMap<String, String>,
}
impl ProviderConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: None,
default_model: None,
extra: HashMap::new(),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = Some(model.into());
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
pub default_provider: ProviderType,
pub providers: HashMap<ProviderType, ProviderConfig>,
pub timeout_seconds: Option<u64>,
pub max_retries: Option<u32>,
}
impl GatewayConfig {
pub fn new(default_provider: ProviderType) -> Self {
Self {
default_provider,
providers: HashMap::new(),
timeout_seconds: None,
max_retries: None,
}
}
pub fn with_provider(mut self, provider_type: ProviderType, config: ProviderConfig) -> Self {
self.providers.insert(provider_type, config);
self
}
pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
self.timeout_seconds = Some(timeout_seconds);
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = Some(max_retries);
self
}
pub fn from_env() -> Result<Self> {
let mut config = Self::new(ProviderType::OpenAI);
if let Ok(provider_model) = env::var("LLM_MODEL") {
let parts: Vec<&str> = provider_model.split('/').collect();
if parts.len() == 2 {
let provider_str = parts[0];
let _model_name = parts[1];
let provider_type = provider_str.parse::<ProviderType>().map_err(|_| {
GatewayError::Config(format!("Invalid provider in LLM_MODEL: {}", provider_str))
})?;
config.default_provider = provider_type;
} else {
return Err(GatewayError::Config(
"LLM_MODEL must be in format 'provider/model' (e.g., 'openai/gpt-4')"
.to_string(),
));
}
}
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
let mut provider_config = ProviderConfig::new(api_key);
if let Ok(base_url) = env::var("OPENAI_BASE_URL") {
provider_config = provider_config.with_base_url(base_url);
}
if let Ok(model) = env::var("OPENAI_MODEL") {
provider_config = provider_config.with_default_model(model);
}
config
.providers
.insert(ProviderType::OpenAI, provider_config);
}
if let Ok(api_key) = env::var("AZURE_OPENAI_API_KEY") {
let mut provider_config = ProviderConfig::new(api_key);
if let Ok(endpoint) = env::var("AZURE_OPENAI_ENDPOINT") {
provider_config = provider_config.with_base_url(endpoint);
}
if let Ok(deployment) = env::var("AZURE_OPENAI_DEPLOYMENT") {
provider_config = provider_config.with_default_model(deployment);
}
if let Ok(api_version) = env::var("AZURE_OPENAI_API_VERSION") {
provider_config = provider_config.with_extra("api_version", api_version);
}
config
.providers
.insert(ProviderType::Azure, provider_config);
}
if let Ok(api_key) = env::var("DEEPSEEK_API_KEY") {
let mut provider_config = ProviderConfig::new(api_key);
if let Ok(base_url) = env::var("DEEPSEEK_BASE_URL") {
provider_config = provider_config.with_base_url(base_url);
}
if let Ok(model) = env::var("DEEPSEEK_MODEL") {
provider_config = provider_config.with_default_model(model);
}
config
.providers
.insert(ProviderType::DeepSeek, provider_config);
}
if let Ok(api_key) = env::var("QWEN_API_KEY") {
let mut provider_config = ProviderConfig::new(api_key);
if let Ok(base_url) = env::var("QWEN_BASE_URL") {
provider_config = provider_config.with_base_url(base_url);
}
if let Ok(model) = env::var("QWEN_MODEL") {
provider_config = provider_config.with_default_model(model);
}
config.providers.insert(ProviderType::Qwen, provider_config);
}
if let Ok(provider_model) = env::var("LLM_MODEL") {
let parts: Vec<&str> = provider_model.split('/').collect();
if parts.len() == 2 {
let provider_str = parts[0];
let model_name = parts[1];
if let Ok(provider_type) = provider_str.parse::<ProviderType>() {
if let Some(provider_config) = config.providers.get_mut(&provider_type) {
provider_config.default_model = Some(model_name.to_string());
}
}
}
}
if let Ok(timeout) = env::var("LLM_TIMEOUT_SECONDS") {
config.timeout_seconds = Some(
timeout
.parse()
.map_err(|_| GatewayError::Config("Invalid timeout value".to_string()))?,
);
}
if let Ok(retries) = env::var("LLM_MAX_RETRIES") {
config.max_retries = Some(
retries
.parse()
.map_err(|_| GatewayError::Config("Invalid max retries value".to_string()))?,
);
}
if config.providers.is_empty() {
return Err(GatewayError::Config(
"No providers configured. Please set at least one provider's API key.".to_string(),
));
}
if !config.providers.contains_key(&config.default_provider) {
return Err(GatewayError::Config(format!(
"Default provider {:?} is not configured",
config.default_provider
)));
}
Ok(config)
}
pub fn get_provider_config(&self, provider_type: &ProviderType) -> Option<&ProviderConfig> {
self.providers.get(provider_type)
}
pub fn get_default_provider_config(&self) -> Option<&ProviderConfig> {
self.get_provider_config(&self.default_provider)
}
}