1use anyhow::{Context, Result};
2use backoff::{future::retry, ExponentialBackoff};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use tracing::{info, warn};
7
8#[derive(Debug, Clone)]
9pub struct SimpleEmbedder {
10 client: Client,
11 api_key: String,
12 model: String,
13 base_url: String,
14 provider: EmbeddingProvider,
15 fallback_models: Vec<String>,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub enum EmbeddingProvider {
20 OpenAI,
21 Ollama,
22 Mock, }
24
25#[derive(Debug, Serialize)]
27struct OpenAIEmbeddingRequest {
28 input: String,
29 model: String,
30}
31
32#[derive(Debug, Deserialize)]
33struct OpenAIEmbeddingResponse {
34 data: Vec<OpenAIEmbeddingData>,
35}
36
37#[derive(Debug, Deserialize)]
38struct OpenAIEmbeddingData {
39 embedding: Vec<f32>,
40}
41
42#[derive(Debug, Serialize)]
44struct OllamaEmbeddingRequest {
45 model: String,
46 prompt: String,
47}
48
49#[derive(Debug, Deserialize)]
50struct OllamaEmbeddingResponse {
51 embedding: Vec<f32>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OllamaModel {
56 name: String,
57 #[allow(dead_code)]
58 size: u64,
59 #[serde(default)]
60 #[allow(dead_code)]
61 family: String,
62}
63
64#[derive(Debug, Deserialize)]
65struct OllamaModelsResponse {
66 models: Vec<OllamaModel>,
67}
68
69impl SimpleEmbedder {
70 pub fn new(api_key: String) -> Self {
71 let client = Client::builder()
72 .timeout(Duration::from_secs(30))
73 .build()
74 .expect("Failed to create HTTP client");
75
76 Self {
77 client,
78 api_key,
79 model: "text-embedding-3-small".to_string(),
80 base_url: "https://api.openai.com".to_string(),
81 provider: EmbeddingProvider::OpenAI,
82 fallback_models: vec![
83 "text-embedding-3-large".to_string(),
84 "text-embedding-ada-002".to_string(),
85 ],
86 }
87 }
88
89 pub fn new_ollama(base_url: String, model: String) -> Self {
90 let client = Client::builder()
91 .timeout(Duration::from_secs(60)) .build()
93 .expect("Failed to create HTTP client");
94
95 Self {
96 client,
97 api_key: String::new(), model,
99 base_url,
100 provider: EmbeddingProvider::Ollama,
101 fallback_models: vec![
102 "nomic-embed-text".to_string(),
103 "mxbai-embed-large".to_string(),
104 "all-minilm".to_string(),
105 "all-mpnet-base-v2".to_string(),
106 ],
107 }
108 }
109
110 pub fn new_mock() -> Self {
111 let client = Client::builder()
112 .timeout(Duration::from_secs(1))
113 .build()
114 .expect("Failed to create HTTP client");
115
116 Self {
117 client,
118 api_key: String::new(),
119 model: "mock-model".to_string(),
120 base_url: "http://mock:11434".to_string(),
121 provider: EmbeddingProvider::Mock,
122 fallback_models: vec!["mock-model-2".to_string()],
123 }
124 }
125
126 pub fn with_model(mut self, model: String) -> Self {
127 self.model = model;
128 self
129 }
130
131 pub fn with_base_url(mut self, base_url: String) -> Self {
132 self.base_url = base_url;
133 self
134 }
135
136 pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
138 info!("Generating embedding for text of length: {}", text.len());
139
140 let operation = || async {
141 match self.generate_embedding_internal(text).await {
142 Ok(embedding) => Ok(embedding),
143 Err(e) => {
144 if e.to_string().contains("Rate limited") {
145 Err(backoff::Error::transient(e))
146 } else {
147 Err(backoff::Error::permanent(e))
148 }
149 }
150 }
151 };
152
153 let backoff = ExponentialBackoff {
154 max_elapsed_time: Some(Duration::from_secs(60)),
155 ..Default::default()
156 };
157
158 retry(backoff, operation).await
159 }
160
161 async fn generate_embedding_internal(&self, text: &str) -> Result<Vec<f32>> {
162 match self.provider {
163 EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
164 EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
165 EmbeddingProvider::Mock => self.generate_mock_embedding(text).await,
166 }
167 }
168
169 async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
170 let request = OpenAIEmbeddingRequest {
171 input: text.to_string(),
172 model: self.model.clone(),
173 };
174
175 let response = self
176 .client
177 .post(&format!("{}/v1/embeddings", self.base_url))
178 .header("Authorization", format!("Bearer {}", self.api_key))
179 .header("Content-Type", "application/json")
180 .json(&request)
181 .send()
182 .await?;
183
184 if !response.status().is_success() {
185 let status = response.status();
186 let error_text = response
187 .text()
188 .await
189 .unwrap_or_else(|_| "Unknown error".to_string());
190
191 if status.as_u16() == 429 {
192 warn!("Rate limited by OpenAI API, will retry");
193 return Err(anyhow::anyhow!("Rate limited: {}", error_text));
194 }
195
196 return Err(anyhow::anyhow!(
197 "OpenAI API request failed with status {}: {}",
198 status,
199 error_text
200 ));
201 }
202
203 let embedding_response: OpenAIEmbeddingResponse = response.json().await?;
204
205 if let Some(embedding_data) = embedding_response.data.first() {
206 Ok(embedding_data.embedding.clone())
207 } else {
208 Err(anyhow::anyhow!("No embedding data in OpenAI response"))
209 }
210 }
211
212 async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
213 let request = OllamaEmbeddingRequest {
214 model: self.model.clone(),
215 prompt: text.to_string(),
216 };
217
218 let response = self
219 .client
220 .post(&format!("{}/api/embeddings", self.base_url))
221 .header("Content-Type", "application/json")
222 .json(&request)
223 .send()
224 .await?;
225
226 if !response.status().is_success() {
227 let status = response.status();
228 let error_text = response
229 .text()
230 .await
231 .unwrap_or_else(|_| "Unknown error".to_string());
232
233 if status.as_u16() == 429 {
234 warn!("Rate limited by Ollama API, will retry");
235 return Err(anyhow::anyhow!("Rate limited: {}", error_text));
236 }
237
238 return Err(anyhow::anyhow!(
239 "Ollama API request failed with status {}: {}",
240 status,
241 error_text
242 ));
243 }
244
245 let embedding_response: OllamaEmbeddingResponse = response.json().await?;
246 Ok(embedding_response.embedding)
247 }
248
249 async fn generate_mock_embedding(&self, text: &str) -> Result<Vec<f32>> {
250 use std::collections::hash_map::DefaultHasher;
253 use std::hash::{Hash, Hasher};
254
255 let mut hasher = DefaultHasher::new();
256 text.hash(&mut hasher);
257 let hash = hasher.finish();
258
259 let dimensions = self.embedding_dimension();
261 let mut embedding = Vec::with_capacity(dimensions);
262
263 let mut seed = hash;
265 for _ in 0..dimensions {
266 seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
267 let value = ((seed >> 16) % 1000) as f32 / 1000.0 - 0.5; embedding.push(value);
269 }
270
271 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
273 if magnitude > 0.0 {
274 for val in &mut embedding {
275 *val /= magnitude;
276 }
277 }
278
279 Ok(embedding)
280 }
281
282 pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
284 info!("Generating embeddings for {} texts", texts.len());
285
286 let mut embeddings = Vec::with_capacity(texts.len());
287
288 for chunk in texts.chunks(10) {
290 let mut chunk_embeddings = Vec::with_capacity(chunk.len());
291
292 for text in chunk {
293 match self.generate_embedding(text).await {
294 Ok(embedding) => chunk_embeddings.push(embedding),
295 Err(e) => {
296 warn!("Failed to generate embedding for text: {}", e);
297 return Err(e);
298 }
299 }
300
301 tokio::time::sleep(Duration::from_millis(100)).await;
303 }
304
305 embeddings.extend(chunk_embeddings);
306 }
307
308 Ok(embeddings)
309 }
310
311 pub fn embedding_dimension(&self) -> usize {
313 match self.provider {
314 EmbeddingProvider::OpenAI => match self.model.as_str() {
315 "text-embedding-3-small" => 1536,
316 "text-embedding-3-large" => 3072,
317 "text-embedding-ada-002" => 1536,
318 _ => 1536, },
320 EmbeddingProvider::Ollama => {
321 match self.model.as_str() {
323 "gpt-oss:20b" => 4096, "nomic-embed-text" => 768,
325 "mxbai-embed-large" => 1024,
326 "all-minilm" => 384,
327 _ => 768, }
329 }
330 EmbeddingProvider::Mock => 768, }
332 }
333
334 pub fn provider(&self) -> &EmbeddingProvider {
336 &self.provider
337 }
338
339 pub async fn auto_configure(base_url: String) -> Result<Self> {
341 info!("🔍 Auto-detecting best available embedding model...");
342
343 let client = Client::builder()
344 .timeout(Duration::from_secs(30))
345 .build()
346 .context("Failed to create HTTP client")?;
347
348 let available_models = Self::detect_ollama_models(&client, &base_url).await?;
350
351 if available_models.is_empty() {
352 return Err(anyhow::anyhow!("No embedding models found on Ollama server"));
353 }
354
355 let selected_model = Self::select_best_model(&available_models)?;
357
358 info!("✅ Selected model: {} ({}D)", selected_model.name, selected_model.dimensions);
359
360 let mut embedder = Self::new_ollama(base_url, selected_model.name.clone());
361 embedder.fallback_models = available_models.into_iter()
362 .filter(|m| m.name != embedder.model)
363 .map(|m| m.name)
364 .collect();
365
366 Ok(embedder)
367 }
368
369 pub async fn generate_embedding_with_fallback(&self, text: &str) -> Result<Vec<f32>> {
371 match self.generate_embedding(text).await {
373 Ok(embedding) => return Ok(embedding),
374 Err(e) => {
375 warn!("Primary model '{}' failed: {}", self.model, e);
376 }
377 }
378
379 for fallback_model in &self.fallback_models {
381 info!("🔄 Trying fallback model: {}", fallback_model);
382
383 let mut fallback_embedder = self.clone();
384 fallback_embedder.model = fallback_model.clone();
385
386 match fallback_embedder.generate_embedding(text).await {
387 Ok(embedding) => {
388 info!("✅ Fallback model '{}' succeeded", fallback_model);
389 return Ok(embedding);
390 }
391 Err(e) => {
392 warn!("Fallback model '{}' failed: {}", fallback_model, e);
393 continue;
394 }
395 }
396 }
397
398 Err(anyhow::anyhow!("All embedding models failed, including fallbacks"))
399 }
400
401 pub async fn health_check(&self) -> Result<EmbeddingHealth> {
403 let start_time = std::time::Instant::now();
404
405 let test_result = self.generate_embedding("Health check test").await;
406 let response_time = start_time.elapsed();
407
408 let health = match test_result {
409 Ok(embedding) => EmbeddingHealth {
410 status: "healthy".to_string(),
411 model: self.model.clone(),
412 provider: format!("{:?}", self.provider),
413 response_time_ms: response_time.as_millis() as u64,
414 embedding_dimensions: embedding.len(),
415 error: None,
416 },
417 Err(e) => EmbeddingHealth {
418 status: "unhealthy".to_string(),
419 model: self.model.clone(),
420 provider: format!("{:?}", self.provider),
421 response_time_ms: response_time.as_millis() as u64,
422 embedding_dimensions: 0,
423 error: Some(e.to_string()),
424 },
425 };
426
427 Ok(health)
428 }
429
430 async fn detect_ollama_models(client: &Client, base_url: &str) -> Result<Vec<EmbeddingModelInfo>> {
432 let response = client
433 .get(&format!("{}/api/tags", base_url))
434 .send()
435 .await
436 .context("Failed to connect to Ollama API")?;
437
438 if !response.status().is_success() {
439 return Err(anyhow::anyhow!("Ollama API returned error: {}", response.status()));
440 }
441
442 let models_response: OllamaModelsResponse = response.json().await
443 .context("Failed to parse Ollama models response")?;
444
445 let mut embedding_models = Vec::new();
446
447 for model in models_response.models {
448 if let Some(model_info) = Self::classify_embedding_model(&model.name) {
449 embedding_models.push(model_info);
450 }
451 }
452
453 Ok(embedding_models)
454 }
455
456 fn classify_embedding_model(model_name: &str) -> Option<EmbeddingModelInfo> {
458 let name_lower = model_name.to_lowercase();
459
460 let known_models = [
462 ("nomic-embed-text", 768, "High-quality text embeddings", true),
463 ("mxbai-embed-large", 1024, "Large multilingual embeddings", true),
464 ("all-minilm", 384, "Compact sentence embeddings", false),
465 ("all-mpnet-base-v2", 768, "Sentence transformer embeddings", false),
466 ("bge-small-en", 384, "BGE small English embeddings", false),
467 ("bge-base-en", 768, "BGE base English embeddings", false),
468 ("bge-large-en", 1024, "BGE large English embeddings", false),
469 ("e5-small", 384, "E5 small embeddings", false),
470 ("e5-base", 768, "E5 base embeddings", false),
471 ("e5-large", 1024, "E5 large embeddings", false),
472 ];
473
474 for (pattern, dimensions, description, preferred) in known_models {
475 if name_lower.contains(pattern) || model_name.contains(pattern) {
476 return Some(EmbeddingModelInfo {
477 name: model_name.to_string(),
478 dimensions,
479 description: description.to_string(),
480 preferred,
481 });
482 }
483 }
484
485 if name_lower.contains("embed") ||
487 name_lower.contains("sentence") ||
488 name_lower.contains("vector") {
489 return Some(EmbeddingModelInfo {
490 name: model_name.to_string(),
491 dimensions: 768, description: "Detected embedding model".to_string(),
493 preferred: false,
494 });
495 }
496
497 None
498 }
499
500 fn select_best_model(available_models: &[EmbeddingModelInfo]) -> Result<&EmbeddingModelInfo> {
502 if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
504 return Ok(preferred);
505 }
506
507 available_models.first()
509 .ok_or_else(|| anyhow::anyhow!("No embedding models available"))
510 }
511}
512
513#[derive(Debug, Clone)]
515pub struct EmbeddingModelInfo {
516 pub name: String,
517 pub dimensions: usize,
518 pub description: String,
519 pub preferred: bool,
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct EmbeddingHealth {
525 pub status: String,
526 pub model: String,
527 pub provider: String,
528 pub response_time_ms: u64,
529 pub embedding_dimensions: usize,
530 pub error: Option<String>,
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[tokio::test]
538 #[ignore] async fn test_generate_openai_embedding() {
540 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
541 let embedder = SimpleEmbedder::new(api_key);
542
543 let result = embedder.generate_embedding("Hello, world!").await;
544 assert!(result.is_ok());
545
546 let embedding = result.unwrap();
547 assert_eq!(embedding.len(), 1536);
548 }
549
550 #[tokio::test]
551 #[ignore] async fn test_generate_ollama_embedding() {
553 let embedder = SimpleEmbedder::new_ollama(
554 "http://192.168.1.110:11434".to_string(),
555 "nomic-embed-text".to_string(),
556 );
557
558 let result = embedder.generate_embedding("Hello, world!").await;
559 assert!(result.is_ok());
560
561 let embedding = result.unwrap();
562 assert_eq!(embedding.len(), 768);
563 }
564
565 #[test]
566 fn test_embedding_dimensions() {
567 let embedder = SimpleEmbedder::new("dummy_key".to_string());
568 assert_eq!(embedder.embedding_dimension(), 1536);
569
570 let embedder = embedder.with_model("text-embedding-3-large".to_string());
571 assert_eq!(embedder.embedding_dimension(), 3072);
572
573 let ollama_embedder = SimpleEmbedder::new_ollama(
574 "http://localhost:11434".to_string(),
575 "nomic-embed-text".to_string(),
576 );
577 assert_eq!(ollama_embedder.embedding_dimension(), 768);
578
579 let gpt_oss_embedder = SimpleEmbedder::new_ollama(
580 "http://localhost:11434".to_string(),
581 "gpt-oss:20b".to_string(),
582 );
583 assert_eq!(gpt_oss_embedder.embedding_dimension(), 4096);
584
585 let mock_embedder = SimpleEmbedder::new_mock();
586 assert_eq!(mock_embedder.embedding_dimension(), 768);
587 }
588
589 #[test]
590 fn test_provider_types() {
591 let openai_embedder = SimpleEmbedder::new("dummy_key".to_string());
592 assert_eq!(openai_embedder.provider(), &EmbeddingProvider::OpenAI);
593
594 let ollama_embedder = SimpleEmbedder::new_ollama(
595 "http://localhost:11434".to_string(),
596 "nomic-embed-text".to_string(),
597 );
598 assert_eq!(ollama_embedder.provider(), &EmbeddingProvider::Ollama);
599
600 let mock_embedder = SimpleEmbedder::new_mock();
601 assert_eq!(mock_embedder.provider(), &EmbeddingProvider::Mock);
602 }
603}