1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use backoff::{future::retry, ExponentialBackoff};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use tracing::{info, warn};
8
9#[async_trait]
11pub trait EmbeddingService: Send + Sync {
12 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
13 async fn health_check(&self) -> Result<()>;
14}
15
16#[derive(Debug, Clone)]
17pub struct SimpleEmbedder {
18 client: Client,
19 api_key: String,
20 model: String,
21 base_url: String,
22 provider: EmbeddingProvider,
23 fallback_models: Vec<String>,
24}
25
26#[derive(Debug, Clone, PartialEq)]
27pub enum EmbeddingProvider {
28 OpenAI,
29 Ollama,
30 Mock, }
32
33#[derive(Debug, Serialize)]
35struct OpenAIEmbeddingRequest {
36 input: String,
37 model: String,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAIEmbeddingResponse {
42 data: Vec<OpenAIEmbeddingData>,
43}
44
45#[derive(Debug, Deserialize)]
46struct OpenAIEmbeddingData {
47 embedding: Vec<f32>,
48}
49
50#[derive(Debug, Serialize)]
52struct OllamaEmbeddingRequest {
53 model: String,
54 prompt: String,
55}
56
57#[derive(Debug, Deserialize)]
58struct OllamaEmbeddingResponse {
59 embedding: Vec<f32>,
60}
61
62#[derive(Debug, Deserialize)]
63struct OllamaModel {
64 name: String,
65 #[allow(dead_code)]
66 size: u64,
67 #[serde(default)]
68 #[allow(dead_code)]
69 family: String,
70}
71
72#[derive(Debug, Deserialize)]
73struct OllamaModelsResponse {
74 models: Vec<OllamaModel>,
75}
76
77impl SimpleEmbedder {
78 pub fn new(api_key: String) -> Self {
79 let client = Client::builder()
80 .timeout(Duration::from_secs(30))
81 .build()
82 .expect("Failed to create HTTP client");
83
84 Self {
85 client,
86 api_key,
87 model: "text-embedding-3-small".to_string(),
88 base_url: "https://api.openai.com".to_string(),
89 provider: EmbeddingProvider::OpenAI,
90 fallback_models: vec![
91 "text-embedding-3-large".to_string(),
92 "text-embedding-ada-002".to_string(),
93 ],
94 }
95 }
96
97 pub fn new_ollama(base_url: String, model: String) -> Self {
98 let client = Client::builder()
99 .timeout(Duration::from_secs(60)) .build()
101 .expect("Failed to create HTTP client");
102
103 Self {
104 client,
105 api_key: String::new(), model,
107 base_url,
108 provider: EmbeddingProvider::Ollama,
109 fallback_models: vec![
110 "nomic-embed-text".to_string(),
111 "mxbai-embed-large".to_string(),
112 "all-minilm".to_string(),
113 "all-mpnet-base-v2".to_string(),
114 ],
115 }
116 }
117
118 pub fn new_mock() -> Self {
119 let client = Client::builder()
120 .timeout(Duration::from_secs(1))
121 .build()
122 .expect("Failed to create HTTP client");
123
124 Self {
125 client,
126 api_key: String::new(),
127 model: "mock-model".to_string(),
128 base_url: "http://mock:11434".to_string(),
129 provider: EmbeddingProvider::Mock,
130 fallback_models: vec!["mock-model-2".to_string()],
131 }
132 }
133
134 pub fn with_model(mut self, model: String) -> Self {
135 self.model = model;
136 self
137 }
138
139 pub fn with_base_url(mut self, base_url: String) -> Self {
140 self.base_url = base_url;
141 self
142 }
143
144 pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
146 info!("Generating embedding for text of length: {}", text.len());
147
148 let operation = || async {
149 match self.generate_embedding_internal(text).await {
150 Ok(embedding) => Ok(embedding),
151 Err(e) => {
152 if e.to_string().contains("Rate limited") {
153 Err(backoff::Error::transient(e))
154 } else {
155 Err(backoff::Error::permanent(e))
156 }
157 }
158 }
159 };
160
161 let backoff = ExponentialBackoff {
162 max_elapsed_time: Some(Duration::from_secs(60)),
163 ..Default::default()
164 };
165
166 retry(backoff, operation).await
167 }
168
169 async fn generate_embedding_internal(&self, text: &str) -> Result<Vec<f32>> {
170 match self.provider {
171 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
172 EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
173 EmbeddingProvider::Mock => self.generate_mock_embedding(text).await,
174 }
175 }
176
177 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
178 let request = OpenAIEmbeddingRequest {
179 input: text.to_string(),
180 model: self.model.clone(),
181 };
182
183 let response = self
184 .client
185 .post(format!("{}/v1/embeddings", self.base_url))
186 .header("Authorization", format!("Bearer {}", self.api_key))
187 .header("Content-Type", "application/json")
188 .json(&request)
189 .send()
190 .await?;
191
192 if !response.status().is_success() {
193 let status = response.status();
194 let error_text = response
195 .text()
196 .await
197 .unwrap_or_else(|_| "Unknown error".to_string());
198
199 if status.as_u16() == 429 {
200 warn!("Rate limited by OpenAI API, will retry");
201 return Err(anyhow::anyhow!("Rate limited: {}", error_text));
202 }
203
204 return Err(anyhow::anyhow!(
205 "OpenAI API request failed with status {}: {}",
206 status,
207 error_text
208 ));
209 }
210
211 let embedding_response: OpenAIEmbeddingResponse = response.json().await?;
212
213 if let Some(embedding_data) = embedding_response.data.first() {
214 Ok(embedding_data.embedding.clone())
215 } else {
216 Err(anyhow::anyhow!("No embedding data in OpenAI response"))
217 }
218 }
219
220 async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
221 let request = OllamaEmbeddingRequest {
222 model: self.model.clone(),
223 prompt: text.to_string(),
224 };
225
226 let response = self
227 .client
228 .post(format!("{}/api/embeddings", self.base_url))
229 .header("Content-Type", "application/json")
230 .json(&request)
231 .send()
232 .await?;
233
234 if !response.status().is_success() {
235 let status = response.status();
236 let error_text = response
237 .text()
238 .await
239 .unwrap_or_else(|_| "Unknown error".to_string());
240
241 if status.as_u16() == 429 {
242 warn!("Rate limited by Ollama API, will retry");
243 return Err(anyhow::anyhow!("Rate limited: {}", error_text));
244 }
245
246 return Err(anyhow::anyhow!(
247 "Ollama API request failed with status {}: {}",
248 status,
249 error_text
250 ));
251 }
252
253 let embedding_response: OllamaEmbeddingResponse = response.json().await?;
254 Ok(embedding_response.embedding)
255 }
256
257 async fn generate_mock_embedding(&self, text: &str) -> Result<Vec<f32>> {
258 use std::collections::hash_map::DefaultHasher;
261 use std::hash::{Hash, Hasher};
262
263 let mut hasher = DefaultHasher::new();
264 text.hash(&mut hasher);
265 let hash = hasher.finish();
266
267 let dimensions = self.embedding_dimension();
269 let mut embedding = Vec::with_capacity(dimensions);
270
271 let mut seed = hash;
273 for _ in 0..dimensions {
274 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
275 let value = ((seed >> 16) % 1000) as f32 / 1000.0 - 0.5; embedding.push(value);
277 }
278
279 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
281 if magnitude > 0.0 {
282 for val in &mut embedding {
283 *val /= magnitude;
284 }
285 }
286
287 Ok(embedding)
288 }
289
290 pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
292 info!("Generating embeddings for {} texts", texts.len());
293
294 let mut embeddings = Vec::with_capacity(texts.len());
295
296 for chunk in texts.chunks(10) {
298 let mut chunk_embeddings = Vec::with_capacity(chunk.len());
299
300 for text in chunk {
301 match self.generate_embedding(text).await {
302 Ok(embedding) => chunk_embeddings.push(embedding),
303 Err(e) => {
304 warn!("Failed to generate embedding for text: {}", e);
305 return Err(e);
306 }
307 }
308
309 tokio::time::sleep(Duration::from_millis(100)).await;
311 }
312
313 embeddings.extend(chunk_embeddings);
314 }
315
316 Ok(embeddings)
317 }
318
319 pub fn embedding_dimension(&self) -> usize {
321 match self.provider {
322 EmbeddingProvider::OpenAI => match self.model.as_str() {
323 "text-embedding-3-small" => 1536,
324 "text-embedding-3-large" => 3072,
325 "text-embedding-ada-002" => 1536,
326 _ => 1536, },
328 EmbeddingProvider::Ollama => {
329 match self.model.as_str() {
331 "gpt-oss:20b" => 4096, "nomic-embed-text" => 768,
333 "mxbai-embed-large" => 1024,
334 "all-minilm" => 384,
335 _ => 768, }
337 }
338 EmbeddingProvider::Mock => 768, }
340 }
341
342 pub fn provider(&self) -> &EmbeddingProvider {
344 &self.provider
345 }
346
347 pub async fn auto_configure(base_url: String) -> Result<Self> {
349 info!("🔍 Auto-detecting best available embedding model...");
350
351 let client = Client::builder()
352 .timeout(Duration::from_secs(30))
353 .build()
354 .context("Failed to create HTTP client")?;
355
356 let available_models = Self::detect_ollama_models(&client, &base_url).await?;
358
359 if available_models.is_empty() {
360 return Err(anyhow::anyhow!(
361 "No embedding models found on Ollama server"
362 ));
363 }
364
365 let selected_model = Self::select_best_model(&available_models)?;
367
368 info!(
369 "✅ Selected model: {} ({}D)",
370 selected_model.name, selected_model.dimensions
371 );
372
373 let mut embedder = Self::new_ollama(base_url, selected_model.name.clone());
374 embedder.fallback_models = available_models
375 .into_iter()
376 .filter(|m| m.name != embedder.model)
377 .map(|m| m.name)
378 .collect();
379
380 Ok(embedder)
381 }
382
383 pub async fn generate_embedding_with_fallback(&self, text: &str) -> Result<Vec<f32>> {
385 match self.generate_embedding(text).await {
387 Ok(embedding) => return Ok(embedding),
388 Err(e) => {
389 warn!("Primary model '{}' failed: {}", self.model, e);
390 }
391 }
392
393 for fallback_model in &self.fallback_models {
395 info!("🔄 Trying fallback model: {}", fallback_model);
396
397 let mut fallback_embedder = self.clone();
398 fallback_embedder.model = fallback_model.clone();
399
400 match fallback_embedder.generate_embedding(text).await {
401 Ok(embedding) => {
402 info!("✅ Fallback model '{}' succeeded", fallback_model);
403 return Ok(embedding);
404 }
405 Err(e) => {
406 warn!("Fallback model '{}' failed: {}", fallback_model, e);
407 continue;
408 }
409 }
410 }
411
412 Err(anyhow::anyhow!(
413 "All embedding models failed, including fallbacks"
414 ))
415 }
416
417 pub async fn health_check(&self) -> Result<EmbeddingHealth> {
419 let start_time = std::time::Instant::now();
420
421 let test_result = self.generate_embedding("Health check test").await;
422 let response_time = start_time.elapsed();
423
424 let health = match test_result {
425 Ok(embedding) => EmbeddingHealth {
426 status: "healthy".to_string(),
427 model: self.model.clone(),
428 provider: format!("{:?}", self.provider),
429 response_time_ms: response_time.as_millis() as u64,
430 embedding_dimensions: embedding.len(),
431 error: None,
432 },
433 Err(e) => EmbeddingHealth {
434 status: "unhealthy".to_string(),
435 model: self.model.clone(),
436 provider: format!("{:?}", self.provider),
437 response_time_ms: response_time.as_millis() as u64,
438 embedding_dimensions: 0,
439 error: Some(e.to_string()),
440 },
441 };
442
443 Ok(health)
444 }
445
446 async fn detect_ollama_models(
448 client: &Client,
449 base_url: &str,
450 ) -> Result<Vec<EmbeddingModelInfo>> {
451 let response = client
452 .get(format!("{base_url}/api/tags"))
453 .send()
454 .await
455 .context("Failed to connect to Ollama API")?;
456
457 if !response.status().is_success() {
458 return Err(anyhow::anyhow!(
459 "Ollama API returned error: {}",
460 response.status()
461 ));
462 }
463
464 let models_response: OllamaModelsResponse = response
465 .json()
466 .await
467 .context("Failed to parse Ollama models response")?;
468
469 let mut embedding_models = Vec::new();
470
471 for model in models_response.models {
472 if let Some(model_info) = Self::classify_embedding_model(&model.name) {
473 embedding_models.push(model_info);
474 }
475 }
476
477 Ok(embedding_models)
478 }
479
480 fn classify_embedding_model(model_name: &str) -> Option<EmbeddingModelInfo> {
482 let name_lower = model_name.to_lowercase();
483
484 let known_models = [
486 (
487 "nomic-embed-text",
488 768,
489 "High-quality text embeddings",
490 true,
491 ),
492 (
493 "mxbai-embed-large",
494 1024,
495 "Large multilingual embeddings",
496 true,
497 ),
498 ("all-minilm", 384, "Compact sentence embeddings", false),
499 (
500 "all-mpnet-base-v2",
501 768,
502 "Sentence transformer embeddings",
503 false,
504 ),
505 ("bge-small-en", 384, "BGE small English embeddings", false),
506 ("bge-base-en", 768, "BGE base English embeddings", false),
507 ("bge-large-en", 1024, "BGE large English embeddings", false),
508 ("e5-small", 384, "E5 small embeddings", false),
509 ("e5-base", 768, "E5 base embeddings", false),
510 ("e5-large", 1024, "E5 large embeddings", false),
511 ];
512
513 for (pattern, dimensions, description, preferred) in known_models {
514 if name_lower.contains(pattern) || model_name.contains(pattern) {
515 return Some(EmbeddingModelInfo {
516 name: model_name.to_string(),
517 dimensions,
518 description: description.to_string(),
519 preferred,
520 });
521 }
522 }
523
524 if name_lower.contains("embed")
526 || name_lower.contains("sentence")
527 || name_lower.contains("vector")
528 {
529 return Some(EmbeddingModelInfo {
530 name: model_name.to_string(),
531 dimensions: 768, description: "Detected embedding model".to_string(),
533 preferred: false,
534 });
535 }
536
537 None
538 }
539
540 fn select_best_model(available_models: &[EmbeddingModelInfo]) -> Result<&EmbeddingModelInfo> {
542 if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
544 return Ok(preferred);
545 }
546
547 available_models
549 .first()
550 .ok_or_else(|| anyhow::anyhow!("No embedding models available"))
551 }
552}
553
554#[async_trait]
555impl EmbeddingService for SimpleEmbedder {
556 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
557 SimpleEmbedder::generate_embedding(self, text).await
558 }
559
560 async fn health_check(&self) -> Result<()> {
561 let health = SimpleEmbedder::health_check(self).await?;
562 if health.status == "healthy" {
563 Ok(())
564 } else {
565 Err(anyhow::anyhow!(
566 "Embedding service unhealthy: {:?}",
567 health.error
568 ))
569 }
570 }
571}
572
573#[derive(Debug, Clone)]
575pub struct EmbeddingModelInfo {
576 pub name: String,
577 pub dimensions: usize,
578 pub description: String,
579 pub preferred: bool,
580}
581
582#[derive(Debug, Clone, Serialize, Deserialize)]
584pub struct EmbeddingHealth {
585 pub status: String,
586 pub model: String,
587 pub provider: String,
588 pub response_time_ms: u64,
589 pub embedding_dimensions: usize,
590 pub error: Option<String>,
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[tokio::test]
598 #[ignore] async fn test_generate_openai_embedding() {
600 let api_key = match std::env::var("OPENAI_API_KEY") {
601 Ok(key) => key,
602 Err(_) => {
603 eprintln!("OPENAI_API_KEY not set, skipping test");
604 return;
605 }
606 };
607 let embedder = SimpleEmbedder::new(api_key);
608
609 let result = embedder.generate_embedding("Hello, world!").await;
610 assert!(result.is_ok());
611
612 let embedding = result.unwrap();
613 assert_eq!(embedding.len(), 1536);
614 }
615
616 #[tokio::test]
617 #[ignore] async fn test_generate_ollama_embedding() {
619 let embedder = SimpleEmbedder::new_ollama(
620 "http://192.168.1.110:11434".to_string(),
621 "nomic-embed-text".to_string(),
622 );
623
624 let result = embedder.generate_embedding("Hello, world!").await;
625 assert!(result.is_ok());
626
627 let embedding = result.unwrap();
628 assert_eq!(embedding.len(), 768);
629 }
630
631 #[test]
632 fn test_embedding_dimensions() {
633 let embedder = SimpleEmbedder::new("dummy_key".to_string());
634 assert_eq!(embedder.embedding_dimension(), 1536);
635
636 let embedder = embedder.with_model("text-embedding-3-large".to_string());
637 assert_eq!(embedder.embedding_dimension(), 3072);
638
639 let ollama_embedder = SimpleEmbedder::new_ollama(
640 "http://localhost:11434".to_string(),
641 "nomic-embed-text".to_string(),
642 );
643 assert_eq!(ollama_embedder.embedding_dimension(), 768);
644
645 let gpt_oss_embedder = SimpleEmbedder::new_ollama(
646 "http://localhost:11434".to_string(),
647 "gpt-oss:20b".to_string(),
648 );
649 assert_eq!(gpt_oss_embedder.embedding_dimension(), 4096);
650
651 let mock_embedder = SimpleEmbedder::new_mock();
652 assert_eq!(mock_embedder.embedding_dimension(), 768);
653 }
654
655 #[test]
656 fn test_provider_types() {
657 let openai_embedder = SimpleEmbedder::new("dummy_key".to_string());
658 assert_eq!(openai_embedder.provider(), &EmbeddingProvider::OpenAI);
659
660 let ollama_embedder = SimpleEmbedder::new_ollama(
661 "http://localhost:11434".to_string(),
662 "nomic-embed-text".to_string(),
663 );
664 assert_eq!(ollama_embedder.provider(), &EmbeddingProvider::Ollama);
665
666 let mock_embedder = SimpleEmbedder::new_mock();
667 assert_eq!(mock_embedder.provider(), &EmbeddingProvider::Mock);
668 }
669}