1use crate::types::{Error, Result};
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::env;
10use tokio::process::Command as AsyncCommand;
11use std::sync::{Mutex, Once};
12use std::sync::OnceLock;
13use tokio::io::AsyncWriteExt;
14
15const OPENAI_API_URL: &str = "https://api.openai.com/v1/embeddings";
17const OPENAI_MODEL: &str = "text-embedding-ada-002";
18const OPENAI_DIMENSION: usize = 1536;
19
20const DEFAULT_LOCAL_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2";
31const DEFAULT_LOCAL_DIMENSION: usize = 384;
32
33fn get_local_embedding_model() -> fastembed::EmbeddingModel {
35 use fastembed::EmbeddingModel;
36
37 if let Ok(model_str) = env::var("AVOCADODB_EMBEDDING_MODEL") {
38 match model_str.to_lowercase().as_str() {
39 "allminilml6v2" | "all-minilm-l6-v2" | "minilm6" => EmbeddingModel::AllMiniLML6V2,
40 "allminilml12v2" | "all-minilm-l12-v2" | "minilm12" => EmbeddingModel::AllMiniLML12V2,
41 "bgesmallen" | "bge-small-en-v1.5" | "bgesmall" => EmbeddingModel::BGESmallENV15,
42 "bgelargeen" | "bge-large-en-v1.5" | "bgelarge" => EmbeddingModel::BGELargeENV15,
43 "nomicv1" | "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
44 "nomicv15" | "nomic-embed-text-v1.5" | "nomic" => EmbeddingModel::NomicEmbedTextV15,
45 _ => {
46 log::warn!("Unknown embedding model '{}', using default AllMiniLML6V2", model_str);
47 EmbeddingModel::AllMiniLML6V2
48 }
49 }
50 } else {
51 EmbeddingModel::AllMiniLML6V2
52 }
53}
54
55fn get_local_embedding_dimension() -> usize {
57 use fastembed::EmbeddingModel;
58
59 match get_local_embedding_model() {
60 EmbeddingModel::AllMiniLML6V2 => 384,
61 EmbeddingModel::AllMiniLML12V2 => 384,
62 EmbeddingModel::BGESmallENV15 => 384,
63 EmbeddingModel::BGELargeENV15 => 1024,
64 EmbeddingModel::NomicEmbedTextV1 => 768,
65 EmbeddingModel::NomicEmbedTextV15 => 768,
66 _ => DEFAULT_LOCAL_DIMENSION, }
68}
69
70fn get_local_model_name() -> &'static str {
72 use fastembed::EmbeddingModel;
73
74 match get_local_embedding_model() {
75 EmbeddingModel::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
76 EmbeddingModel::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
77 EmbeddingModel::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
78 EmbeddingModel::BGELargeENV15 => "BAAI/bge-large-en-v1.5",
79 EmbeddingModel::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
80 EmbeddingModel::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
81 _ => DEFAULT_LOCAL_MODEL, }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum EmbeddingProvider {
88 Local,
90 OpenAI,
92 Remote,
94}
95
96impl Default for EmbeddingProvider {
97 fn default() -> Self {
98 EmbeddingProvider::Local
100 }
101}
102
103impl EmbeddingProvider {
104 pub fn from_env() -> Self {
106 if env::var("AVOCADODB_EMBEDDING_PROVIDER").is_ok() {
109 match env::var("AVOCADODB_EMBEDDING_PROVIDER")
110 .unwrap()
111 .to_lowercase()
112 .as_str()
113 {
114 "openai" => EmbeddingProvider::OpenAI,
115 "local" => EmbeddingProvider::Local,
116 "remote" => EmbeddingProvider::Remote,
117 _ => EmbeddingProvider::Local,
118 }
119 } else {
120 EmbeddingProvider::Local
121 }
122 }
123
124 pub fn dimension(&self) -> usize {
125 match self {
126 EmbeddingProvider::Local => get_local_embedding_dimension(),
127 EmbeddingProvider::OpenAI => OPENAI_DIMENSION,
128 EmbeddingProvider::Remote => {
129 env::var("AVOCADODB_EMBEDDING_DIM")
131 .ok()
132 .and_then(|s| s.parse::<usize>().ok())
133 .unwrap_or_else(get_local_embedding_dimension)
134 }
135 }
136 }
137
138 pub fn model_name(&self) -> &'static str {
139 match self {
140 EmbeddingProvider::Local => get_local_model_name(),
141 EmbeddingProvider::OpenAI => OPENAI_MODEL,
142 EmbeddingProvider::Remote => DEFAULT_LOCAL_MODEL,
144 }
145 }
146}
147
148#[derive(Debug, Serialize)]
150struct EmbeddingRequest {
151 model: String,
152 input: Vec<String>,
153}
154
155#[derive(Debug, Deserialize)]
156struct EmbeddingResponse {
157 data: Vec<EmbeddingData>,
158}
159
160#[derive(Debug, Deserialize)]
161struct EmbeddingData {
162 embedding: Vec<f32>,
163 index: usize,
164}
165
166pub async fn embed_text(
181 text: &str,
182 provider: Option<EmbeddingProvider>,
183 api_key: Option<&str>,
184) -> Result<Vec<f32>> {
185 let results = embed_batch(vec![text], provider, api_key).await?;
186 results.into_iter().next().ok_or_else(|| {
187 Error::Embedding("No embedding returned".to_string())
188 })
189}
190
191pub async fn embed_batch(
206 texts: Vec<&str>,
207 provider: Option<EmbeddingProvider>,
208 api_key: Option<&str>,
209) -> Result<Vec<Vec<f32>>> {
210 let provider = provider.unwrap_or_else(EmbeddingProvider::from_env);
211
212 if texts.is_empty() {
213 return Ok(vec![]);
214 }
215
216 match provider {
217 EmbeddingProvider::Local => embed_batch_local(texts).await,
218 EmbeddingProvider::OpenAI => embed_batch_openai(texts, api_key).await,
219 EmbeddingProvider::Remote => embed_batch_remote(texts).await,
220 }
221}
222
223async fn embed_batch_local(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
239 if let Ok(embeddings) = embed_batch_local_rust(texts.clone()).await {
241 return Ok(embeddings);
242 }
243
244 if matches!(std::env::var("AVOCADODB_FORBID_FALLBACKS").ok().as_deref(), Some("1" | "true" | "TRUE" | "yes" | "YES")) {
246 return Err(Error::Embedding(
247 "Local fastembed failed and fallbacks are disabled (AVOCADODB_FORBID_FALLBACKS=1)".to_string()
248 ));
249 }
250
251 static PY_WARN_ONCE: Once = Once::new();
253 PY_WARN_ONCE.call_once(|| {
254 log::warn!("Falling back to Python sentence-transformers for embeddings. Install Rust fastembed for best performance.");
255 });
256 if let Ok(embeddings) = embed_batch_local_python(texts.clone()).await {
257 return Ok(embeddings);
258 }
259
260 static HASH_WARN_ONCE: Once = Once::new();
262 HASH_WARN_ONCE.call_once(|| {
263 log::error!("Falling back to HASH-BASED embeddings (NOT SEMANTIC). This mode is for emergencies only.");
264 });
265 embed_batch_local_hash(texts).await
266}
267
268async fn embed_batch_local_rust(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
276 use fastembed::{TextEmbedding, InitOptions};
277 use tokio::task;
278
279 if texts.is_empty() {
280 return Ok(vec![]);
281 }
282
283 let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
285
286 static FASTEMBED_MODEL: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
288
289 let embeddings = task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
292 let model_mutex = FASTEMBED_MODEL.get_or_init(|| {
294 let embedding_model = get_local_embedding_model();
295 let model = TextEmbedding::try_new(
296 InitOptions::new(embedding_model)
297 .with_show_download_progress(false)
298 )
299 .expect("Failed to initialize fastembed model");
300 Mutex::new(model)
301 });
302
303 let embeddings = model_mutex
305 .lock()
306 .map_err(|_| Error::Embedding("Failed to lock fastembed model".to_string()))?
307 .embed(texts_owned, None)
308 .map_err(|e| Error::Embedding(format!("Failed to generate embeddings: {}", e)))?;
309
310 let expected_dim = get_local_embedding_dimension();
312 for emb in &embeddings {
313 if emb.len() != expected_dim {
314 return Err(Error::Embedding(format!(
315 "Unexpected embedding dimension: {} (expected {})",
316 emb.len(),
317 expected_dim
318 )));
319 }
320 }
321
322 Ok(embeddings)
323 })
324 .await
325 .map_err(|e| Error::Embedding(format!("Task join error: {}", e)))??;
326
327 Ok(embeddings)
328}
329
330async fn embed_batch_local_python(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
338 let python = which_python()?;
340
341 let script = format!(r#"
343import sys
344import json
345
346try:
347 from sentence_transformers import SentenceTransformer
348 import numpy as np
349
350 # Load model (cached after first use)
351 model = SentenceTransformer('all-MiniLM-L6-v2')
352
353 # Read texts from stdin (one per line)
354 texts = []
355 for line in sys.stdin:
356 texts.append(line.strip())
357
358 # Generate embeddings
359 embeddings = model.encode(texts, normalize_embeddings=True)
360
361 # Output as JSON array
362 result = [emb.tolist() for emb in embeddings]
363 print(json.dumps(result))
364 sys.exit(0)
365except ImportError:
366 print(json.dumps({{"error": "sentence-transformers not installed. Install with: pip install sentence-transformers"}}), file=sys.stderr)
367 sys.exit(1)
368except Exception as e:
369 print(json.dumps({{"error": str(e)}}), file=sys.stderr)
370 sys.exit(1)
371"#);
372
373 let mut child = AsyncCommand::new(&python)
375 .arg("-c")
376 .arg(&script)
377 .stdin(std::process::Stdio::piped())
378 .stdout(std::process::Stdio::piped())
379 .stderr(std::process::Stdio::piped())
380 .spawn()
381 .map_err(|e| Error::Embedding(format!("Failed to spawn Python process: {}", e)))?;
382
383 if let Some(mut stdin) = child.stdin.take() {
385 for text in &texts {
386 stdin.write_all(text.as_bytes())
387 .await
388 .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
389 stdin.write_all(b"\n")
390 .await
391 .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
392 }
393 stdin.shutdown().await
394 .map_err(|e| Error::Embedding(format!("Failed to close Python stdin: {}", e)))?;
395 }
396
397 let output = child.wait_with_output()
399 .await
400 .map_err(|e| Error::Embedding(format!("Failed to wait for Python process: {}", e)))?;
401
402 if !output.status.success() {
403 let stderr = String::from_utf8_lossy(&output.stderr);
404 return Err(Error::Embedding(format!("Python embedding failed: {}", stderr)));
405 }
406
407 let stdout = String::from_utf8_lossy(&output.stdout);
409 let embeddings: Vec<Vec<f32>> = serde_json::from_str(&stdout)
410 .map_err(|e| Error::Embedding(format!("Failed to parse Python output: {}", e)))?;
411
412 let expected_dim = get_local_embedding_dimension();
414 for emb in &embeddings {
415 if emb.len() != expected_dim {
416 return Err(Error::Embedding(format!(
417 "Unexpected embedding dimension: {} (expected {})",
418 emb.len(),
419 expected_dim
420 )));
421 }
422 }
423
424 if embeddings.len() != texts.len() {
425 return Err(Error::Embedding(format!(
426 "Mismatched embedding count: {} embeddings for {} texts",
427 embeddings.len(),
428 texts.len()
429 )));
430 }
431
432 Ok(embeddings)
433}
434
435fn which_python() -> Result<String> {
437 for cmd in &["python3", "python"] {
439 if std::process::Command::new(cmd)
440 .arg("--version")
441 .output()
442 .is_ok()
443 {
444 return Ok(cmd.to_string());
445 }
446 }
447 Err(Error::Embedding("Python not found. Install Python 3 to use local embeddings.".to_string()))
448}
449
450async fn embed_batch_local_hash(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
455 use std::collections::hash_map::DefaultHasher;
456 use std::hash::{Hash, Hasher};
457
458 let embeddings: Vec<Vec<f32>> = texts
459 .iter()
460 .map(|text| {
461 let mut hasher = DefaultHasher::new();
462 text.hash(&mut hasher);
463 let hash = hasher.finish();
464
465 let dim = get_local_embedding_dimension();
466 let mut embedding = vec![0.0f32; dim];
467 for i in 0..dim {
468 let seed = hash.wrapping_add(i as u64);
469 embedding[i] = ((seed % 2000) as f32 - 1000.0) / 1000.0;
470 }
471
472 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
473 if norm > 0.0 {
474 for x in &mut embedding {
475 *x /= norm;
476 }
477 }
478
479 embedding
480 })
481 .collect();
482
483 Ok(embeddings)
484}
485
486async fn embed_batch_openai(
488 texts: Vec<&str>,
489 api_key: Option<&str>,
490) -> Result<Vec<Vec<f32>>> {
491 let api_key = api_key
492 .map(|s| s.to_string())
493 .or_else(|| env::var("OPENAI_API_KEY").ok())
494 .ok_or_else(|| {
495 Error::Embedding(
496 "OPENAI_API_KEY environment variable not set and no API key provided".to_string(),
497 )
498 })?;
499
500 if texts.len() > 2048 {
502 return Err(Error::InvalidInput(format!(
503 "Too many texts to embed at once: {} (max 2048)",
504 texts.len()
505 )));
506 }
507
508 let client = Client::new();
509
510 let request = EmbeddingRequest {
511 model: OPENAI_MODEL.to_string(),
512 input: texts.iter().map(|s| s.to_string()).collect(),
513 };
514
515 let response = client
516 .post(OPENAI_API_URL)
517 .header("Authorization", format!("Bearer {}", api_key))
518 .header("Content-Type", "application/json")
519 .json(&request)
520 .send()
521 .await
522 .map_err(|e| Error::Embedding(format!("API request failed: {}", e)))?;
523
524 if !response.status().is_success() {
525 let status = response.status();
526 let body = response.text().await.unwrap_or_default();
527 return Err(Error::Embedding(format!(
528 "API returned error {}: {}",
529 status, body
530 )));
531 }
532
533 let embedding_response: EmbeddingResponse = response
534 .json()
535 .await
536 .map_err(|e| Error::Embedding(format!("Failed to parse response: {}", e)))?;
537
538 let mut data = embedding_response.data;
540 data.sort_by_key(|d| d.index);
541
542 let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
543
544 for emb in &embeddings {
546 if emb.len() != OPENAI_DIMENSION {
547 return Err(Error::Embedding(format!(
548 "Unexpected embedding dimension: {} (expected {})",
549 emb.len(),
550 OPENAI_DIMENSION
551 )));
552 }
553 }
554
555 Ok(embeddings)
556}
557
558async fn embed_batch_remote(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
573 use serde_json::json;
574
575 let url = env::var("AVOCADODB_EMBEDDING_URL")
576 .map_err(|_| Error::Embedding("AVOCADODB_EMBEDDING_URL not set for remote provider".to_string()))?;
577 if texts.is_empty() {
578 return Ok(vec![]);
579 }
580
581 let client = Client::new();
582 let mut req = client.post(&url).header("Content-Type", "application/json");
583
584 if let Ok(api_key) = env::var("AVOCADODB_EMBEDDING_API_KEY") {
585 if !api_key.is_empty() {
586 req = req.header("Authorization", format!("Bearer {}", api_key));
587 }
588 }
589
590 let model = env::var("AVOCADODB_EMBEDDING_MODEL").ok();
591 let body = if let Some(model_name) = model {
592 json!({ "inputs": texts, "model": model_name })
593 } else {
594 json!({ "inputs": texts })
595 };
596
597 let resp = req
598 .json(&body)
599 .send()
600 .await
601 .map_err(|e| Error::Embedding(format!("Remote request failed: {}", e)))?;
602
603 if !resp.status().is_success() {
604 let status = resp.status();
605 let text = resp.text().await.unwrap_or_default();
606 return Err(Error::Embedding(format!("Remote returned error {}: {}", status, text)));
607 }
608
609 let expected_dim = EmbeddingProvider::Remote.dimension();
611 let text_body = resp.text().await.map_err(|e| Error::Embedding(format!("Failed reading remote body: {}", e)))?;
612
613 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text_body) {
615 if let Some(arr) = v.get("embeddings").and_then(|x| x.as_array()) {
616 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
617 for item in arr {
618 let vec_opt = item.as_array().map(|nums| {
619 nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
620 });
621 let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
622 if !vec.is_empty() && vec.len() != expected_dim {
623 if let Some(dim) = v.get("dimension").and_then(|d| d.as_u64()).map(|d| d as usize) {
625 if vec.len() != dim {
626 return Err(Error::Embedding(format!(
627 "Unexpected embedding dimension: {} (expected {})",
628 vec.len(),
629 expected_dim
630 )));
631 }
632 } else {
633 return Err(Error::Embedding(format!(
634 "Unexpected embedding dimension: {} (expected {})",
635 vec.len(),
636 expected_dim
637 )));
638 }
639 }
640 embeddings.push(vec);
641 }
642 if embeddings.len() != texts.len() {
643 return Err(Error::Embedding(format!(
644 "Mismatched embedding count: got {}, expected {}",
645 embeddings.len(),
646 texts.len()
647 )));
648 }
649 return Ok(embeddings);
650 }
651
652 if let Some(arr) = v.as_array() {
654 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
655 for item in arr {
656 let vec_opt = item.as_array().map(|nums| {
657 nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
658 });
659 let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
660 if !vec.is_empty() && vec.len() != expected_dim {
661 return Err(Error::Embedding(format!(
662 "Unexpected embedding dimension: {} (expected {})",
663 vec.len(),
664 expected_dim
665 )));
666 }
667 embeddings.push(vec);
668 }
669 if embeddings.len() != texts.len() {
670 return Err(Error::Embedding(format!(
671 "Mismatched embedding count: got {}, expected {}",
672 embeddings.len(),
673 texts.len()
674 )));
675 }
676 return Ok(embeddings);
677 }
678 }
679
680 Err(Error::Embedding("Failed to parse remote embedding response".to_string()))
681}
682pub fn embedding_model() -> &'static str {
684 EmbeddingProvider::from_env().model_name()
685}
686
687pub fn embedding_dimension() -> usize {
689 EmbeddingProvider::from_env().dimension()
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_embedding_provider_default() {
698 let provider = EmbeddingProvider::default();
700 assert_eq!(provider, EmbeddingProvider::Local);
701 assert_eq!(provider.dimension(), get_local_embedding_dimension());
702 }
703
704 #[test]
705 fn test_embedding_dimensions() {
706 assert_eq!(EmbeddingProvider::Local.dimension(), get_local_embedding_dimension());
708 assert_eq!(EmbeddingProvider::OpenAI.dimension(), 1536);
709 }
710
711 #[tokio::test]
712 async fn test_embed_batch_local() {
713 let texts = vec!["Hello", "World", "Test"];
715 let result = embed_batch_local(texts).await;
716
717 assert!(result.is_ok());
718 let embeddings = result.unwrap();
719 assert_eq!(embeddings.len(), 3);
720 for emb in embeddings {
721 assert_eq!(emb.len(), get_local_embedding_dimension());
722 }
723 }
724
725 #[tokio::test]
726 #[ignore] async fn test_embed_text_openai() {
728 let result = embed_text("Hello, world!", Some(EmbeddingProvider::OpenAI), None).await;
729 if env::var("OPENAI_API_KEY").is_ok() {
730 let embedding = result.unwrap();
731 assert_eq!(embedding.len(), OPENAI_DIMENSION);
732 } else {
733 assert!(result.is_err());
734 }
735 }
736}