Skip to main content

ceres_client/
provider.rs

1//! Embedding provider factory and dynamic dispatch.
2//!
3//! This module provides a unified interface for working with different
4//! embedding providers through the [`EmbeddingProviderEnum`] enum.
5//!
6//! # Why an Enum Instead of `dyn Trait`?
7//!
8//! The [`EmbeddingProvider`] trait uses `impl Future` return types (RPITIT),
9//! which makes it not object-safe. We use an enum to provide dynamic dispatch
10//! while maintaining the ergonomic async trait syntax.
11//!
12//! # Usage
13//!
14//! ```no_run
15//! use ceres_client::provider::EmbeddingProviderEnum;
16//! use ceres_core::traits::EmbeddingProvider;
17//!
18//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
19//! // Create provider based on configuration
20//! let provider = EmbeddingProviderEnum::gemini("your-api-key")?;
21//!
22//! // Use the provider generically
23//! println!("Using {} provider ({} dimensions)", provider.name(), provider.dimension());
24//! let embedding = provider.generate("Hello world").await?;
25//! # Ok(())
26//! # }
27//! ```
28
29use anyhow::Context;
30use ceres_core::config::EmbeddingProviderType;
31use ceres_core::error::AppError;
32use ceres_core::traits::EmbeddingProvider;
33
34use crate::{GeminiClient, OllamaClient, OpenAIClient};
35
36/// Configuration needed to create an embedding provider.
37///
38/// This struct extracts the embedding-related fields shared between
39/// CLI and server configurations, avoiding duplication of the factory logic.
40pub struct EmbeddingConfig {
41    pub provider: String,
42    pub gemini_api_key: Option<String>,
43    pub openai_api_key: Option<String>,
44    pub embedding_model: Option<String>,
45    pub ollama_endpoint: Option<String>,
46}
47
48/// Mock embedding client for testing (returns deterministic 768-dim vectors).
49///
50/// Available only with the `test-support` feature.
51#[cfg(feature = "test-support")]
52#[derive(Clone, Debug)]
53pub struct MockEmbeddingClient {
54    dimension: usize,
55}
56
57#[cfg(feature = "test-support")]
58impl MockEmbeddingClient {
59    pub fn new() -> Self {
60        Self { dimension: 768 }
61    }
62}
63
64#[cfg(feature = "test-support")]
65impl Default for MockEmbeddingClient {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71#[cfg(feature = "test-support")]
72impl EmbeddingProvider for MockEmbeddingClient {
73    fn name(&self) -> &'static str {
74        "mock"
75    }
76    fn dimension(&self) -> usize {
77        self.dimension
78    }
79    fn max_batch_size(&self) -> usize {
80        100
81    }
82    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
83        let seed = text.len() as f32;
84        Ok((0..self.dimension)
85            .map(|i| (seed + i as f32) / 1000.0)
86            .collect())
87    }
88}
89
90/// Unified embedding provider that wraps concrete implementations.
91///
92/// This enum allows runtime selection of embedding providers while
93/// implementing the `EmbeddingProvider` trait.
94#[derive(Clone)]
95pub enum EmbeddingProviderEnum {
96    /// Google Gemini embedding provider (768 dimensions).
97    Gemini(GeminiClient),
98    /// OpenAI embedding provider (1536 or 3072 dimensions).
99    OpenAI(OpenAIClient),
100    /// Ollama local embedding provider (default 768 dimensions).
101    Ollama(OllamaClient),
102    /// Mock embedding provider for testing (768 dimensions).
103    #[cfg(feature = "test-support")]
104    Mock(MockEmbeddingClient),
105}
106
107impl EmbeddingProviderEnum {
108    /// Creates a Gemini embedding provider.
109    ///
110    /// # Arguments
111    ///
112    /// * `api_key` - Google Gemini API key
113    pub fn gemini(api_key: &str) -> Result<Self, AppError> {
114        Ok(Self::Gemini(GeminiClient::new(api_key)?))
115    }
116
117    /// Creates an OpenAI embedding provider with the default model.
118    ///
119    /// Uses `text-embedding-3-small` (1536 dimensions).
120    ///
121    /// # Arguments
122    ///
123    /// * `api_key` - OpenAI API key (starts with `sk-`)
124    pub fn openai(api_key: &str) -> Result<Self, AppError> {
125        Ok(Self::OpenAI(OpenAIClient::new(api_key)?))
126    }
127
128    /// Creates an OpenAI embedding provider with a specific model.
129    ///
130    /// # Arguments
131    ///
132    /// * `api_key` - OpenAI API key
133    /// * `model` - Model name (e.g., `text-embedding-3-large`)
134    pub fn openai_with_model(api_key: &str, model: &str) -> Result<Self, AppError> {
135        Ok(Self::OpenAI(OpenAIClient::with_model(api_key, model)?))
136    }
137
138    /// Creates an Ollama embedding provider with default settings.
139    ///
140    /// Uses `nomic-embed-text` model at `http://localhost:11434`.
141    pub fn ollama() -> Result<Self, AppError> {
142        Ok(Self::Ollama(OllamaClient::new()?))
143    }
144
145    /// Creates an Ollama embedding provider with custom configuration.
146    pub fn ollama_with_config(model: &str, endpoint: Option<&str>) -> Result<Self, AppError> {
147        Ok(Self::Ollama(OllamaClient::with_config(model, endpoint)?))
148    }
149
150    /// Creates a mock embedding provider for testing.
151    #[cfg(feature = "test-support")]
152    pub fn mock() -> Self {
153        Self::Mock(MockEmbeddingClient::new())
154    }
155
156    /// Creates an embedding provider from configuration.
157    ///
158    /// Parses the provider type and initializes the appropriate client
159    /// with the given API key and optional model override.
160    pub fn from_config(config: &EmbeddingConfig) -> anyhow::Result<Self> {
161        let provider_type: EmbeddingProviderType = config
162            .provider
163            .parse()
164            .context("Invalid embedding provider")?;
165
166        match provider_type {
167            EmbeddingProviderType::Gemini => {
168                let api_key = config.gemini_api_key.as_ref().ok_or_else(|| {
169                    anyhow::anyhow!("GEMINI_API_KEY required when using gemini provider")
170                })?;
171                Self::gemini(api_key).context("Failed to initialize Gemini client")
172            }
173            EmbeddingProviderType::OpenAI => {
174                let api_key = config.openai_api_key.as_ref().ok_or_else(|| {
175                    anyhow::anyhow!("OPENAI_API_KEY required when using openai provider")
176                })?;
177
178                if let Some(model) = &config.embedding_model {
179                    Self::openai_with_model(api_key, model)
180                        .context("Failed to initialize OpenAI client")
181                } else {
182                    Self::openai(api_key).context("Failed to initialize OpenAI client")
183                }
184            }
185            EmbeddingProviderType::Ollama => {
186                let model = config
187                    .embedding_model
188                    .as_deref()
189                    .unwrap_or("nomic-embed-text");
190                let endpoint = config.ollama_endpoint.as_deref();
191                Self::ollama_with_config(model, endpoint)
192                    .context("Failed to initialize Ollama client")
193            }
194        }
195    }
196}
197
198impl EmbeddingProvider for EmbeddingProviderEnum {
199    fn name(&self) -> &'static str {
200        match self {
201            Self::Gemini(c) => c.name(),
202            Self::OpenAI(c) => c.name(),
203            Self::Ollama(c) => c.name(),
204            #[cfg(feature = "test-support")]
205            Self::Mock(c) => c.name(),
206        }
207    }
208
209    fn dimension(&self) -> usize {
210        match self {
211            Self::Gemini(c) => c.dimension(),
212            Self::OpenAI(c) => c.dimension(),
213            Self::Ollama(c) => c.dimension(),
214            #[cfg(feature = "test-support")]
215            Self::Mock(c) => c.dimension(),
216        }
217    }
218
219    fn max_batch_size(&self) -> usize {
220        match self {
221            Self::Gemini(c) => c.max_batch_size(),
222            Self::OpenAI(c) => c.max_batch_size(),
223            Self::Ollama(c) => c.max_batch_size(),
224            #[cfg(feature = "test-support")]
225            Self::Mock(c) => c.max_batch_size(),
226        }
227    }
228
229    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
230        match self {
231            Self::Gemini(c) => c.generate(text).await,
232            Self::OpenAI(c) => c.generate(text).await,
233            Self::Ollama(c) => c.generate(text).await,
234            #[cfg(feature = "test-support")]
235            Self::Mock(c) => c.generate(text).await,
236        }
237    }
238
239    async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
240        match self {
241            Self::Gemini(c) => c.generate_batch(texts).await,
242            Self::OpenAI(c) => c.generate_batch(texts).await,
243            Self::Ollama(c) => c.generate_batch(texts).await,
244            #[cfg(feature = "test-support")]
245            Self::Mock(c) => c.generate_batch(texts).await,
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_gemini_provider_creation() {
256        let provider = EmbeddingProviderEnum::gemini("test-key");
257        assert!(provider.is_ok());
258        let provider = provider.unwrap();
259        assert_eq!(provider.name(), "gemini");
260        assert_eq!(provider.dimension(), 768);
261    }
262
263    #[test]
264    fn test_openai_provider_creation() {
265        let provider = EmbeddingProviderEnum::openai("sk-test");
266        assert!(provider.is_ok());
267        let provider = provider.unwrap();
268        assert_eq!(provider.name(), "openai");
269        assert_eq!(provider.dimension(), 1536);
270    }
271
272    #[test]
273    fn test_openai_large_model() {
274        let provider =
275            EmbeddingProviderEnum::openai_with_model("sk-test", "text-embedding-3-large");
276        assert!(provider.is_ok());
277        let provider = provider.unwrap();
278        assert_eq!(provider.dimension(), 3072);
279    }
280
281    fn base_config(provider: &str) -> EmbeddingConfig {
282        EmbeddingConfig {
283            provider: provider.to_string(),
284            gemini_api_key: None,
285            openai_api_key: None,
286            embedding_model: None,
287            ollama_endpoint: None,
288        }
289    }
290
291    #[test]
292    fn test_from_config_gemini() {
293        let mut config = base_config("gemini");
294        config.gemini_api_key = Some("test-key".to_string());
295        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
296        assert!(matches!(provider, EmbeddingProviderEnum::Gemini(_)));
297    }
298
299    #[test]
300    fn test_from_config_openai_default_model() {
301        let mut config = base_config("openai");
302        config.openai_api_key = Some("sk-test".to_string());
303        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
304        assert!(matches!(provider, EmbeddingProviderEnum::OpenAI(_)));
305        assert_eq!(provider.dimension(), 1536);
306    }
307
308    #[test]
309    fn test_from_config_openai_custom_model() {
310        let mut config = base_config("openai");
311        config.openai_api_key = Some("sk-test".to_string());
312        config.embedding_model = Some("text-embedding-3-large".to_string());
313        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
314        assert_eq!(provider.dimension(), 3072);
315    }
316
317    #[test]
318    fn test_from_config_invalid_provider() {
319        let config = base_config("invalid");
320        assert!(EmbeddingProviderEnum::from_config(&config).is_err());
321    }
322
323    #[test]
324    fn test_from_config_missing_gemini_key() {
325        let config = base_config("gemini");
326        assert!(EmbeddingProviderEnum::from_config(&config).is_err());
327    }
328
329    #[test]
330    fn test_from_config_missing_openai_key() {
331        let config = base_config("openai");
332        assert!(EmbeddingProviderEnum::from_config(&config).is_err());
333    }
334
335    #[test]
336    fn test_ollama_provider_creation() {
337        let provider = EmbeddingProviderEnum::ollama();
338        assert!(provider.is_ok());
339        let provider = provider.unwrap();
340        assert_eq!(provider.name(), "ollama");
341        assert_eq!(provider.dimension(), 768);
342    }
343
344    #[test]
345    fn test_ollama_provider_custom_model() {
346        let provider = EmbeddingProviderEnum::ollama_with_config("mxbai-embed-large", None);
347        assert!(provider.is_ok());
348        let provider = provider.unwrap();
349        assert_eq!(provider.dimension(), 1024);
350    }
351
352    #[test]
353    fn test_from_config_ollama() {
354        let config = base_config("ollama");
355        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
356        assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
357        assert_eq!(provider.dimension(), 768);
358    }
359
360    #[test]
361    fn test_from_config_ollama_custom_model() {
362        let mut config = base_config("ollama");
363        config.embedding_model = Some("mxbai-embed-large".to_string());
364        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
365        assert_eq!(provider.dimension(), 1024);
366    }
367
368    #[test]
369    fn test_from_config_ollama_custom_endpoint() {
370        let mut config = base_config("ollama");
371        config.ollama_endpoint = Some("http://myhost:11434".to_string());
372        let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
373        assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
374    }
375}