litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! SDK配置模块

use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// SDK客户端配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
    /// 默认provider
    pub default_provider: Option<String>,
    /// provider配置列表
    pub providers: Vec<ProviderConfig>,
    /// 全局设置
    pub settings: ClientSettings,
}

impl Default for ClientConfig {
    fn default() -> Self {
        Self {
            default_provider: None,
            providers: Vec::new(),
            settings: ClientSettings::default(),
        }
    }
}

/// 全局客户端设置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientSettings {
    /// 请求超时(秒)
    pub timeout: u64,
    /// 重试次数
    pub max_retries: u32,
    /// 并发请求数限制
    pub max_concurrent_requests: u32,
    /// 启用请求日志
    pub enable_logging: bool,
    /// 启用指标收集
    pub enable_metrics: bool,
}

impl Default for ClientSettings {
    fn default() -> Self {
        Self {
            timeout: 30,
            max_retries: 3,
            max_concurrent_requests: 100,
            enable_logging: true,
            enable_metrics: true,
        }
    }
}

/// Provider配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
    /// Provider唯一ID
    pub id: String,
    /// Provider类型
    pub provider_type: ProviderType,
    /// 显示名称
    pub name: String,
    /// API密钥
    pub api_key: String,
    /// 基础URL(可选)
    pub base_url: Option<String>,
    /// 模型列表
    pub models: Vec<String>,
    /// 启用状态
    pub enabled: bool,
    /// 权重(用于负载均衡)
    pub weight: f32,
    /// 每分钟请求限制
    pub rate_limit_rpm: Option<u32>,
    /// 每分钟token限制
    pub rate_limit_tpm: Option<u32>,
    /// Provider特定设置
    pub settings: HashMap<String, serde_json::Value>,
}

/// Provider类型枚举
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderType {
    /// OpenAI provider (GPT models)
    OpenAI,
    /// Anthropic provider (Claude models)
    Anthropic,
    /// Azure OpenAI provider
    Azure,
    /// Google provider (PaLM, Gemini models)
    Google,
    /// Cohere provider
    Cohere,
    /// Hugging Face provider
    HuggingFace,
    /// Ollama provider (local models)
    Ollama,
    /// AWS Bedrock provider
    AwsBedrock,
    /// Google Vertex AI provider
    GoogleVertex,
    /// Mistral provider
    Mistral,
    /// Custom provider with specified name
    Custom(String),
}

impl From<&str> for ProviderType {
    fn from(s: &str) -> Self {
        match s.to_lowercase().as_str() {
            "openai" => ProviderType::OpenAI,
            "anthropic" => ProviderType::Anthropic,
            "azure" => ProviderType::Azure,
            "google" => ProviderType::Google,
            "cohere" => ProviderType::Cohere,
            "huggingface" => ProviderType::HuggingFace,
            "ollama" => ProviderType::Ollama,
            "aws_bedrock" => ProviderType::AwsBedrock,
            "google_vertex" => ProviderType::GoogleVertex,
            "mistral" => ProviderType::Mistral,
            _ => ProviderType::Custom(s.to_string()),
        }
    }
}

/// 配置构建器模式
pub struct ConfigBuilder {
    config: ClientConfig,
}

impl ConfigBuilder {
    /// Create a new configuration builder
    pub fn new() -> Self {
        Self {
            config: ClientConfig::default(),
        }
    }

    /// 设置默认provider
    pub fn default_provider(mut self, provider_id: &str) -> Self {
        self.config.default_provider = Some(provider_id.to_string());
        self
    }

    /// 添加provider
    pub fn add_provider(mut self, provider: ProviderConfig) -> Self {
        self.config.providers.push(provider);
        self
    }

    /// 添加OpenAI provider
    pub fn add_openai(self, id: &str, api_key: &str) -> Self {
        self.add_provider(ProviderConfig {
            id: id.to_string(),
            provider_type: ProviderType::OpenAI,
            name: "OpenAI".to_string(),
            api_key: api_key.to_string(),
            base_url: None,
            models: vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()],
            enabled: true,
            weight: 1.0,
            rate_limit_rpm: Some(3000),
            rate_limit_tpm: Some(250000),
            settings: HashMap::new(),
        })
    }

    /// 添加Anthropic provider
    pub fn add_anthropic(self, id: &str, api_key: &str) -> Self {
        self.add_provider(ProviderConfig {
            id: id.to_string(),
            provider_type: ProviderType::Anthropic,
            name: "Anthropic".to_string(),
            api_key: api_key.to_string(),
            base_url: None,
            models: vec![
                "claude-3-opus-20240229".to_string(),
                "claude-3-sonnet-20240229".to_string(),
            ],
            enabled: true,
            weight: 1.0,
            rate_limit_rpm: Some(1000),
            rate_limit_tpm: Some(100000),
            settings: HashMap::new(),
        })
    }

    /// 设置超时
    pub fn timeout(mut self, timeout: u64) -> Self {
        self.config.settings.timeout = timeout;
        self
    }

    /// 设置重试次数
    pub fn max_retries(mut self, retries: u32) -> Self {
        self.config.settings.max_retries = retries;
        self
    }

    /// 构建配置
    pub fn build(self) -> ClientConfig {
        self.config
    }
}

impl Default for ConfigBuilder {
    fn default() -> Self {
        Self::new()
    }
}

/// 从环境变量构建配置
impl ClientConfig {
    /// 从环境变量加载配置
    pub fn from_env() -> crate::sdk::errors::Result<Self> {
        let mut builder = ConfigBuilder::new();

        // 检查OpenAI配置
        if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
            builder = builder.add_openai("openai", &api_key);
        }

        // 检查Anthropic配置
        if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
            builder = builder.add_anthropic("anthropic", &api_key);
        }

        let config = builder.build();

        if config.providers.is_empty() {
            return Err(crate::sdk::errors::SDKError::ConfigError(
                "No providers configured. Please set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variables.".to_string()
            ));
        }

        Ok(config)
    }

    /// 从文件加载配置
    pub fn from_file(path: &str) -> crate::sdk::errors::Result<Self> {
        let content = std::fs::read_to_string(path).map_err(|e| {
            crate::sdk::errors::SDKError::ConfigError(format!(
                "Failed to read config file {}: {}",
                path, e
            ))
        })?;

        serde_yaml::from_str(&content).map_err(|e| {
            crate::sdk::errors::SDKError::ConfigError(format!(
                "Failed to parse config file {}: {}",
                path, e
            ))
        })
    }
}