ubiquity-core 0.1.1

Core types and traits for Ubiquity consciousness-aware mesh
Documentation
//! LLM configuration helpers and builders

use crate::{LLMConfig, LLMProvider, Result, UbiquityError};
use std::time::Duration;
use std::env;
use serde::{Deserialize, Serialize};

/// LLM configuration builder for easier setup
#[derive(Debug, Clone)]
pub struct LLMConfigBuilder {
    provider: Option<LLMProvider>,
    api_key: Option<String>,
    model: Option<String>,
    temperature: f32,
    max_tokens: usize,
    timeout: Duration,
    retry_attempts: u32,
    retry_delay: Duration,
}

impl Default for LLMConfigBuilder {
    fn default() -> Self {
        Self {
            provider: None,
            api_key: None,
            model: None,
            temperature: 0.7,
            max_tokens: 4096,
            timeout: Duration::from_secs(120),
            retry_attempts: 3,
            retry_delay: Duration::from_secs(1),
        }
    }
}

impl LLMConfigBuilder {
    /// Create a new builder
    pub fn new() -> Self {
        Self::default()
    }
    
    /// Set the provider
    pub fn provider(mut self, provider: LLMProvider) -> Self {
        self.provider = Some(provider);
        self
    }
    
    /// Set the API key
    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
        self.api_key = Some(api_key.into());
        self
    }
    
    /// Set the model
    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = Some(model.into());
        self
    }
    
    /// Set the temperature
    pub fn temperature(mut self, temperature: f32) -> Self {
        self.temperature = temperature.clamp(0.0, 1.0);
        self
    }
    
    /// Set the maximum tokens
    pub fn max_tokens(mut self, max_tokens: usize) -> Self {
        self.max_tokens = max_tokens;
        self
    }
    
    /// Set the timeout
    pub fn timeout(mut self, timeout: Duration) -> Self {
        self.timeout = timeout;
        self
    }
    
    /// Set retry attempts
    pub fn retry_attempts(mut self, attempts: u32) -> Self {
        self.retry_attempts = attempts;
        self
    }
    
    /// Set retry delay
    pub fn retry_delay(mut self, delay: Duration) -> Self {
        self.retry_delay = delay;
        self
    }
    
    /// Build the configuration
    pub fn build(self) -> Result<LLMConfig> {
        let provider = self.provider
            .ok_or_else(|| UbiquityError::ConfigError("Provider not specified".to_string()))?;
        
        let api_key = self.api_key
            .ok_or_else(|| UbiquityError::ConfigError("API key not specified".to_string()))?;
        
        let model = self.model.unwrap_or_else(|| {
            default_model_for_provider(provider)
        });
        
        Ok(LLMConfig {
            provider,
            api_key,
            model,
            temperature: self.temperature,
            max_tokens: self.max_tokens,
            timeout: self.timeout,
            retry_attempts: self.retry_attempts,
            retry_delay: self.retry_delay,
        })
    }
}

/// Get default model for a provider
fn default_model_for_provider(provider: LLMProvider) -> String {
    match provider {
        LLMProvider::Claude => "claude-3-opus-20240229".to_string(),
        LLMProvider::OpenAI => "gpt-4-turbo-preview".to_string(),
        LLMProvider::Local => "llama2".to_string(),
        LLMProvider::Mock => "mock-model".to_string(),
    }
}

/// Configuration presets for common use cases
pub struct LLMConfigPresets;

impl LLMConfigPresets {
    /// Fast, cheap configuration for simple tasks
    pub fn fast() -> LLMConfigBuilder {
        LLMConfigBuilder::new()
            .temperature(0.3)
            .max_tokens(1024)
            .timeout(Duration::from_secs(30))
    }
    
    /// Balanced configuration for general use
    pub fn balanced() -> LLMConfigBuilder {
        LLMConfigBuilder::new()
            .temperature(0.7)
            .max_tokens(4096)
            .timeout(Duration::from_secs(60))
    }
    
    /// High quality configuration for complex tasks
    pub fn quality() -> LLMConfigBuilder {
        LLMConfigBuilder::new()
            .temperature(0.9)
            .max_tokens(8192)
            .timeout(Duration::from_secs(180))
            .retry_attempts(5)
    }
    
    /// Configuration for code generation
    pub fn code_generation() -> LLMConfigBuilder {
        LLMConfigBuilder::new()
            .temperature(0.2)
            .max_tokens(8192)
            .timeout(Duration::from_secs(120))
    }
}

/// Load LLM configuration from environment variables
pub struct LLMConfigFromEnv;

impl LLMConfigFromEnv {
    /// Load Claude configuration from environment
    pub fn claude() -> Result<LLMConfig> {
        let api_key = env::var("CLAUDE_API_KEY")
            .or_else(|_| env::var("ANTHROPIC_API_KEY"))
            .map_err(|_| UbiquityError::ConfigError(
                "CLAUDE_API_KEY or ANTHROPIC_API_KEY not found in environment".to_string()
            ))?;
        
        let model = env::var("CLAUDE_MODEL")
            .unwrap_or_else(|_| "claude-3-opus-20240229".to_string());
        
        LLMConfigBuilder::new()
            .provider(LLMProvider::Claude)
            .api_key(api_key)
            .model(model)
            .build()
    }
    
    /// Load OpenAI configuration from environment
    pub fn openai() -> Result<LLMConfig> {
        let api_key = env::var("OPENAI_API_KEY")
            .map_err(|_| UbiquityError::ConfigError(
                "OPENAI_API_KEY not found in environment".to_string()
            ))?;
        
        let model = env::var("OPENAI_MODEL")
            .unwrap_or_else(|_| "gpt-4-turbo-preview".to_string());
        
        LLMConfigBuilder::new()
            .provider(LLMProvider::OpenAI)
            .api_key(api_key)
            .model(model)
            .build()
    }
    
    /// Load local LLM configuration from environment
    pub fn local() -> Result<LLMConfig> {
        let base_url = env::var("LOCAL_LLM_URL")
            .unwrap_or_else(|_| "http://localhost:11434".to_string());
        
        let model = env::var("LOCAL_LLM_MODEL")
            .unwrap_or_else(|_| "llama2".to_string());
        
        LLMConfigBuilder::new()
            .provider(LLMProvider::Local)
            .api_key(base_url) // For local, api_key holds the base URL
            .model(model)
            .build()
    }
    
    /// Try to load any available LLM configuration from environment
    pub fn any() -> Result<LLMConfig> {
        // Try providers in order of preference
        Self::claude()
            .or_else(|_| Self::openai())
            .or_else(|_| Self::local())
            .map_err(|_| UbiquityError::ConfigError(
                "No LLM configuration found in environment".to_string()
            ))
    }
}

/// Multiple LLM configurations for fallback support
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfigChain {
    /// Primary configuration
    pub primary: LLMConfig,
    /// Fallback configurations in order of preference
    pub fallbacks: Vec<LLMConfig>,
}

impl LLMConfigChain {
    /// Create a new configuration chain
    pub fn new(primary: LLMConfig) -> Self {
        Self {
            primary,
            fallbacks: Vec::new(),
        }
    }
    
    /// Add a fallback configuration
    pub fn with_fallback(mut self, config: LLMConfig) -> Self {
        self.fallbacks.push(config);
        self
    }
    
    /// Get all configurations in order (primary first)
    pub fn all_configs(&self) -> Vec<&LLMConfig> {
        let mut configs = vec![&self.primary];
        configs.extend(self.fallbacks.iter());
        configs
    }
}

/// Rate limiting configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
    /// Maximum requests per minute
    pub requests_per_minute: u32,
    /// Maximum tokens per minute
    pub tokens_per_minute: u32,
    /// Maximum concurrent requests
    pub max_concurrent: usize,
}

impl Default for RateLimitConfig {
    fn default() -> Self {
        Self {
            requests_per_minute: 60,
            tokens_per_minute: 100000,
            max_concurrent: 10,
        }
    }
}

/// Extended LLM configuration with additional features
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMConfigExtended {
    /// Base configuration
    pub base: LLMConfig,
    /// Rate limiting configuration
    pub rate_limit: Option<RateLimitConfig>,
    /// Custom headers for API requests
    pub custom_headers: Option<std::collections::HashMap<String, String>>,
    /// Whether to enable request/response logging
    pub enable_logging: bool,
    /// Whether to enable metrics collection
    pub enable_metrics: bool,
    /// Cost tracking configuration
    pub cost_tracking: Option<CostTrackingConfig>,
}

/// Cost tracking configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostTrackingConfig {
    /// Cost per 1K input tokens
    pub input_token_cost: f64,
    /// Cost per 1K output tokens
    pub output_token_cost: f64,
    /// Currency (e.g., "USD")
    pub currency: String,
}

impl CostTrackingConfig {
    /// Calculate cost for a given token usage
    pub fn calculate_cost(&self, input_tokens: usize, output_tokens: usize) -> f64 {
        let input_cost = (input_tokens as f64 / 1000.0) * self.input_token_cost;
        let output_cost = (output_tokens as f64 / 1000.0) * self.output_token_cost;
        input_cost + output_cost
    }
}

/// Common cost configurations
impl CostTrackingConfig {
    pub fn claude_opus() -> Self {
        Self {
            input_token_cost: 15.0,
            output_token_cost: 75.0,
            currency: "USD".to_string(),
        }
    }
    
    pub fn gpt4_turbo() -> Self {
        Self {
            input_token_cost: 10.0,
            output_token_cost: 30.0,
            currency: "USD".to_string(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_llm_config_builder() {
        let config = LLMConfigBuilder::new()
            .provider(LLMProvider::Claude)
            .api_key("test-key")
            .model("claude-3-opus")
            .temperature(0.5)
            .max_tokens(2048)
            .build()
            .unwrap();
        
        assert_eq!(config.provider, LLMProvider::Claude);
        assert_eq!(config.api_key, "test-key");
        assert_eq!(config.model, "claude-3-opus");
        assert_eq!(config.temperature, 0.5);
        assert_eq!(config.max_tokens, 2048);
    }
    
    #[test]
    fn test_config_presets() {
        let fast = LLMConfigPresets::fast()
            .provider(LLMProvider::Mock)
            .api_key("test")
            .build()
            .unwrap();
        
        assert_eq!(fast.temperature, 0.3);
        assert_eq!(fast.max_tokens, 1024);
        assert_eq!(fast.timeout, Duration::from_secs(30));
    }
    
    #[test]
    fn test_cost_calculation() {
        let cost_config = CostTrackingConfig::claude_opus();
        let cost = cost_config.calculate_cost(1000, 2000);
        
        // 1K input tokens * $15/1K + 2K output tokens * $75/1K
        assert_eq!(cost, 15.0 + 150.0);
    }
}