herolib-ai 0.3.13

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! AI provider definitions.
//!
//! This module defines the supported AI providers and their configurations.

use serde::{Deserialize, Serialize};

/// Supported AI providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
    /// Groq - Fast inference provider.
    /// API: https://api.groq.com/openai/v1/chat/completions
    Groq,
    /// OpenRouter - Unified API for multiple models.
    /// API: https://openrouter.ai/api/v1/chat/completions
    OpenRouter,
    /// SambaNova - High-performance AI inference.
    /// API: https://api.sambanova.ai/v1/chat/completions
    SambaNova,
}

impl Provider {
    /// Returns the base URL for this provider's API.
    pub fn base_url(&self) -> &'static str {
        match self {
            Provider::Groq => "https://api.groq.com/openai/v1",
            Provider::OpenRouter => "https://openrouter.ai/api/v1",
            Provider::SambaNova => "https://api.sambanova.ai/v1",
        }
    }

    /// Returns the chat completions endpoint URL.
    pub fn chat_completions_url(&self) -> String {
        format!("{}/chat/completions", self.base_url())
    }

    /// Returns the embeddings endpoint URL.
    pub fn embeddings_url(&self) -> String {
        format!("{}/embeddings", self.base_url())
    }

    /// Returns the audio transcriptions endpoint URL.
    pub fn transcriptions_url(&self) -> String {
        format!("{}/audio/transcriptions", self.base_url())
    }

    /// Returns the audio translations endpoint URL.
    pub fn translations_url(&self) -> String {
        format!("{}/audio/translations", self.base_url())
    }

    /// Returns the environment variable name for the API key.
    pub fn api_key_env_var(&self) -> &'static str {
        match self {
            Provider::Groq => "GROQ_API_KEY",
            Provider::OpenRouter => "OPENROUTER_API_KEY",
            Provider::SambaNova => "SAMBANOVA_API_KEY",
        }
    }

    /// Returns the provider name as a string.
    pub fn name(&self) -> &'static str {
        match self {
            Provider::Groq => "Groq",
            Provider::OpenRouter => "OpenRouter",
            Provider::SambaNova => "SambaNova",
        }
    }
}

impl std::fmt::Display for Provider {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.name())
    }
}

/// Configuration for a specific provider.
#[derive(Debug, Clone)]
pub struct ProviderConfig {
    /// The provider type.
    pub provider: Provider,
    /// API key for authentication.
    pub api_key: String,
    /// Optional custom base URL (overrides default).
    pub base_url: Option<String>,
    /// Request timeout in seconds.
    pub timeout_secs: u64,
}

impl ProviderConfig {
    /// Creates a new provider configuration.
    pub fn new(provider: Provider, api_key: impl Into<String>) -> Self {
        Self {
            provider,
            api_key: api_key.into(),
            base_url: None,
            timeout_secs: 120,
        }
    }

    /// Creates a provider configuration from environment variable.
    ///
    /// # Returns
    ///
    /// Returns `None` if the environment variable is not set.
    pub fn from_env(provider: Provider) -> Option<Self> {
        std::env::var(provider.api_key_env_var())
            .ok()
            .filter(|key| !key.is_empty())
            .map(|api_key| Self::new(provider, api_key))
    }

    /// Sets a custom base URL.
    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
        self.base_url = Some(url.into());
        self
    }

    /// Sets the request timeout.
    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
        self.timeout_secs = timeout_secs;
        self
    }

    /// Returns the chat completions URL.
    pub fn chat_completions_url(&self) -> String {
        match &self.base_url {
            Some(url) => format!("{}/chat/completions", url.trim_end_matches('/')),
            None => self.provider.chat_completions_url(),
        }
    }

    /// Returns the embeddings URL.
    pub fn embeddings_url(&self) -> String {
        match &self.base_url {
            Some(url) => format!("{}/embeddings", url.trim_end_matches('/')),
            None => self.provider.embeddings_url(),
        }
    }

    /// Returns the audio transcriptions URL.
    pub fn transcriptions_url(&self) -> String {
        match &self.base_url {
            Some(url) => format!("{}/audio/transcriptions", url.trim_end_matches('/')),
            None => self.provider.transcriptions_url(),
        }
    }

    /// Returns the audio translations URL.
    pub fn translations_url(&self) -> String {
        match &self.base_url {
            Some(url) => format!("{}/audio/translations", url.trim_end_matches('/')),
            None => self.provider.translations_url(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_provider_urls() {
        assert_eq!(
            Provider::Groq.chat_completions_url(),
            "https://api.groq.com/openai/v1/chat/completions"
        );
        assert_eq!(
            Provider::OpenRouter.chat_completions_url(),
            "https://openrouter.ai/api/v1/chat/completions"
        );
        assert_eq!(
            Provider::SambaNova.chat_completions_url(),
            "https://api.sambanova.ai/v1/chat/completions"
        );
    }

    #[test]
    fn test_provider_config() {
        let config = ProviderConfig::new(Provider::Groq, "test-key")
            .with_timeout(60)
            .with_base_url("https://custom.api.com/v1");

        assert_eq!(config.api_key, "test-key");
        assert_eq!(config.timeout_secs, 60);
        assert_eq!(
            config.chat_completions_url(),
            "https://custom.api.com/v1/chat/completions"
        );
    }
}