ceres-client 0.4.0

HTTP clients for Ceres portal harvesters and embedding providers
Documentation
//! Embedding provider factory and dynamic dispatch.
//!
//! This module provides a unified interface for working with different
//! embedding providers through the [`EmbeddingProviderEnum`] enum.
//!
//! # Why an Enum Instead of `dyn Trait`?
//!
//! The [`EmbeddingProvider`] trait uses `impl Future` return types (RPITIT),
//! which makes it not object-safe. We use an enum to provide dynamic dispatch
//! while maintaining the ergonomic async trait syntax.
//!
//! # Usage
//!
//! ```no_run
//! use ceres_client::provider::EmbeddingProviderEnum;
//! use ceres_core::traits::EmbeddingProvider;
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! // Create provider based on configuration
//! let provider = EmbeddingProviderEnum::gemini("your-api-key")?;
//!
//! // Use the provider generically
//! println!("Using {} provider ({} dimensions)", provider.name(), provider.dimension());
//! let embedding = provider.generate("Hello world").await?;
//! # Ok(())
//! # }
//! ```

use anyhow::Context;
use ceres_core::config::EmbeddingProviderType;
use ceres_core::error::AppError;
use ceres_core::traits::EmbeddingProvider;

use crate::{GeminiClient, OllamaClient, OpenAIClient};

/// Configuration needed to create an embedding provider.
///
/// This struct extracts the embedding-related fields shared between
/// CLI and server configurations, avoiding duplication of the factory logic.
pub struct EmbeddingConfig {
    pub provider: String,
    pub gemini_api_key: Option<String>,
    pub openai_api_key: Option<String>,
    pub embedding_model: Option<String>,
    pub ollama_endpoint: Option<String>,
}

/// Mock embedding client for testing (returns deterministic 768-dim vectors).
///
/// Available only with the `test-support` feature.
#[cfg(feature = "test-support")]
#[derive(Clone, Debug)]
pub struct MockEmbeddingClient {
    dimension: usize,
}

#[cfg(feature = "test-support")]
impl MockEmbeddingClient {
    pub fn new() -> Self {
        Self { dimension: 768 }
    }
}

#[cfg(feature = "test-support")]
impl Default for MockEmbeddingClient {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(feature = "test-support")]
impl EmbeddingProvider for MockEmbeddingClient {
    fn name(&self) -> &'static str {
        "mock"
    }
    fn dimension(&self) -> usize {
        self.dimension
    }
    fn max_batch_size(&self) -> usize {
        100
    }
    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
        let seed = text.len() as f32;
        Ok((0..self.dimension)
            .map(|i| (seed + i as f32) / 1000.0)
            .collect())
    }
}

/// Unified embedding provider that wraps concrete implementations.
///
/// This enum allows runtime selection of embedding providers while
/// implementing the `EmbeddingProvider` trait.
#[derive(Clone)]
pub enum EmbeddingProviderEnum {
    /// Google Gemini embedding provider (768 dimensions).
    Gemini(GeminiClient),
    /// OpenAI embedding provider (1536 or 3072 dimensions).
    OpenAI(OpenAIClient),
    /// Ollama local embedding provider (default 768 dimensions).
    Ollama(OllamaClient),
    /// Mock embedding provider for testing (768 dimensions).
    #[cfg(feature = "test-support")]
    Mock(MockEmbeddingClient),
}

impl EmbeddingProviderEnum {
    /// Creates a Gemini embedding provider.
    ///
    /// # Arguments
    ///
    /// * `api_key` - Google Gemini API key
    pub fn gemini(api_key: &str) -> Result<Self, AppError> {
        Ok(Self::Gemini(GeminiClient::new(api_key)?))
    }

    /// Creates an OpenAI embedding provider with the default model.
    ///
    /// Uses `text-embedding-3-small` (1536 dimensions).
    ///
    /// # Arguments
    ///
    /// * `api_key` - OpenAI API key (starts with `sk-`)
    pub fn openai(api_key: &str) -> Result<Self, AppError> {
        Ok(Self::OpenAI(OpenAIClient::new(api_key)?))
    }

    /// Creates an OpenAI embedding provider with a specific model.
    ///
    /// # Arguments
    ///
    /// * `api_key` - OpenAI API key
    /// * `model` - Model name (e.g., `text-embedding-3-large`)
    pub fn openai_with_model(api_key: &str, model: &str) -> Result<Self, AppError> {
        Ok(Self::OpenAI(OpenAIClient::with_model(api_key, model)?))
    }

    /// Creates an Ollama embedding provider with default settings.
    ///
    /// Uses `nomic-embed-text` model at `http://localhost:11434`.
    pub fn ollama() -> Result<Self, AppError> {
        Ok(Self::Ollama(OllamaClient::new()?))
    }

    /// Creates an Ollama embedding provider with custom configuration.
    pub fn ollama_with_config(model: &str, endpoint: Option<&str>) -> Result<Self, AppError> {
        Ok(Self::Ollama(OllamaClient::with_config(model, endpoint)?))
    }

    /// Creates a mock embedding provider for testing.
    #[cfg(feature = "test-support")]
    pub fn mock() -> Self {
        Self::Mock(MockEmbeddingClient::new())
    }

    /// Creates an embedding provider from configuration.
    ///
    /// Parses the provider type and initializes the appropriate client
    /// with the given API key and optional model override.
    pub fn from_config(config: &EmbeddingConfig) -> anyhow::Result<Self> {
        let provider_type: EmbeddingProviderType = config
            .provider
            .parse()
            .context("Invalid embedding provider")?;

        match provider_type {
            EmbeddingProviderType::Gemini => {
                let api_key = config.gemini_api_key.as_ref().ok_or_else(|| {
                    anyhow::anyhow!("GEMINI_API_KEY required when using gemini provider")
                })?;
                Self::gemini(api_key).context("Failed to initialize Gemini client")
            }
            EmbeddingProviderType::OpenAI => {
                let api_key = config.openai_api_key.as_ref().ok_or_else(|| {
                    anyhow::anyhow!("OPENAI_API_KEY required when using openai provider")
                })?;

                if let Some(model) = &config.embedding_model {
                    Self::openai_with_model(api_key, model)
                        .context("Failed to initialize OpenAI client")
                } else {
                    Self::openai(api_key).context("Failed to initialize OpenAI client")
                }
            }
            EmbeddingProviderType::Ollama => {
                let model = config
                    .embedding_model
                    .as_deref()
                    .unwrap_or("nomic-embed-text");
                let endpoint = config.ollama_endpoint.as_deref();
                Self::ollama_with_config(model, endpoint)
                    .context("Failed to initialize Ollama client")
            }
        }
    }
}

impl EmbeddingProvider for EmbeddingProviderEnum {
    fn name(&self) -> &'static str {
        match self {
            Self::Gemini(c) => c.name(),
            Self::OpenAI(c) => c.name(),
            Self::Ollama(c) => c.name(),
            #[cfg(feature = "test-support")]
            Self::Mock(c) => c.name(),
        }
    }

    fn dimension(&self) -> usize {
        match self {
            Self::Gemini(c) => c.dimension(),
            Self::OpenAI(c) => c.dimension(),
            Self::Ollama(c) => c.dimension(),
            #[cfg(feature = "test-support")]
            Self::Mock(c) => c.dimension(),
        }
    }

    fn max_batch_size(&self) -> usize {
        match self {
            Self::Gemini(c) => c.max_batch_size(),
            Self::OpenAI(c) => c.max_batch_size(),
            Self::Ollama(c) => c.max_batch_size(),
            #[cfg(feature = "test-support")]
            Self::Mock(c) => c.max_batch_size(),
        }
    }

    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
        match self {
            Self::Gemini(c) => c.generate(text).await,
            Self::OpenAI(c) => c.generate(text).await,
            Self::Ollama(c) => c.generate(text).await,
            #[cfg(feature = "test-support")]
            Self::Mock(c) => c.generate(text).await,
        }
    }

    async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
        match self {
            Self::Gemini(c) => c.generate_batch(texts).await,
            Self::OpenAI(c) => c.generate_batch(texts).await,
            Self::Ollama(c) => c.generate_batch(texts).await,
            #[cfg(feature = "test-support")]
            Self::Mock(c) => c.generate_batch(texts).await,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gemini_provider_creation() {
        let provider = EmbeddingProviderEnum::gemini("test-key");
        assert!(provider.is_ok());
        let provider = provider.unwrap();
        assert_eq!(provider.name(), "gemini");
        assert_eq!(provider.dimension(), 768);
    }

    #[test]
    fn test_openai_provider_creation() {
        let provider = EmbeddingProviderEnum::openai("sk-test");
        assert!(provider.is_ok());
        let provider = provider.unwrap();
        assert_eq!(provider.name(), "openai");
        assert_eq!(provider.dimension(), 1536);
    }

    #[test]
    fn test_openai_large_model() {
        let provider =
            EmbeddingProviderEnum::openai_with_model("sk-test", "text-embedding-3-large");
        assert!(provider.is_ok());
        let provider = provider.unwrap();
        assert_eq!(provider.dimension(), 3072);
    }

    fn base_config(provider: &str) -> EmbeddingConfig {
        EmbeddingConfig {
            provider: provider.to_string(),
            gemini_api_key: None,
            openai_api_key: None,
            embedding_model: None,
            ollama_endpoint: None,
        }
    }

    #[test]
    fn test_from_config_gemini() {
        let mut config = base_config("gemini");
        config.gemini_api_key = Some("test-key".to_string());
        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
        assert!(matches!(provider, EmbeddingProviderEnum::Gemini(_)));
    }

    #[test]
    fn test_from_config_openai_default_model() {
        let mut config = base_config("openai");
        config.openai_api_key = Some("sk-test".to_string());
        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
        assert!(matches!(provider, EmbeddingProviderEnum::OpenAI(_)));
        assert_eq!(provider.dimension(), 1536);
    }

    #[test]
    fn test_from_config_openai_custom_model() {
        let mut config = base_config("openai");
        config.openai_api_key = Some("sk-test".to_string());
        config.embedding_model = Some("text-embedding-3-large".to_string());
        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
        assert_eq!(provider.dimension(), 3072);
    }

    #[test]
    fn test_from_config_invalid_provider() {
        let config = base_config("invalid");
        assert!(EmbeddingProviderEnum::from_config(&config).is_err());
    }

    #[test]
    fn test_from_config_missing_gemini_key() {
        let config = base_config("gemini");
        assert!(EmbeddingProviderEnum::from_config(&config).is_err());
    }

    #[test]
    fn test_from_config_missing_openai_key() {
        let config = base_config("openai");
        assert!(EmbeddingProviderEnum::from_config(&config).is_err());
    }

    #[test]
    fn test_ollama_provider_creation() {
        let provider = EmbeddingProviderEnum::ollama();
        assert!(provider.is_ok());
        let provider = provider.unwrap();
        assert_eq!(provider.name(), "ollama");
        assert_eq!(provider.dimension(), 768);
    }

    #[test]
    fn test_ollama_provider_custom_model() {
        let provider = EmbeddingProviderEnum::ollama_with_config("mxbai-embed-large", None);
        assert!(provider.is_ok());
        let provider = provider.unwrap();
        assert_eq!(provider.dimension(), 1024);
    }

    #[test]
    fn test_from_config_ollama() {
        let config = base_config("ollama");
        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
        assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
        assert_eq!(provider.dimension(), 768);
    }

    #[test]
    fn test_from_config_ollama_custom_model() {
        let mut config = base_config("ollama");
        config.embedding_model = Some("mxbai-embed-large".to_string());
        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
        assert_eq!(provider.dimension(), 1024);
    }

    #[test]
    fn test_from_config_ollama_custom_endpoint() {
        let mut config = base_config("ollama");
        config.ollama_endpoint = Some("http://myhost:11434".to_string());
        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
        assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
    }
}