use crate::{LLMConfig, LLMProvider, Result, UbiquityError};
use std::time::Duration;
use std::env;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct LLMConfigBuilder {
provider: Option<LLMProvider>,
api_key: Option<String>,
model: Option<String>,
temperature: f32,
max_tokens: usize,
timeout: Duration,
retry_attempts: u32,
retry_delay: Duration,
}
impl Default for LLMConfigBuilder {
fn default() -> Self {
Self {
provider: None,
api_key: None,
model: None,
temperature: 0.7,
max_tokens: 4096,
timeout: Duration::from_secs(120),
retry_attempts: 3,
retry_delay: Duration::from_secs(1),
}
}
}
impl LLMConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn provider(mut self, provider: LLMProvider) -> Self {
self.provider = Some(provider);
self
}
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature.clamp(0.0, 1.0);
self
}
pub fn max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn retry_attempts(mut self, attempts: u32) -> Self {
self.retry_attempts = attempts;
self
}
pub fn retry_delay(mut self, delay: Duration) -> Self {
self.retry_delay = delay;
self
}
pub fn build(self) -> Result<LLMConfig> {
let provider = self.provider
.ok_or_else(|| UbiquityError::ConfigError("Provider not specified".to_string()))?;
let api_key = self.api_key
.ok_or_else(|| UbiquityError::ConfigError("API key not specified".to_string()))?;
let model = self.model.unwrap_or_else(|| {
default_model_for_provider(provider)
});
Ok(LLMConfig {
provider,
api_key,
model,
temperature: self.temperature,
max_tokens: self.max_tokens,
timeout: self.timeout,
retry_attempts: self.retry_attempts,
retry_delay: self.retry_delay,
})
}
}
fn default_model_for_provider(provider: LLMProvider) -> String {
match provider {
LLMProvider::Claude => "claude-3-opus-20240229".to_string(),
LLMProvider::OpenAI => "gpt-4-turbo-preview".to_string(),
LLMProvider::Local => "llama2".to_string(),
LLMProvider::Mock => "mock-model".to_string(),
}
}
pub struct LLMConfigPresets;
impl LLMConfigPresets {
pub fn fast() -> LLMConfigBuilder {
LLMConfigBuilder::new()
.temperature(0.3)
.max_tokens(1024)
.timeout(Duration::from_secs(30))
}
pub fn balanced() -> LLMConfigBuilder {
LLMConfigBuilder::new()
.temperature(0.7)
.max_tokens(4096)
.timeout(Duration::from_secs(60))
}
pub fn quality() -> LLMConfigBuilder {
LLMConfigBuilder::new()
.temperature(0.9)
.max_tokens(8192)
.timeout(Duration::from_secs(180))
.retry_attempts(5)
}
pub fn code_generation() -> LLMConfigBuilder {
LLMConfigBuilder::new()
.temperature(0.2)
.max_tokens(8192)
.timeout(Duration::from_secs(120))
}
}
pub struct LLMConfigFromEnv;
impl LLMConfigFromEnv {
pub fn claude() -> Result<LLMConfig> {
let api_key = env::var("CLAUDE_API_KEY")
.or_else(|_| env::var("ANTHROPIC_API_KEY"))
.map_err(|_| UbiquityError::ConfigError(
"CLAUDE_API_KEY or ANTHROPIC_API_KEY not found in environment".to_string()
))?;
let model = env::var("CLAUDE_MODEL")
.unwrap_or_else(|_| "claude-3-opus-20240229".to_string());
LLMConfigBuilder::new()
.provider(LLMProvider::Claude)
.api_key(api_key)
.model(model)
.build()
}
pub fn openai() -> Result<LLMConfig> {
let api_key = env::var("OPENAI_API_KEY")
.map_err(|_| UbiquityError::ConfigError(
"OPENAI_API_KEY not found in environment".to_string()
))?;
let model = env::var("OPENAI_MODEL")
.unwrap_or_else(|_| "gpt-4-turbo-preview".to_string());
LLMConfigBuilder::new()
.provider(LLMProvider::OpenAI)
.api_key(api_key)
.model(model)
.build()
}
pub fn local() -> Result<LLMConfig> {
let base_url = env::var("LOCAL_LLM_URL")
.unwrap_or_else(|_| "http://localhost:11434".to_string());
let model = env::var("LOCAL_LLM_MODEL")
.unwrap_or_else(|_| "llama2".to_string());
LLMConfigBuilder::new()
.provider(LLMProvider::Local)
.api_key(base_url) .model(model)
.build()
}
pub fn any() -> Result<LLMConfig> {
Self::claude()
.or_else(|_| Self::openai())
.or_else(|_| Self::local())
.map_err(|_| UbiquityError::ConfigError(
"No LLM configuration found in environment".to_string()
))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfigChain {
pub primary: LLMConfig,
pub fallbacks: Vec<LLMConfig>,
}
impl LLMConfigChain {
pub fn new(primary: LLMConfig) -> Self {
Self {
primary,
fallbacks: Vec::new(),
}
}
pub fn with_fallback(mut self, config: LLMConfig) -> Self {
self.fallbacks.push(config);
self
}
pub fn all_configs(&self) -> Vec<&LLMConfig> {
let mut configs = vec![&self.primary];
configs.extend(self.fallbacks.iter());
configs
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub tokens_per_minute: u32,
pub max_concurrent: usize,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 60,
tokens_per_minute: 100000,
max_concurrent: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfigExtended {
pub base: LLMConfig,
pub rate_limit: Option<RateLimitConfig>,
pub custom_headers: Option<std::collections::HashMap<String, String>>,
pub enable_logging: bool,
pub enable_metrics: bool,
pub cost_tracking: Option<CostTrackingConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostTrackingConfig {
pub input_token_cost: f64,
pub output_token_cost: f64,
pub currency: String,
}
impl CostTrackingConfig {
pub fn calculate_cost(&self, input_tokens: usize, output_tokens: usize) -> f64 {
let input_cost = (input_tokens as f64 / 1000.0) * self.input_token_cost;
let output_cost = (output_tokens as f64 / 1000.0) * self.output_token_cost;
input_cost + output_cost
}
}
impl CostTrackingConfig {
pub fn claude_opus() -> Self {
Self {
input_token_cost: 15.0,
output_token_cost: 75.0,
currency: "USD".to_string(),
}
}
pub fn gpt4_turbo() -> Self {
Self {
input_token_cost: 10.0,
output_token_cost: 30.0,
currency: "USD".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_config_builder() {
let config = LLMConfigBuilder::new()
.provider(LLMProvider::Claude)
.api_key("test-key")
.model("claude-3-opus")
.temperature(0.5)
.max_tokens(2048)
.build()
.unwrap();
assert_eq!(config.provider, LLMProvider::Claude);
assert_eq!(config.api_key, "test-key");
assert_eq!(config.model, "claude-3-opus");
assert_eq!(config.temperature, 0.5);
assert_eq!(config.max_tokens, 2048);
}
#[test]
fn test_config_presets() {
let fast = LLMConfigPresets::fast()
.provider(LLMProvider::Mock)
.api_key("test")
.build()
.unwrap();
assert_eq!(fast.temperature, 0.3);
assert_eq!(fast.max_tokens, 1024);
assert_eq!(fast.timeout, Duration::from_secs(30));
}
#[test]
fn test_cost_calculation() {
let cost_config = CostTrackingConfig::claude_opus();
let cost = cost_config.calculate_cost(1000, 2000);
assert_eq!(cost, 15.0 + 150.0);
}
}