ceres_client/provider.rs
1//! Embedding provider factory and dynamic dispatch.
2//!
3//! This module provides a unified interface for working with different
4//! embedding providers through the [`EmbeddingProviderEnum`] enum.
5//!
6//! # Why an Enum Instead of `dyn Trait`?
7//!
8//! The [`EmbeddingProvider`] trait uses `impl Future` return types (RPITIT),
9//! which makes it not object-safe. We use an enum to provide dynamic dispatch
10//! while maintaining the ergonomic async trait syntax.
11//!
12//! # Usage
13//!
14//! ```no_run
15//! use ceres_client::provider::EmbeddingProviderEnum;
16//! use ceres_core::traits::EmbeddingProvider;
17//!
18//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
19//! // Create provider based on configuration
20//! let provider = EmbeddingProviderEnum::gemini("your-api-key")?;
21//!
22//! // Use the provider generically
23//! println!("Using {} provider ({} dimensions)", provider.name(), provider.dimension());
24//! let embedding = provider.generate("Hello world").await?;
25//! # Ok(())
26//! # }
27//! ```
28
29use ceres_core::error::AppError;
30use ceres_core::traits::EmbeddingProvider;
31
32use crate::{GeminiClient, OpenAIClient};
33
34/// Unified embedding provider that wraps concrete implementations.
35///
36/// This enum allows runtime selection of embedding providers while
37/// implementing the `EmbeddingProvider` trait.
38#[derive(Clone)]
39pub enum EmbeddingProviderEnum {
40 /// Google Gemini embedding provider (768 dimensions).
41 Gemini(GeminiClient),
42 /// OpenAI embedding provider (1536 or 3072 dimensions).
43 OpenAI(OpenAIClient),
44}
45
46impl EmbeddingProviderEnum {
47 /// Creates a Gemini embedding provider.
48 ///
49 /// # Arguments
50 ///
51 /// * `api_key` - Google Gemini API key
52 pub fn gemini(api_key: &str) -> Result<Self, AppError> {
53 Ok(Self::Gemini(GeminiClient::new(api_key)?))
54 }
55
56 /// Creates an OpenAI embedding provider with the default model.
57 ///
58 /// Uses `text-embedding-3-small` (1536 dimensions).
59 ///
60 /// # Arguments
61 ///
62 /// * `api_key` - OpenAI API key (starts with `sk-`)
63 pub fn openai(api_key: &str) -> Result<Self, AppError> {
64 Ok(Self::OpenAI(OpenAIClient::new(api_key)?))
65 }
66
67 /// Creates an OpenAI embedding provider with a specific model.
68 ///
69 /// # Arguments
70 ///
71 /// * `api_key` - OpenAI API key
72 /// * `model` - Model name (e.g., `text-embedding-3-large`)
73 pub fn openai_with_model(api_key: &str, model: &str) -> Result<Self, AppError> {
74 Ok(Self::OpenAI(OpenAIClient::with_model(api_key, model)?))
75 }
76}
77
78impl EmbeddingProvider for EmbeddingProviderEnum {
79 fn name(&self) -> &'static str {
80 match self {
81 Self::Gemini(c) => c.name(),
82 Self::OpenAI(c) => c.name(),
83 }
84 }
85
86 fn dimension(&self) -> usize {
87 match self {
88 Self::Gemini(c) => c.dimension(),
89 Self::OpenAI(c) => c.dimension(),
90 }
91 }
92
93 async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
94 match self {
95 Self::Gemini(c) => c.generate(text).await,
96 Self::OpenAI(c) => c.generate(text).await,
97 }
98 }
99
100 async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
101 match self {
102 Self::Gemini(c) => c.generate_batch(texts).await,
103 Self::OpenAI(c) => c.generate_batch(texts).await,
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_gemini_provider_creation() {
114 let provider = EmbeddingProviderEnum::gemini("test-key");
115 assert!(provider.is_ok());
116 let provider = provider.unwrap();
117 assert_eq!(provider.name(), "gemini");
118 assert_eq!(provider.dimension(), 768);
119 }
120
121 #[test]
122 fn test_openai_provider_creation() {
123 let provider = EmbeddingProviderEnum::openai("sk-test");
124 assert!(provider.is_ok());
125 let provider = provider.unwrap();
126 assert_eq!(provider.name(), "openai");
127 assert_eq!(provider.dimension(), 1536);
128 }
129
130 #[test]
131 fn test_openai_large_model() {
132 let provider =
133 EmbeddingProviderEnum::openai_with_model("sk-test", "text-embedding-3-large");
134 assert!(provider.is_ok());
135 let provider = provider.unwrap();
136 assert_eq!(provider.dimension(), 3072);
137 }
138}