herolib-ai 0.3.2

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! AI client with multi-provider support.
//!
//! This module provides a unified client for multiple AI providers with automatic failover.

use std::collections::HashMap;
use std::time::Duration;

use crate::error::{AiError, AiResult};
use crate::model::Model;
use crate::provider::{Provider, ProviderConfig};
use crate::types::{ApiErrorResponse, ChatCompletionRequest, ChatCompletionResponse, Message};

/// AI client with multi-provider support.
///
/// The client automatically tries providers in order of preference until one succeeds.
#[derive(Debug)]
pub struct AiClient {
    /// Configured providers.
    providers: HashMap<Provider, ProviderConfig>,
    /// Default temperature.
    default_temperature: Option<f32>,
    /// Default max tokens.
    default_max_tokens: Option<u32>,
}

impl Default for AiClient {
    fn default() -> Self {
        Self::new()
    }
}

impl AiClient {
    /// Creates a new AI client with no providers configured.
    pub fn new() -> Self {
        Self {
            providers: HashMap::new(),
            default_temperature: None,
            default_max_tokens: None,
        }
    }

    /// Creates a new AI client with providers configured from environment variables.
    ///
    /// Automatically configures any provider that has its API key set in the environment.
    pub fn from_env() -> Self {
        let mut client = Self::new();

        for provider in [Provider::Groq, Provider::OpenRouter, Provider::SambaNova] {
            if let Some(config) = ProviderConfig::from_env(provider) {
                client.providers.insert(provider, config);
            }
        }

        client
    }

    /// Adds a provider configuration.
    pub fn with_provider(mut self, config: ProviderConfig) -> Self {
        self.providers.insert(config.provider, config);
        self
    }

    /// Sets the default temperature for all requests.
    pub fn with_default_temperature(mut self, temperature: f32) -> Self {
        self.default_temperature = Some(temperature);
        self
    }

    /// Sets the default max tokens for all requests.
    pub fn with_default_max_tokens(mut self, max_tokens: u32) -> Self {
        self.default_max_tokens = Some(max_tokens);
        self
    }

    /// Returns whether any providers are configured.
    pub fn has_providers(&self) -> bool {
        !self.providers.is_empty()
    }

    /// Returns the configured providers.
    pub fn providers(&self) -> Vec<Provider> {
        self.providers.keys().copied().collect()
    }

    /// Sends a chat completion request using the specified model.
    ///
    /// Automatically tries providers in order of preference until one succeeds.
    pub fn chat(&self, model: Model, messages: Vec<Message>) -> AiResult<ChatCompletionResponse> {
        let model_info = model.info();
        let mut errors = Vec::new();

        // Try each provider in order
        for mapping in &model_info.providers {
            if let Some(config) = self.providers.get(&mapping.provider) {
                let mut request = ChatCompletionRequest::new(mapping.model_id, messages.clone());

                if let Some(temp) = self.default_temperature {
                    request = request.with_temperature(temp);
                }
                if let Some(max) = self.default_max_tokens {
                    request = request.with_max_tokens(max);
                }

                match self.send_request(config, request) {
                    Ok(response) => return Ok(response),
                    Err(e) => {
                        errors.push((mapping.provider, e.to_string()));
                    }
                }
            }
        }

        if errors.is_empty() {
            Err(AiError::ModelNotAvailable(model.name().to_string()))
        } else {
            Err(AiError::AllProvidersFailed(errors))
        }
    }

    /// Sends a chat completion request with custom options.
    pub fn chat_with_options(
        &self,
        model: Model,
        messages: Vec<Message>,
        temperature: Option<f32>,
        max_tokens: Option<u32>,
    ) -> AiResult<ChatCompletionResponse> {
        let model_info = model.info();
        let mut errors = Vec::new();

        for mapping in &model_info.providers {
            if let Some(config) = self.providers.get(&mapping.provider) {
                let mut request = ChatCompletionRequest::new(mapping.model_id, messages.clone());

                if let Some(temp) = temperature.or(self.default_temperature) {
                    request = request.with_temperature(temp);
                }
                if let Some(max) = max_tokens.or(self.default_max_tokens) {
                    request = request.with_max_tokens(max);
                }

                match self.send_request(config, request) {
                    Ok(response) => return Ok(response),
                    Err(e) => {
                        errors.push((mapping.provider, e.to_string()));
                    }
                }
            }
        }

        if errors.is_empty() {
            Err(AiError::ModelNotAvailable(model.name().to_string()))
        } else {
            Err(AiError::AllProvidersFailed(errors))
        }
    }

    /// Sends a raw chat completion request to a specific provider.
    pub fn chat_raw(
        &self,
        provider: Provider,
        request: ChatCompletionRequest,
    ) -> AiResult<ChatCompletionResponse> {
        let config = self
            .providers
            .get(&provider)
            .ok_or(AiError::ApiKeyNotFound(provider))?;

        self.send_request(config, request)
    }

    /// Sends a request to a specific provider.
    fn send_request(
        &self,
        config: &ProviderConfig,
        request: ChatCompletionRequest,
    ) -> AiResult<ChatCompletionResponse> {
        let url = config.chat_completions_url();

        let agent = ureq::Agent::new_with_config(
            ureq::Agent::config_builder()
                .timeout_global(Some(Duration::from_secs(config.timeout_secs)))
                .build(),
        );

        let response = agent
            .post(&url)
            .header("Authorization", &format!("Bearer {}", config.api_key))
            .header("Content-Type", "application/json")
            .send_json(&request)
            .map_err(|e| self.handle_http_error(config.provider, e))?;

        let body = response
            .into_body()
            .read_to_string()
            .map_err(|e| AiError::ParseError(e.to_string()))?;

        // Try to parse as success response
        if let Ok(completion) = serde_json::from_str::<ChatCompletionResponse>(&body) {
            return Ok(completion);
        }

        // Try to parse as error response
        if let Ok(error) = serde_json::from_str::<ApiErrorResponse>(&body) {
            return Err(AiError::ApiError {
                provider: config.provider,
                message: error.error.message,
                status_code: None,
            });
        }

        Err(AiError::ParseError(format!(
            "Failed to parse response: {}",
            body
        )))
    }

    /// Handles HTTP errors from ureq.
    fn handle_http_error(&self, provider: Provider, error: ureq::Error) -> AiError {
        match error {
            ureq::Error::Timeout(_) => AiError::Timeout(120),
            ureq::Error::StatusCode(status) => {
                if status == 429 {
                    AiError::RateLimitExceeded(provider)
                } else {
                    AiError::ApiError {
                        provider,
                        message: format!("HTTP {}", status),
                        status_code: Some(status),
                    }
                }
            }
            _ => AiError::HttpError(error.to_string()),
        }
    }
}

/// Simple helper function to chat with the default model.
///
/// Uses environment variables for configuration.
pub fn chat_simple(prompt: &str) -> AiResult<String> {
    let client = AiClient::from_env();
    if !client.has_providers() {
        return Err(AiError::NoApiKey);
    }

    let messages = vec![Message::user(prompt)];
    let response = client.chat(Model::default_general(), messages)?;

    response
        .content()
        .map(|s| s.to_string())
        .ok_or_else(|| AiError::ParseError("No content in response".to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_client_creation() {
        let client = AiClient::new();
        assert!(!client.has_providers());
    }

    #[test]
    fn test_client_with_provider() {
        let client = AiClient::new()
            .with_provider(ProviderConfig::new(Provider::Groq, "test-key"))
            .with_default_temperature(0.7);

        assert!(client.has_providers());
        assert!(client.providers().contains(&Provider::Groq));
    }

    #[test]
    fn test_model_not_available() {
        let client = AiClient::new();
        let result = client.chat(Model::Llama3_3_70B, vec![Message::user("Hello")]);

        assert!(matches!(result, Err(AiError::ModelNotAvailable(_))));
    }
}