1use anyhow::{Context, Result};
2use backoff::{future::retry, ExponentialBackoff};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use tracing::{info, warn};
7
8#[derive(Debug, Clone)]
9pub struct SimpleEmbedder {
10 client: Client,
11 api_key: String,
12 model: String,
13 base_url: String,
14 provider: EmbeddingProvider,
15 fallback_models: Vec<String>,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub enum EmbeddingProvider {
20 OpenAI,
21 Ollama,
22 Mock, }
24
25#[derive(Debug, Serialize)]
27struct OpenAIEmbeddingRequest {
28 input: String,
29 model: String,
30}
31
32#[derive(Debug, Deserialize)]
33struct OpenAIEmbeddingResponse {
34 data: Vec<OpenAIEmbeddingData>,
35}
36
37#[derive(Debug, Deserialize)]
38struct OpenAIEmbeddingData {
39 embedding: Vec<f32>,
40}
41
42#[derive(Debug, Serialize)]
44struct OllamaEmbeddingRequest {
45 model: String,
46 prompt: String,
47}
48
49#[derive(Debug, Deserialize)]
50struct OllamaEmbeddingResponse {
51 embedding: Vec<f32>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OllamaModel {
56 name: String,
57 #[allow(dead_code)]
58 size: u64,
59 #[serde(default)]
60 #[allow(dead_code)]
61 family: String,
62}
63
64#[derive(Debug, Deserialize)]
65struct OllamaModelsResponse {
66 models: Vec<OllamaModel>,
67}
68
69impl SimpleEmbedder {
70 pub fn new(api_key: String) -> Self {
71 let client = Client::builder()
72 .timeout(Duration::from_secs(30))
73 .build()
74 .expect("Failed to create HTTP client");
75
76 Self {
77 client,
78 api_key,
79 model: "text-embedding-3-small".to_string(),
80 base_url: "https://api.openai.com".to_string(),
81 provider: EmbeddingProvider::OpenAI,
82 fallback_models: vec![
83 "text-embedding-3-large".to_string(),
84 "text-embedding-ada-002".to_string(),
85 ],
86 }
87 }
88
89 pub fn new_ollama(base_url: String, model: String) -> Self {
90 let client = Client::builder()
91 .timeout(Duration::from_secs(60)) .build()
93 .expect("Failed to create HTTP client");
94
95 Self {
96 client,
97 api_key: String::new(), model,
99 base_url,
100 provider: EmbeddingProvider::Ollama,
101 fallback_models: vec![
102 "nomic-embed-text".to_string(),
103 "mxbai-embed-large".to_string(),
104 "all-minilm".to_string(),
105 "all-mpnet-base-v2".to_string(),
106 ],
107 }
108 }
109
110 pub fn new_mock() -> Self {
111 let client = Client::builder()
112 .timeout(Duration::from_secs(1))
113 .build()
114 .expect("Failed to create HTTP client");
115
116 Self {
117 client,
118 api_key: String::new(),
119 model: "mock-model".to_string(),
120 base_url: "http://mock:11434".to_string(),
121 provider: EmbeddingProvider::Mock,
122 fallback_models: vec!["mock-model-2".to_string()],
123 }
124 }
125
126 pub fn with_model(mut self, model: String) -> Self {
127 self.model = model;
128 self
129 }
130
131 pub fn with_base_url(mut self, base_url: String) -> Self {
132 self.base_url = base_url;
133 self
134 }
135
136 pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
138 info!("Generating embedding for text of length: {}", text.len());
139
140 let operation = || async {
141 match self.generate_embedding_internal(text).await {
142 Ok(embedding) => Ok(embedding),
143 Err(e) => {
144 if e.to_string().contains("Rate limited") {
145 Err(backoff::Error::transient(e))
146 } else {
147 Err(backoff::Error::permanent(e))
148 }
149 }
150 }
151 };
152
153 let backoff = ExponentialBackoff {
154 max_elapsed_time: Some(Duration::from_secs(60)),
155 ..Default::default()
156 };
157
158 retry(backoff, operation).await
159 }
160
161 async fn generate_embedding_internal(&self, text: &str) -> Result<Vec<f32>> {
162 match self.provider {
163 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
164 EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
165 EmbeddingProvider::Mock => self.generate_mock_embedding(text).await,
166 }
167 }
168
169 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
170 let request = OpenAIEmbeddingRequest {
171 input: text.to_string(),
172 model: self.model.clone(),
173 };
174
175 let response = self
176 .client
177 .post(&format!("{}/v1/embeddings", self.base_url))
178 .header("Authorization", format!("Bearer {}", self.api_key))
179 .header("Content-Type", "application/json")
180 .json(&request)
181 .send()
182 .await?;
183
184 if !response.status().is_success() {
185 let status = response.status();
186 let error_text = response
187 .text()
188 .await
189 .unwrap_or_else(|_| "Unknown error".to_string());
190
191 if status.as_u16() == 429 {
192 warn!("Rate limited by OpenAI API, will retry");
193 return Err(anyhow::anyhow!("Rate limited: {}", error_text));
194 }
195
196 return Err(anyhow::anyhow!(
197 "OpenAI API request failed with status {}: {}",
198 status,
199 error_text
200 ));
201 }
202
203 let embedding_response: OpenAIEmbeddingResponse = response.json().await?;
204
205 if let Some(embedding_data) = embedding_response.data.first() {
206 Ok(embedding_data.embedding.clone())
207 } else {
208 Err(anyhow::anyhow!("No embedding data in OpenAI response"))
209 }
210 }
211
212 async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
213 let request = OllamaEmbeddingRequest {
214 model: self.model.clone(),
215 prompt: text.to_string(),
216 };
217
218 let response = self
219 .client
220 .post(&format!("{}/api/embeddings", self.base_url))
221 .header("Content-Type", "application/json")
222 .json(&request)
223 .send()
224 .await?;
225
226 if !response.status().is_success() {
227 let status = response.status();
228 let error_text = response
229 .text()
230 .await
231 .unwrap_or_else(|_| "Unknown error".to_string());
232
233 if status.as_u16() == 429 {
234 warn!("Rate limited by Ollama API, will retry");
235 return Err(anyhow::anyhow!("Rate limited: {}", error_text));
236 }
237
238 return Err(anyhow::anyhow!(
239 "Ollama API request failed with status {}: {}",
240 status,
241 error_text
242 ));
243 }
244
245 let embedding_response: OllamaEmbeddingResponse = response.json().await?;
246 Ok(embedding_response.embedding)
247 }
248
249 async fn generate_mock_embedding(&self, text: &str) -> Result<Vec<f32>> {
250 use std::collections::hash_map::DefaultHasher;
253 use std::hash::{Hash, Hasher};
254
255 let mut hasher = DefaultHasher::new();
256 text.hash(&mut hasher);
257 let hash = hasher.finish();
258
259 let dimensions = self.embedding_dimension();
261 let mut embedding = Vec::with_capacity(dimensions);
262
263 let mut seed = hash;
265 for _ in 0..dimensions {
266 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
267 let value = ((seed >> 16) % 1000) as f32 / 1000.0 - 0.5; embedding.push(value);
269 }
270
271 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
273 if magnitude > 0.0 {
274 for val in &mut embedding {
275 *val /= magnitude;
276 }
277 }
278
279 Ok(embedding)
280 }
281
282 pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
284 info!("Generating embeddings for {} texts", texts.len());
285
286 let mut embeddings = Vec::with_capacity(texts.len());
287
288 for chunk in texts.chunks(10) {
290 let mut chunk_embeddings = Vec::with_capacity(chunk.len());
291
292 for text in chunk {
293 match self.generate_embedding(text).await {
294 Ok(embedding) => chunk_embeddings.push(embedding),
295 Err(e) => {
296 warn!("Failed to generate embedding for text: {}", e);
297 return Err(e);
298 }
299 }
300
301 tokio::time::sleep(Duration::from_millis(100)).await;
303 }
304
305 embeddings.extend(chunk_embeddings);
306 }
307
308 Ok(embeddings)
309 }
310
311 pub fn embedding_dimension(&self) -> usize {
313 match self.provider {
314 EmbeddingProvider::OpenAI => match self.model.as_str() {
315 "text-embedding-3-small" => 1536,
316 "text-embedding-3-large" => 3072,
317 "text-embedding-ada-002" => 1536,
318 _ => 1536, },
320 EmbeddingProvider::Ollama => {
321 match self.model.as_str() {
323 "gpt-oss:20b" => 4096, "nomic-embed-text" => 768,
325 "mxbai-embed-large" => 1024,
326 "all-minilm" => 384,
327 _ => 768, }
329 }
330 EmbeddingProvider::Mock => 768, }
332 }
333
334 pub fn provider(&self) -> &EmbeddingProvider {
336 &self.provider
337 }
338
339 pub async fn auto_configure(base_url: String) -> Result<Self> {
341 info!("🔍 Auto-detecting best available embedding model...");
342
343 let client = Client::builder()
344 .timeout(Duration::from_secs(30))
345 .build()
346 .context("Failed to create HTTP client")?;
347
348 let available_models = Self::detect_ollama_models(&client, &base_url).await?;
350
351 if available_models.is_empty() {
352 return Err(anyhow::anyhow!(
353 "No embedding models found on Ollama server"
354 ));
355 }
356
357 let selected_model = Self::select_best_model(&available_models)?;
359
360 info!(
361 "✅ Selected model: {} ({}D)",
362 selected_model.name, selected_model.dimensions
363 );
364
365 let mut embedder = Self::new_ollama(base_url, selected_model.name.clone());
366 embedder.fallback_models = available_models
367 .into_iter()
368 .filter(|m| m.name != embedder.model)
369 .map(|m| m.name)
370 .collect();
371
372 Ok(embedder)
373 }
374
375 pub async fn generate_embedding_with_fallback(&self, text: &str) -> Result<Vec<f32>> {
377 match self.generate_embedding(text).await {
379 Ok(embedding) => return Ok(embedding),
380 Err(e) => {
381 warn!("Primary model '{}' failed: {}", self.model, e);
382 }
383 }
384
385 for fallback_model in &self.fallback_models {
387 info!("🔄 Trying fallback model: {}", fallback_model);
388
389 let mut fallback_embedder = self.clone();
390 fallback_embedder.model = fallback_model.clone();
391
392 match fallback_embedder.generate_embedding(text).await {
393 Ok(embedding) => {
394 info!("✅ Fallback model '{}' succeeded", fallback_model);
395 return Ok(embedding);
396 }
397 Err(e) => {
398 warn!("Fallback model '{}' failed: {}", fallback_model, e);
399 continue;
400 }
401 }
402 }
403
404 Err(anyhow::anyhow!(
405 "All embedding models failed, including fallbacks"
406 ))
407 }
408
409 pub async fn health_check(&self) -> Result<EmbeddingHealth> {
411 let start_time = std::time::Instant::now();
412
413 let test_result = self.generate_embedding("Health check test").await;
414 let response_time = start_time.elapsed();
415
416 let health = match test_result {
417 Ok(embedding) => EmbeddingHealth {
418 status: "healthy".to_string(),
419 model: self.model.clone(),
420 provider: format!("{:?}", self.provider),
421 response_time_ms: response_time.as_millis() as u64,
422 embedding_dimensions: embedding.len(),
423 error: None,
424 },
425 Err(e) => EmbeddingHealth {
426 status: "unhealthy".to_string(),
427 model: self.model.clone(),
428 provider: format!("{:?}", self.provider),
429 response_time_ms: response_time.as_millis() as u64,
430 embedding_dimensions: 0,
431 error: Some(e.to_string()),
432 },
433 };
434
435 Ok(health)
436 }
437
438 async fn detect_ollama_models(
440 client: &Client,
441 base_url: &str,
442 ) -> Result<Vec<EmbeddingModelInfo>> {
443 let response = client
444 .get(&format!("{}/api/tags", base_url))
445 .send()
446 .await
447 .context("Failed to connect to Ollama API")?;
448
449 if !response.status().is_success() {
450 return Err(anyhow::anyhow!(
451 "Ollama API returned error: {}",
452 response.status()
453 ));
454 }
455
456 let models_response: OllamaModelsResponse = response
457 .json()
458 .await
459 .context("Failed to parse Ollama models response")?;
460
461 let mut embedding_models = Vec::new();
462
463 for model in models_response.models {
464 if let Some(model_info) = Self::classify_embedding_model(&model.name) {
465 embedding_models.push(model_info);
466 }
467 }
468
469 Ok(embedding_models)
470 }
471
472 fn classify_embedding_model(model_name: &str) -> Option<EmbeddingModelInfo> {
474 let name_lower = model_name.to_lowercase();
475
476 let known_models = [
478 (
479 "nomic-embed-text",
480 768,
481 "High-quality text embeddings",
482 true,
483 ),
484 (
485 "mxbai-embed-large",
486 1024,
487 "Large multilingual embeddings",
488 true,
489 ),
490 ("all-minilm", 384, "Compact sentence embeddings", false),
491 (
492 "all-mpnet-base-v2",
493 768,
494 "Sentence transformer embeddings",
495 false,
496 ),
497 ("bge-small-en", 384, "BGE small English embeddings", false),
498 ("bge-base-en", 768, "BGE base English embeddings", false),
499 ("bge-large-en", 1024, "BGE large English embeddings", false),
500 ("e5-small", 384, "E5 small embeddings", false),
501 ("e5-base", 768, "E5 base embeddings", false),
502 ("e5-large", 1024, "E5 large embeddings", false),
503 ];
504
505 for (pattern, dimensions, description, preferred) in known_models {
506 if name_lower.contains(pattern) || model_name.contains(pattern) {
507 return Some(EmbeddingModelInfo {
508 name: model_name.to_string(),
509 dimensions,
510 description: description.to_string(),
511 preferred,
512 });
513 }
514 }
515
516 if name_lower.contains("embed")
518 || name_lower.contains("sentence")
519 || name_lower.contains("vector")
520 {
521 return Some(EmbeddingModelInfo {
522 name: model_name.to_string(),
523 dimensions: 768, description: "Detected embedding model".to_string(),
525 preferred: false,
526 });
527 }
528
529 None
530 }
531
532 fn select_best_model(available_models: &[EmbeddingModelInfo]) -> Result<&EmbeddingModelInfo> {
534 if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
536 return Ok(preferred);
537 }
538
539 available_models
541 .first()
542 .ok_or_else(|| anyhow::anyhow!("No embedding models available"))
543 }
544}
545
546#[derive(Debug, Clone)]
548pub struct EmbeddingModelInfo {
549 pub name: String,
550 pub dimensions: usize,
551 pub description: String,
552 pub preferred: bool,
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct EmbeddingHealth {
558 pub status: String,
559 pub model: String,
560 pub provider: String,
561 pub response_time_ms: u64,
562 pub embedding_dimensions: usize,
563 pub error: Option<String>,
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[tokio::test]
571 #[ignore] async fn test_generate_openai_embedding() {
573 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
574 let embedder = SimpleEmbedder::new(api_key);
575
576 let result = embedder.generate_embedding("Hello, world!").await;
577 assert!(result.is_ok());
578
579 let embedding = result.unwrap();
580 assert_eq!(embedding.len(), 1536);
581 }
582
583 #[tokio::test]
584 #[ignore] async fn test_generate_ollama_embedding() {
586 let embedder = SimpleEmbedder::new_ollama(
587 "http://192.168.1.110:11434".to_string(),
588 "nomic-embed-text".to_string(),
589 );
590
591 let result = embedder.generate_embedding("Hello, world!").await;
592 assert!(result.is_ok());
593
594 let embedding = result.unwrap();
595 assert_eq!(embedding.len(), 768);
596 }
597
598 #[test]
599 fn test_embedding_dimensions() {
600 let embedder = SimpleEmbedder::new("dummy_key".to_string());
601 assert_eq!(embedder.embedding_dimension(), 1536);
602
603 let embedder = embedder.with_model("text-embedding-3-large".to_string());
604 assert_eq!(embedder.embedding_dimension(), 3072);
605
606 let ollama_embedder = SimpleEmbedder::new_ollama(
607 "http://localhost:11434".to_string(),
608 "nomic-embed-text".to_string(),
609 );
610 assert_eq!(ollama_embedder.embedding_dimension(), 768);
611
612 let gpt_oss_embedder = SimpleEmbedder::new_ollama(
613 "http://localhost:11434".to_string(),
614 "gpt-oss:20b".to_string(),
615 );
616 assert_eq!(gpt_oss_embedder.embedding_dimension(), 4096);
617
618 let mock_embedder = SimpleEmbedder::new_mock();
619 assert_eq!(mock_embedder.embedding_dimension(), 768);
620 }
621
622 #[test]
623 fn test_provider_types() {
624 let openai_embedder = SimpleEmbedder::new("dummy_key".to_string());
625 assert_eq!(openai_embedder.provider(), &EmbeddingProvider::OpenAI);
626
627 let ollama_embedder = SimpleEmbedder::new_ollama(
628 "http://localhost:11434".to_string(),
629 "nomic-embed-text".to_string(),
630 );
631 assert_eq!(ollama_embedder.provider(), &EmbeddingProvider::Ollama);
632
633 let mock_embedder = SimpleEmbedder::new_mock();
634 assert_eq!(mock_embedder.provider(), &EmbeddingProvider::Mock);
635 }
636}