1use crate::types::{Error, Result};
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::env;
10use tokio::process::Command as AsyncCommand;
11use std::sync::{Mutex, Once};
12use std::sync::OnceLock;
13use tokio::io::AsyncWriteExt;
14
15const OPENAI_API_URL: &str = "https://api.openai.com/v1/embeddings";
17const OPENAI_MODEL: &str = "text-embedding-ada-002";
18const OPENAI_DIMENSION: usize = 1536;
19
20const DEFAULT_LOCAL_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2";
31const DEFAULT_LOCAL_DIMENSION: usize = 384;
32
33fn get_local_embedding_model() -> fastembed::EmbeddingModel {
35 use fastembed::EmbeddingModel;
36
37 if let Ok(model_str) = env::var("AVOCADODB_EMBEDDING_MODEL") {
38 match model_str.to_lowercase().as_str() {
39 "allminilml6v2" | "all-minilm-l6-v2" | "minilm6" => EmbeddingModel::AllMiniLML6V2,
40 "allminilml12v2" | "all-minilm-l12-v2" | "minilm12" => EmbeddingModel::AllMiniLML12V2,
41 "bgesmallen" | "bge-small-en-v1.5" | "bgesmall" => EmbeddingModel::BGESmallENV15,
42 "bgelargeen" | "bge-large-en-v1.5" | "bgelarge" => EmbeddingModel::BGELargeENV15,
43 "nomicv1" | "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
44 "nomicv15" | "nomic-embed-text-v1.5" | "nomic" => EmbeddingModel::NomicEmbedTextV15,
45 _ => {
46 log::warn!("Unknown embedding model '{}', using default AllMiniLML6V2", model_str);
47 EmbeddingModel::AllMiniLML6V2
48 }
49 }
50 } else {
51 EmbeddingModel::AllMiniLML6V2
52 }
53}
54
55fn get_local_embedding_dimension() -> usize {
57 use fastembed::EmbeddingModel;
58
59 match get_local_embedding_model() {
60 EmbeddingModel::AllMiniLML6V2 => 384,
61 EmbeddingModel::AllMiniLML12V2 => 384,
62 EmbeddingModel::BGESmallENV15 => 384,
63 EmbeddingModel::BGELargeENV15 => 1024,
64 EmbeddingModel::NomicEmbedTextV1 => 768,
65 EmbeddingModel::NomicEmbedTextV15 => 768,
66 _ => DEFAULT_LOCAL_DIMENSION, }
68}
69
70fn get_local_model_name() -> &'static str {
72 use fastembed::EmbeddingModel;
73
74 match get_local_embedding_model() {
75 EmbeddingModel::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
76 EmbeddingModel::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
77 EmbeddingModel::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
78 EmbeddingModel::BGELargeENV15 => "BAAI/bge-large-en-v1.5",
79 EmbeddingModel::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
80 EmbeddingModel::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
81 _ => DEFAULT_LOCAL_MODEL, }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum EmbeddingProvider {
88 Local,
90 OpenAI,
92 Remote,
94 Ollama,
96}
97
98impl Default for EmbeddingProvider {
99 fn default() -> Self {
100 EmbeddingProvider::Local
102 }
103}
104
105impl EmbeddingProvider {
106 pub fn from_env() -> Self {
108 if env::var("AVOCADODB_EMBEDDING_PROVIDER").is_ok() {
111 match env::var("AVOCADODB_EMBEDDING_PROVIDER")
112 .unwrap()
113 .to_lowercase()
114 .as_str()
115 {
116 "openai" => EmbeddingProvider::OpenAI,
117 "local" | "fastembed" => EmbeddingProvider::Local,
118 "remote" => EmbeddingProvider::Remote,
119 "ollama" => EmbeddingProvider::Ollama,
120 _ => EmbeddingProvider::Local,
121 }
122 } else {
123 EmbeddingProvider::Local
124 }
125 }
126
127 pub fn dimension(&self) -> usize {
129 match self {
130 EmbeddingProvider::Local => get_local_embedding_dimension(),
131 EmbeddingProvider::OpenAI => OPENAI_DIMENSION,
132 EmbeddingProvider::Ollama => get_ollama_embedding_dimension(),
133 EmbeddingProvider::Remote => {
134 env::var("AVOCADODB_EMBEDDING_DIM")
136 .ok()
137 .and_then(|s| s.parse::<usize>().ok())
138 .unwrap_or_else(get_local_embedding_dimension)
139 }
140 }
141 }
142
143 pub fn model_name(&self) -> &'static str {
145 match self {
146 EmbeddingProvider::Local => get_local_model_name(),
147 EmbeddingProvider::OpenAI => OPENAI_MODEL,
148 EmbeddingProvider::Ollama => get_ollama_model_name(),
149 EmbeddingProvider::Remote => DEFAULT_LOCAL_MODEL,
151 }
152 }
153}
154
155const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
157const DEFAULT_OLLAMA_MODEL: &str = "bge-m3";
158
159fn get_ollama_model_name() -> &'static str {
161 static OLLAMA_MODEL: OnceLock<String> = OnceLock::new();
164 let model = OLLAMA_MODEL.get_or_init(|| {
165 env::var("AVOCADODB_OLLAMA_MODEL")
166 .unwrap_or_else(|_| DEFAULT_OLLAMA_MODEL.to_string())
167 });
168 unsafe { std::mem::transmute::<&str, &'static str>(model.as_str()) }
170}
171
172fn get_ollama_embedding_dimension() -> usize {
174 let model = get_ollama_model_name();
175 match model {
176 m if m.contains("bge-m3") => 1024,
177 m if m.contains("bge-large") => 1024,
178 m if m.contains("nomic") => 768,
179 m if m.contains("mxbai") => 1024,
180 m if m.contains("minilm") || m.contains("all-minilm") => 384,
181 m if m.contains("snowflake") => 1024,
182 _ => {
183 env::var("AVOCADODB_EMBEDDING_DIM")
185 .ok()
186 .and_then(|s| s.parse::<usize>().ok())
187 .unwrap_or(1024) }
189 }
190}
191
192#[derive(Debug, Serialize)]
194struct EmbeddingRequest {
195 model: String,
196 input: Vec<String>,
197}
198
199#[derive(Debug, Deserialize)]
200struct EmbeddingResponse {
201 data: Vec<EmbeddingData>,
202}
203
204#[derive(Debug, Deserialize)]
205struct EmbeddingData {
206 embedding: Vec<f32>,
207 index: usize,
208}
209
210pub async fn embed_text(
225 text: &str,
226 provider: Option<EmbeddingProvider>,
227 api_key: Option<&str>,
228) -> Result<Vec<f32>> {
229 let results = embed_batch(vec![text], provider, api_key).await?;
230 results.into_iter().next().ok_or_else(|| {
231 Error::Embedding("No embedding returned".to_string())
232 })
233}
234
235pub async fn embed_batch(
250 texts: Vec<&str>,
251 provider: Option<EmbeddingProvider>,
252 api_key: Option<&str>,
253) -> Result<Vec<Vec<f32>>> {
254 let provider = provider.unwrap_or_else(EmbeddingProvider::from_env);
255
256 if texts.is_empty() {
257 return Ok(vec![]);
258 }
259
260 match provider {
261 EmbeddingProvider::Local => embed_batch_local(texts).await,
262 EmbeddingProvider::OpenAI => embed_batch_openai(texts, api_key).await,
263 EmbeddingProvider::Remote => embed_batch_remote(texts).await,
264 EmbeddingProvider::Ollama => embed_batch_ollama(texts).await,
265 }
266}
267
268async fn embed_batch_local(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
284 if let Ok(embeddings) = embed_batch_local_rust(texts.clone()).await {
286 return Ok(embeddings);
287 }
288
289 if matches!(std::env::var("AVOCADODB_FORBID_FALLBACKS").ok().as_deref(), Some("1" | "true" | "TRUE" | "yes" | "YES")) {
291 return Err(Error::Embedding(
292 "Local fastembed failed and fallbacks are disabled (AVOCADODB_FORBID_FALLBACKS=1)".to_string()
293 ));
294 }
295
296 static PY_WARN_ONCE: Once = Once::new();
298 PY_WARN_ONCE.call_once(|| {
299 log::warn!("Falling back to Python sentence-transformers for embeddings. Install Rust fastembed for best performance.");
300 });
301 if let Ok(embeddings) = embed_batch_local_python(texts.clone()).await {
302 return Ok(embeddings);
303 }
304
305 static HASH_WARN_ONCE: Once = Once::new();
307 HASH_WARN_ONCE.call_once(|| {
308 log::error!("Falling back to HASH-BASED embeddings (NOT SEMANTIC). This mode is for emergencies only.");
309 });
310 embed_batch_local_hash(texts).await
311}
312
313async fn embed_batch_local_rust(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
321 use fastembed::{TextEmbedding, InitOptions};
322 use tokio::task;
323
324 if texts.is_empty() {
325 return Ok(vec![]);
326 }
327
328 let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
330
331 static FASTEMBED_MODEL: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
333
334 let embeddings = task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
337 let model_mutex = FASTEMBED_MODEL.get_or_init(|| {
339 let embedding_model = get_local_embedding_model();
340 let model = TextEmbedding::try_new(
341 InitOptions::new(embedding_model)
342 .with_show_download_progress(false)
343 )
344 .expect("Failed to initialize fastembed model");
345 Mutex::new(model)
346 });
347
348 let embeddings = model_mutex
350 .lock()
351 .map_err(|_| Error::Embedding("Failed to lock fastembed model".to_string()))?
352 .embed(texts_owned, None)
353 .map_err(|e| Error::Embedding(format!("Failed to generate embeddings: {}", e)))?;
354
355 let expected_dim = get_local_embedding_dimension();
357 for emb in &embeddings {
358 if emb.len() != expected_dim {
359 return Err(Error::Embedding(format!(
360 "Unexpected embedding dimension: {} (expected {})",
361 emb.len(),
362 expected_dim
363 )));
364 }
365 }
366
367 Ok(embeddings)
368 })
369 .await
370 .map_err(|e| Error::Embedding(format!("Task join error: {}", e)))??;
371
372 Ok(embeddings)
373}
374
375async fn embed_batch_local_python(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
383 let python = which_python()?;
385
386 let script = format!(r#"
388import sys
389import json
390
391try:
392 from sentence_transformers import SentenceTransformer
393 import numpy as np
394
395 # Load model (cached after first use)
396 model = SentenceTransformer('all-MiniLM-L6-v2')
397
398 # Read texts from stdin (one per line)
399 texts = []
400 for line in sys.stdin:
401 texts.append(line.strip())
402
403 # Generate embeddings
404 embeddings = model.encode(texts, normalize_embeddings=True)
405
406 # Output as JSON array
407 result = [emb.tolist() for emb in embeddings]
408 print(json.dumps(result))
409 sys.exit(0)
410except ImportError:
411 print(json.dumps({{"error": "sentence-transformers not installed. Install with: pip install sentence-transformers"}}), file=sys.stderr)
412 sys.exit(1)
413except Exception as e:
414 print(json.dumps({{"error": str(e)}}), file=sys.stderr)
415 sys.exit(1)
416"#);
417
418 let mut child = AsyncCommand::new(&python)
420 .arg("-c")
421 .arg(&script)
422 .stdin(std::process::Stdio::piped())
423 .stdout(std::process::Stdio::piped())
424 .stderr(std::process::Stdio::piped())
425 .spawn()
426 .map_err(|e| Error::Embedding(format!("Failed to spawn Python process: {}", e)))?;
427
428 if let Some(mut stdin) = child.stdin.take() {
430 for text in &texts {
431 stdin.write_all(text.as_bytes())
432 .await
433 .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
434 stdin.write_all(b"\n")
435 .await
436 .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
437 }
438 stdin.shutdown().await
439 .map_err(|e| Error::Embedding(format!("Failed to close Python stdin: {}", e)))?;
440 }
441
442 let output = child.wait_with_output()
444 .await
445 .map_err(|e| Error::Embedding(format!("Failed to wait for Python process: {}", e)))?;
446
447 if !output.status.success() {
448 let stderr = String::from_utf8_lossy(&output.stderr);
449 return Err(Error::Embedding(format!("Python embedding failed: {}", stderr)));
450 }
451
452 let stdout = String::from_utf8_lossy(&output.stdout);
454 let embeddings: Vec<Vec<f32>> = serde_json::from_str(&stdout)
455 .map_err(|e| Error::Embedding(format!("Failed to parse Python output: {}", e)))?;
456
457 let expected_dim = get_local_embedding_dimension();
459 for emb in &embeddings {
460 if emb.len() != expected_dim {
461 return Err(Error::Embedding(format!(
462 "Unexpected embedding dimension: {} (expected {})",
463 emb.len(),
464 expected_dim
465 )));
466 }
467 }
468
469 if embeddings.len() != texts.len() {
470 return Err(Error::Embedding(format!(
471 "Mismatched embedding count: {} embeddings for {} texts",
472 embeddings.len(),
473 texts.len()
474 )));
475 }
476
477 Ok(embeddings)
478}
479
480fn which_python() -> Result<String> {
482 for cmd in &["python3", "python"] {
484 if std::process::Command::new(cmd)
485 .arg("--version")
486 .output()
487 .is_ok()
488 {
489 return Ok(cmd.to_string());
490 }
491 }
492 Err(Error::Embedding("Python not found. Install Python 3 to use local embeddings.".to_string()))
493}
494
495async fn embed_batch_local_hash(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
500 use std::collections::hash_map::DefaultHasher;
501 use std::hash::{Hash, Hasher};
502
503 let embeddings: Vec<Vec<f32>> = texts
504 .iter()
505 .map(|text| {
506 let mut hasher = DefaultHasher::new();
507 text.hash(&mut hasher);
508 let hash = hasher.finish();
509
510 let dim = get_local_embedding_dimension();
511 let mut embedding = vec![0.0f32; dim];
512 for i in 0..dim {
513 let seed = hash.wrapping_add(i as u64);
514 embedding[i] = ((seed % 2000) as f32 - 1000.0) / 1000.0;
515 }
516
517 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
518 if norm > 0.0 {
519 for x in &mut embedding {
520 *x /= norm;
521 }
522 }
523
524 embedding
525 })
526 .collect();
527
528 Ok(embeddings)
529}
530
531async fn embed_batch_openai(
533 texts: Vec<&str>,
534 api_key: Option<&str>,
535) -> Result<Vec<Vec<f32>>> {
536 let api_key = api_key
537 .map(|s| s.to_string())
538 .or_else(|| env::var("OPENAI_API_KEY").ok())
539 .ok_or_else(|| {
540 Error::Embedding(
541 "OPENAI_API_KEY environment variable not set and no API key provided".to_string(),
542 )
543 })?;
544
545 if texts.len() > 2048 {
547 return Err(Error::InvalidInput(format!(
548 "Too many texts to embed at once: {} (max 2048)",
549 texts.len()
550 )));
551 }
552
553 let client = Client::new();
554
555 let request = EmbeddingRequest {
556 model: OPENAI_MODEL.to_string(),
557 input: texts.iter().map(|s| s.to_string()).collect(),
558 };
559
560 let response = client
561 .post(OPENAI_API_URL)
562 .header("Authorization", format!("Bearer {}", api_key))
563 .header("Content-Type", "application/json")
564 .json(&request)
565 .send()
566 .await
567 .map_err(|e| Error::Embedding(format!("API request failed: {}", e)))?;
568
569 if !response.status().is_success() {
570 let status = response.status();
571 let body = response.text().await.unwrap_or_default();
572 return Err(Error::Embedding(format!(
573 "API returned error {}: {}",
574 status, body
575 )));
576 }
577
578 let embedding_response: EmbeddingResponse = response
579 .json()
580 .await
581 .map_err(|e| Error::Embedding(format!("Failed to parse response: {}", e)))?;
582
583 let mut data = embedding_response.data;
585 data.sort_by_key(|d| d.index);
586
587 let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
588
589 for emb in &embeddings {
591 if emb.len() != OPENAI_DIMENSION {
592 return Err(Error::Embedding(format!(
593 "Unexpected embedding dimension: {} (expected {})",
594 emb.len(),
595 OPENAI_DIMENSION
596 )));
597 }
598 }
599
600 Ok(embeddings)
601}
602
603async fn embed_batch_remote(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
618 use serde_json::json;
619
620 let url = env::var("AVOCADODB_EMBEDDING_URL")
621 .map_err(|_| Error::Embedding("AVOCADODB_EMBEDDING_URL not set for remote provider".to_string()))?;
622 if texts.is_empty() {
623 return Ok(vec![]);
624 }
625
626 let client = Client::new();
627 let mut req = client.post(&url).header("Content-Type", "application/json");
628
629 if let Ok(api_key) = env::var("AVOCADODB_EMBEDDING_API_KEY") {
630 if !api_key.is_empty() {
631 req = req.header("Authorization", format!("Bearer {}", api_key));
632 }
633 }
634
635 let model = env::var("AVOCADODB_EMBEDDING_MODEL").ok();
636 let body = if let Some(model_name) = model {
637 json!({ "inputs": texts, "model": model_name })
638 } else {
639 json!({ "inputs": texts })
640 };
641
642 let resp = req
643 .json(&body)
644 .send()
645 .await
646 .map_err(|e| Error::Embedding(format!("Remote request failed: {}", e)))?;
647
648 if !resp.status().is_success() {
649 let status = resp.status();
650 let text = resp.text().await.unwrap_or_default();
651 return Err(Error::Embedding(format!("Remote returned error {}: {}", status, text)));
652 }
653
654 let expected_dim = EmbeddingProvider::Remote.dimension();
656 let text_body = resp.text().await.map_err(|e| Error::Embedding(format!("Failed reading remote body: {}", e)))?;
657
658 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text_body) {
660 if let Some(arr) = v.get("embeddings").and_then(|x| x.as_array()) {
661 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
662 for item in arr {
663 let vec_opt = item.as_array().map(|nums| {
664 nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
665 });
666 let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
667 if !vec.is_empty() && vec.len() != expected_dim {
668 if let Some(dim) = v.get("dimension").and_then(|d| d.as_u64()).map(|d| d as usize) {
670 if vec.len() != dim {
671 return Err(Error::Embedding(format!(
672 "Unexpected embedding dimension: {} (expected {})",
673 vec.len(),
674 expected_dim
675 )));
676 }
677 } else {
678 return Err(Error::Embedding(format!(
679 "Unexpected embedding dimension: {} (expected {})",
680 vec.len(),
681 expected_dim
682 )));
683 }
684 }
685 embeddings.push(vec);
686 }
687 if embeddings.len() != texts.len() {
688 return Err(Error::Embedding(format!(
689 "Mismatched embedding count: got {}, expected {}",
690 embeddings.len(),
691 texts.len()
692 )));
693 }
694 return Ok(embeddings);
695 }
696
697 if let Some(arr) = v.as_array() {
699 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
700 for item in arr {
701 let vec_opt = item.as_array().map(|nums| {
702 nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
703 });
704 let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
705 if !vec.is_empty() && vec.len() != expected_dim {
706 return Err(Error::Embedding(format!(
707 "Unexpected embedding dimension: {} (expected {})",
708 vec.len(),
709 expected_dim
710 )));
711 }
712 embeddings.push(vec);
713 }
714 if embeddings.len() != texts.len() {
715 return Err(Error::Embedding(format!(
716 "Mismatched embedding count: got {}, expected {}",
717 embeddings.len(),
718 texts.len()
719 )));
720 }
721 return Ok(embeddings);
722 }
723 }
724
725 Err(Error::Embedding("Failed to parse remote embedding response".to_string()))
726}
727
728async fn embed_batch_ollama(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
741 use serde_json::json;
742
743 let base_url = env::var("AVOCADODB_OLLAMA_URL")
744 .unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string());
745 let model = get_ollama_model_name();
746 let expected_dim = get_ollama_embedding_dimension();
747
748 if texts.is_empty() {
749 return Ok(vec![]);
750 }
751
752 let client = Client::new();
753
754 let url = format!("{}/api/embed", base_url);
756 let body = json!({
757 "model": model,
758 "input": texts,
759 });
760
761 let resp = client
762 .post(&url)
763 .header("Content-Type", "application/json")
764 .json(&body)
765 .send()
766 .await
767 .map_err(|e| Error::Embedding(format!("Ollama request failed: {}", e)))?;
768
769 if resp.status().is_success() {
770 let text_body = resp.text().await
771 .map_err(|e| Error::Embedding(format!("Failed reading Ollama response: {}", e)))?;
772
773 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text_body) {
774 if let Some(arr) = v.get("embeddings").and_then(|x| x.as_array()) {
775 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
776 for item in arr {
777 let vec: Vec<f32> = item.as_array()
778 .map(|nums| nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect())
779 .ok_or_else(|| Error::Embedding("Invalid embedding array".to_string()))?;
780 embeddings.push(vec);
781 }
782 if embeddings.len() != texts.len() {
783 return Err(Error::Embedding(format!(
784 "Mismatched embedding count: got {}, expected {}",
785 embeddings.len(),
786 texts.len()
787 )));
788 }
789 return Ok(embeddings);
790 }
791 }
792 }
793
794 let url = format!("{}/api/embeddings", base_url);
796 let mut embeddings = Vec::with_capacity(texts.len());
797
798 for text in texts {
799 let body = json!({
800 "model": model,
801 "prompt": text,
802 });
803
804 let resp = client
805 .post(&url)
806 .header("Content-Type", "application/json")
807 .json(&body)
808 .send()
809 .await
810 .map_err(|e| Error::Embedding(format!("Ollama request failed: {}", e)))?;
811
812 if !resp.status().is_success() {
813 let status = resp.status();
814 let body = resp.text().await.unwrap_or_default();
815 return Err(Error::Embedding(format!(
816 "Ollama API error {}: {}",
817 status, body
818 )));
819 }
820
821 let text_body = resp.text().await
822 .map_err(|e| Error::Embedding(format!("Failed reading Ollama response: {}", e)))?;
823
824 let v: serde_json::Value = serde_json::from_str(&text_body)
825 .map_err(|e| Error::Embedding(format!("Failed parsing Ollama response: {}", e)))?;
826
827 let embedding: Vec<f32> = v.get("embedding")
828 .and_then(|e| e.as_array())
829 .map(|arr| arr.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect())
830 .ok_or_else(|| Error::Embedding("No embedding in Ollama response".to_string()))?;
831
832 if embedding.len() != expected_dim {
833 return Err(Error::Embedding(format!(
834 "Unexpected embedding dimension: {} (expected {})",
835 embedding.len(),
836 expected_dim
837 )));
838 }
839
840 embeddings.push(embedding);
841 }
842
843 Ok(embeddings)
844}
845
846pub fn embedding_model() -> &'static str {
848 EmbeddingProvider::from_env().model_name()
849}
850
851pub fn embedding_dimension() -> usize {
853 EmbeddingProvider::from_env().dimension()
854}
855
856#[cfg(test)]
857mod tests {
858 use super::*;
859
860 #[test]
861 fn test_embedding_provider_default() {
862 let provider = EmbeddingProvider::default();
864 assert_eq!(provider, EmbeddingProvider::Local);
865 assert_eq!(provider.dimension(), get_local_embedding_dimension());
866 }
867
868 #[test]
869 fn test_embedding_dimensions() {
870 assert_eq!(EmbeddingProvider::Local.dimension(), get_local_embedding_dimension());
872 assert_eq!(EmbeddingProvider::OpenAI.dimension(), 1536);
873 }
874
875 #[tokio::test]
876 async fn test_embed_batch_local() {
877 let texts = vec!["Hello", "World", "Test"];
879 let result = embed_batch_local(texts).await;
880
881 assert!(result.is_ok());
882 let embeddings = result.unwrap();
883 assert_eq!(embeddings.len(), 3);
884 for emb in embeddings {
885 assert_eq!(emb.len(), get_local_embedding_dimension());
886 }
887 }
888
889 #[tokio::test]
890 #[ignore] async fn test_embed_text_openai() {
892 let result = embed_text("Hello, world!", Some(EmbeddingProvider::OpenAI), None).await;
893 if env::var("OPENAI_API_KEY").is_ok() {
894 let embedding = result.unwrap();
895 assert_eq!(embedding.len(), OPENAI_DIMENSION);
896 } else {
897 assert!(result.is_err());
898 }
899 }
900}