Skip to main content

ceres_client/
provider.rs

1//! Embedding provider factory and dynamic dispatch.
2//!
3//! This module provides a unified interface for working with different
4//! embedding providers through the [`EmbeddingProviderEnum`] enum.
5//!
6//! # Why an Enum Instead of `dyn Trait`?
7//!
8//! The [`EmbeddingProvider`] trait uses `impl Future` return types (RPITIT),
9//! which makes it not object-safe. We use an enum to provide dynamic dispatch
10//! while maintaining the ergonomic async trait syntax.
11//!
12//! # Usage
13//!
14//! ```no_run
15//! use ceres_client::provider::EmbeddingProviderEnum;
16//! use ceres_core::traits::EmbeddingProvider;
17//!
18//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
19//! // Create provider based on configuration
20//! let provider = EmbeddingProviderEnum::gemini("your-api-key")?;
21//!
22//! // Use the provider generically
23//! println!("Using {} provider ({} dimensions)", provider.name(), provider.dimension());
24//! let embedding = provider.generate("Hello world").await?;
25//! # Ok(())
26//! # }
27//! ```
28
29use ceres_core::error::AppError;
30use ceres_core::traits::EmbeddingProvider;
31
32use crate::{GeminiClient, OpenAIClient};
33
34/// Unified embedding provider that wraps concrete implementations.
35///
36/// This enum allows runtime selection of embedding providers while
37/// implementing the `EmbeddingProvider` trait.
38#[derive(Clone)]
39pub enum EmbeddingProviderEnum {
40    /// Google Gemini embedding provider (768 dimensions).
41    Gemini(GeminiClient),
42    /// OpenAI embedding provider (1536 or 3072 dimensions).
43    OpenAI(OpenAIClient),
44}
45
46impl EmbeddingProviderEnum {
47    /// Creates a Gemini embedding provider.
48    ///
49    /// # Arguments
50    ///
51    /// * `api_key` - Google Gemini API key
52    pub fn gemini(api_key: &str) -> Result<Self, AppError> {
53        Ok(Self::Gemini(GeminiClient::new(api_key)?))
54    }
55
56    /// Creates an OpenAI embedding provider with the default model.
57    ///
58    /// Uses `text-embedding-3-small` (1536 dimensions).
59    ///
60    /// # Arguments
61    ///
62    /// * `api_key` - OpenAI API key (starts with `sk-`)
63    pub fn openai(api_key: &str) -> Result<Self, AppError> {
64        Ok(Self::OpenAI(OpenAIClient::new(api_key)?))
65    }
66
67    /// Creates an OpenAI embedding provider with a specific model.
68    ///
69    /// # Arguments
70    ///
71    /// * `api_key` - OpenAI API key
72    /// * `model` - Model name (e.g., `text-embedding-3-large`)
73    pub fn openai_with_model(api_key: &str, model: &str) -> Result<Self, AppError> {
74        Ok(Self::OpenAI(OpenAIClient::with_model(api_key, model)?))
75    }
76}
77
78impl EmbeddingProvider for EmbeddingProviderEnum {
79    fn name(&self) -> &'static str {
80        match self {
81            Self::Gemini(c) => c.name(),
82            Self::OpenAI(c) => c.name(),
83        }
84    }
85
86    fn dimension(&self) -> usize {
87        match self {
88            Self::Gemini(c) => c.dimension(),
89            Self::OpenAI(c) => c.dimension(),
90        }
91    }
92
93    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
94        match self {
95            Self::Gemini(c) => c.generate(text).await,
96            Self::OpenAI(c) => c.generate(text).await,
97        }
98    }
99
100    async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
101        match self {
102            Self::Gemini(c) => c.generate_batch(texts).await,
103            Self::OpenAI(c) => c.generate_batch(texts).await,
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_gemini_provider_creation() {
114        let provider = EmbeddingProviderEnum::gemini("test-key");
115        assert!(provider.is_ok());
116        let provider = provider.unwrap();
117        assert_eq!(provider.name(), "gemini");
118        assert_eq!(provider.dimension(), 768);
119    }
120
121    #[test]
122    fn test_openai_provider_creation() {
123        let provider = EmbeddingProviderEnum::openai("sk-test");
124        assert!(provider.is_ok());
125        let provider = provider.unwrap();
126        assert_eq!(provider.name(), "openai");
127        assert_eq!(provider.dimension(), 1536);
128    }
129
130    #[test]
131    fn test_openai_large_model() {
132        let provider =
133            EmbeddingProviderEnum::openai_with_model("sk-test", "text-embedding-3-large");
134        assert!(provider.is_ok());
135        let provider = provider.unwrap();
136        assert_eq!(provider.dimension(), 3072);
137    }
138}