1use anyhow::Context;
30use ceres_core::config::EmbeddingProviderType;
31use ceres_core::error::AppError;
32use ceres_core::traits::EmbeddingProvider;
33
34use crate::{GeminiClient, OllamaClient, OpenAIClient};
35
36pub struct EmbeddingConfig {
41 pub provider: String,
42 pub gemini_api_key: Option<String>,
43 pub openai_api_key: Option<String>,
44 pub embedding_model: Option<String>,
45 pub ollama_endpoint: Option<String>,
46}
47
48#[cfg(feature = "test-support")]
52#[derive(Clone, Debug)]
53pub struct MockEmbeddingClient {
54 dimension: usize,
55}
56
57#[cfg(feature = "test-support")]
58impl MockEmbeddingClient {
59 pub fn new() -> Self {
60 Self { dimension: 768 }
61 }
62}
63
64#[cfg(feature = "test-support")]
65impl Default for MockEmbeddingClient {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71#[cfg(feature = "test-support")]
72impl EmbeddingProvider for MockEmbeddingClient {
73 fn name(&self) -> &'static str {
74 "mock"
75 }
76 fn dimension(&self) -> usize {
77 self.dimension
78 }
79 fn max_batch_size(&self) -> usize {
80 100
81 }
82 async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
83 let seed = text.len() as f32;
84 Ok((0..self.dimension)
85 .map(|i| (seed + i as f32) / 1000.0)
86 .collect())
87 }
88}
89
90#[derive(Clone)]
95pub enum EmbeddingProviderEnum {
96 Gemini(GeminiClient),
98 OpenAI(OpenAIClient),
100 Ollama(OllamaClient),
102 #[cfg(feature = "test-support")]
104 Mock(MockEmbeddingClient),
105}
106
107impl EmbeddingProviderEnum {
108 pub fn gemini(api_key: &str) -> Result<Self, AppError> {
114 Ok(Self::Gemini(GeminiClient::new(api_key)?))
115 }
116
117 pub fn openai(api_key: &str) -> Result<Self, AppError> {
125 Ok(Self::OpenAI(OpenAIClient::new(api_key)?))
126 }
127
128 pub fn openai_with_model(api_key: &str, model: &str) -> Result<Self, AppError> {
135 Ok(Self::OpenAI(OpenAIClient::with_model(api_key, model)?))
136 }
137
138 pub fn ollama() -> Result<Self, AppError> {
142 Ok(Self::Ollama(OllamaClient::new()?))
143 }
144
145 pub fn ollama_with_config(model: &str, endpoint: Option<&str>) -> Result<Self, AppError> {
147 Ok(Self::Ollama(OllamaClient::with_config(model, endpoint)?))
148 }
149
150 #[cfg(feature = "test-support")]
152 pub fn mock() -> Self {
153 Self::Mock(MockEmbeddingClient::new())
154 }
155
156 pub fn from_config(config: &EmbeddingConfig) -> anyhow::Result<Self> {
161 let provider_type: EmbeddingProviderType = config
162 .provider
163 .parse()
164 .context("Invalid embedding provider")?;
165
166 match provider_type {
167 EmbeddingProviderType::Gemini => {
168 let api_key = config.gemini_api_key.as_ref().ok_or_else(|| {
169 anyhow::anyhow!("GEMINI_API_KEY required when using gemini provider")
170 })?;
171 Self::gemini(api_key).context("Failed to initialize Gemini client")
172 }
173 EmbeddingProviderType::OpenAI => {
174 let api_key = config.openai_api_key.as_ref().ok_or_else(|| {
175 anyhow::anyhow!("OPENAI_API_KEY required when using openai provider")
176 })?;
177
178 if let Some(model) = &config.embedding_model {
179 Self::openai_with_model(api_key, model)
180 .context("Failed to initialize OpenAI client")
181 } else {
182 Self::openai(api_key).context("Failed to initialize OpenAI client")
183 }
184 }
185 EmbeddingProviderType::Ollama => {
186 let model = config
187 .embedding_model
188 .as_deref()
189 .unwrap_or("nomic-embed-text");
190 let endpoint = config.ollama_endpoint.as_deref();
191 Self::ollama_with_config(model, endpoint)
192 .context("Failed to initialize Ollama client")
193 }
194 }
195 }
196}
197
198impl EmbeddingProvider for EmbeddingProviderEnum {
199 fn name(&self) -> &'static str {
200 match self {
201 Self::Gemini(c) => c.name(),
202 Self::OpenAI(c) => c.name(),
203 Self::Ollama(c) => c.name(),
204 #[cfg(feature = "test-support")]
205 Self::Mock(c) => c.name(),
206 }
207 }
208
209 fn dimension(&self) -> usize {
210 match self {
211 Self::Gemini(c) => c.dimension(),
212 Self::OpenAI(c) => c.dimension(),
213 Self::Ollama(c) => c.dimension(),
214 #[cfg(feature = "test-support")]
215 Self::Mock(c) => c.dimension(),
216 }
217 }
218
219 fn max_batch_size(&self) -> usize {
220 match self {
221 Self::Gemini(c) => c.max_batch_size(),
222 Self::OpenAI(c) => c.max_batch_size(),
223 Self::Ollama(c) => c.max_batch_size(),
224 #[cfg(feature = "test-support")]
225 Self::Mock(c) => c.max_batch_size(),
226 }
227 }
228
229 async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
230 match self {
231 Self::Gemini(c) => c.generate(text).await,
232 Self::OpenAI(c) => c.generate(text).await,
233 Self::Ollama(c) => c.generate(text).await,
234 #[cfg(feature = "test-support")]
235 Self::Mock(c) => c.generate(text).await,
236 }
237 }
238
239 async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
240 match self {
241 Self::Gemini(c) => c.generate_batch(texts).await,
242 Self::OpenAI(c) => c.generate_batch(texts).await,
243 Self::Ollama(c) => c.generate_batch(texts).await,
244 #[cfg(feature = "test-support")]
245 Self::Mock(c) => c.generate_batch(texts).await,
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_gemini_provider_creation() {
256 let provider = EmbeddingProviderEnum::gemini("test-key");
257 assert!(provider.is_ok());
258 let provider = provider.unwrap();
259 assert_eq!(provider.name(), "gemini");
260 assert_eq!(provider.dimension(), 768);
261 }
262
263 #[test]
264 fn test_openai_provider_creation() {
265 let provider = EmbeddingProviderEnum::openai("sk-test");
266 assert!(provider.is_ok());
267 let provider = provider.unwrap();
268 assert_eq!(provider.name(), "openai");
269 assert_eq!(provider.dimension(), 1536);
270 }
271
272 #[test]
273 fn test_openai_large_model() {
274 let provider =
275 EmbeddingProviderEnum::openai_with_model("sk-test", "text-embedding-3-large");
276 assert!(provider.is_ok());
277 let provider = provider.unwrap();
278 assert_eq!(provider.dimension(), 3072);
279 }
280
281 fn base_config(provider: &str) -> EmbeddingConfig {
282 EmbeddingConfig {
283 provider: provider.to_string(),
284 gemini_api_key: None,
285 openai_api_key: None,
286 embedding_model: None,
287 ollama_endpoint: None,
288 }
289 }
290
291 #[test]
292 fn test_from_config_gemini() {
293 let mut config = base_config("gemini");
294 config.gemini_api_key = Some("test-key".to_string());
295 let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
296 assert!(matches!(provider, EmbeddingProviderEnum::Gemini(_)));
297 }
298
299 #[test]
300 fn test_from_config_openai_default_model() {
301 let mut config = base_config("openai");
302 config.openai_api_key = Some("sk-test".to_string());
303 let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
304 assert!(matches!(provider, EmbeddingProviderEnum::OpenAI(_)));
305 assert_eq!(provider.dimension(), 1536);
306 }
307
308 #[test]
309 fn test_from_config_openai_custom_model() {
310 let mut config = base_config("openai");
311 config.openai_api_key = Some("sk-test".to_string());
312 config.embedding_model = Some("text-embedding-3-large".to_string());
313 let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
314 assert_eq!(provider.dimension(), 3072);
315 }
316
317 #[test]
318 fn test_from_config_invalid_provider() {
319 let config = base_config("invalid");
320 assert!(EmbeddingProviderEnum::from_config(&config).is_err());
321 }
322
323 #[test]
324 fn test_from_config_missing_gemini_key() {
325 let config = base_config("gemini");
326 assert!(EmbeddingProviderEnum::from_config(&config).is_err());
327 }
328
329 #[test]
330 fn test_from_config_missing_openai_key() {
331 let config = base_config("openai");
332 assert!(EmbeddingProviderEnum::from_config(&config).is_err());
333 }
334
335 #[test]
336 fn test_ollama_provider_creation() {
337 let provider = EmbeddingProviderEnum::ollama();
338 assert!(provider.is_ok());
339 let provider = provider.unwrap();
340 assert_eq!(provider.name(), "ollama");
341 assert_eq!(provider.dimension(), 768);
342 }
343
344 #[test]
345 fn test_ollama_provider_custom_model() {
346 let provider = EmbeddingProviderEnum::ollama_with_config("mxbai-embed-large", None);
347 assert!(provider.is_ok());
348 let provider = provider.unwrap();
349 assert_eq!(provider.dimension(), 1024);
350 }
351
352 #[test]
353 fn test_from_config_ollama() {
354 let config = base_config("ollama");
355 let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
356 assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
357 assert_eq!(provider.dimension(), 768);
358 }
359
360 #[test]
361 fn test_from_config_ollama_custom_model() {
362 let mut config = base_config("ollama");
363 config.embedding_model = Some("mxbai-embed-large".to_string());
364 let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
365 assert_eq!(provider.dimension(), 1024);
366 }
367
368 #[test]
369 fn test_from_config_ollama_custom_endpoint() {
370 let mut config = base_config("ollama");
371 config.ollama_endpoint = Some("http://myhost:11434".to_string());
372 let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
373 assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
374 }
375}