avocado_core/
embedding.rs

1//! Embedding generation with local and OpenAI support
2//!
3//! By default, uses local embeddings (all-MiniLM-L6-v2 via ONNX) for self-sufficiency.
4//! OpenAI embeddings are optional and can be enabled via OPENAI_API_KEY.
5
6use crate::types::{Error, Result};
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::env;
10use tokio::process::Command as AsyncCommand;
11use std::sync::{Mutex, Once};
12use std::sync::OnceLock;
13use tokio::io::AsyncWriteExt;
14
15// OpenAI constants
16const OPENAI_API_URL: &str = "https://api.openai.com/v1/embeddings";
17const OPENAI_MODEL: &str = "text-embedding-ada-002";
18const OPENAI_DIMENSION: usize = 1536;
19
20// Local embedding model configuration
21// Default: all-MiniLM-L6-v2 (384 dimensions) - fast and efficient
22// Can be overridden via AVOCADODB_EMBEDDING_MODEL environment variable
23// Available models and their dimensions:
24//   - AllMiniLML6V2: 384 dims (default, fastest)
25//   - AllMiniLML12V2: 384 dims (slightly better quality)
26//   - BGESmallENV15: 384 dims (good for English)
27//   - BGELargeENV15: 1024 dims (higher quality, slower)
28//   - NomicEmbedTextV1: 768 dims (good balance)
29//   - NomicEmbedTextV15: 768 dims (improved version)
30const DEFAULT_LOCAL_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2";
31const DEFAULT_LOCAL_DIMENSION: usize = 384;
32
33/// Get the local embedding model enum based on environment variable
34fn get_local_embedding_model() -> fastembed::EmbeddingModel {
35    use fastembed::EmbeddingModel;
36    
37    if let Ok(model_str) = env::var("AVOCADODB_EMBEDDING_MODEL") {
38        match model_str.to_lowercase().as_str() {
39            "allminilml6v2" | "all-minilm-l6-v2" | "minilm6" => EmbeddingModel::AllMiniLML6V2,
40            "allminilml12v2" | "all-minilm-l12-v2" | "minilm12" => EmbeddingModel::AllMiniLML12V2,
41            "bgesmallen" | "bge-small-en-v1.5" | "bgesmall" => EmbeddingModel::BGESmallENV15,
42            "bgelargeen" | "bge-large-en-v1.5" | "bgelarge" => EmbeddingModel::BGELargeENV15,
43            "nomicv1" | "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
44            "nomicv15" | "nomic-embed-text-v1.5" | "nomic" => EmbeddingModel::NomicEmbedTextV15,
45            _ => {
46                log::warn!("Unknown embedding model '{}', using default AllMiniLML6V2", model_str);
47                EmbeddingModel::AllMiniLML6V2
48            }
49        }
50    } else {
51        EmbeddingModel::AllMiniLML6V2
52    }
53}
54
55/// Get the dimension for the selected local embedding model
56fn get_local_embedding_dimension() -> usize {
57    use fastembed::EmbeddingModel;
58    
59    match get_local_embedding_model() {
60        EmbeddingModel::AllMiniLML6V2 => 384,
61        EmbeddingModel::AllMiniLML12V2 => 384,
62        EmbeddingModel::BGESmallENV15 => 384,
63        EmbeddingModel::BGELargeENV15 => 1024,
64        EmbeddingModel::NomicEmbedTextV1 => 768,
65        EmbeddingModel::NomicEmbedTextV15 => 768,
66        _ => DEFAULT_LOCAL_DIMENSION, // Fallback
67    }
68}
69
70/// Get the model name string for the selected local embedding model
71fn get_local_model_name() -> &'static str {
72    use fastembed::EmbeddingModel;
73    
74    match get_local_embedding_model() {
75        EmbeddingModel::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
76        EmbeddingModel::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
77        EmbeddingModel::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
78        EmbeddingModel::BGELargeENV15 => "BAAI/bge-large-en-v1.5",
79        EmbeddingModel::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
80        EmbeddingModel::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
81        _ => DEFAULT_LOCAL_MODEL, // Fallback
82    }
83}
84
85/// Embedding provider type
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum EmbeddingProvider {
88    /// Local embeddings using candle (default, no API required)
89    Local,
90    /// OpenAI embeddings (requires OPENAI_API_KEY)
91    OpenAI,
92    /// Remote HTTP embeddings (GPU sandbox or custom service)
93    Remote,
94}
95
96impl Default for EmbeddingProvider {
97    fn default() -> Self {
98        // Default to local for self-sufficiency
99        EmbeddingProvider::Local
100    }
101}
102
103impl EmbeddingProvider {
104    /// Detect provider from environment or use default
105    pub fn from_env() -> Self {
106        // If OPENAI_API_KEY is set, allow OpenAI as option
107        // But default to local for self-sufficiency
108        if env::var("AVOCADODB_EMBEDDING_PROVIDER").is_ok() {
109            match env::var("AVOCADODB_EMBEDDING_PROVIDER")
110                .unwrap()
111                .to_lowercase()
112                .as_str()
113            {
114                "openai" => EmbeddingProvider::OpenAI,
115                "local" => EmbeddingProvider::Local,
116                "remote" => EmbeddingProvider::Remote,
117                _ => EmbeddingProvider::Local,
118            }
119        } else {
120            EmbeddingProvider::Local
121        }
122    }
123
124    pub fn dimension(&self) -> usize {
125        match self {
126            EmbeddingProvider::Local => get_local_embedding_dimension(),
127            EmbeddingProvider::OpenAI => OPENAI_DIMENSION,
128            EmbeddingProvider::Remote => {
129                // Allow overriding remote dimension via env; default to local dimension for compatibility
130                env::var("AVOCADODB_EMBEDDING_DIM")
131                    .ok()
132                    .and_then(|s| s.parse::<usize>().ok())
133                    .unwrap_or_else(get_local_embedding_dimension)
134            }
135        }
136    }
137
138    pub fn model_name(&self) -> &'static str {
139        match self {
140            EmbeddingProvider::Local => get_local_model_name(),
141            EmbeddingProvider::OpenAI => OPENAI_MODEL,
142            // Remote model name is not fixed; callers can optionally set AVOCADODB_EMBEDDING_MODEL
143            EmbeddingProvider::Remote => DEFAULT_LOCAL_MODEL,
144        }
145    }
146}
147
148// OpenAI API request/response types
149#[derive(Debug, Serialize)]
150struct EmbeddingRequest {
151    model: String,
152    input: Vec<String>,
153}
154
155#[derive(Debug, Deserialize)]
156struct EmbeddingResponse {
157    data: Vec<EmbeddingData>,
158}
159
160#[derive(Debug, Deserialize)]
161struct EmbeddingData {
162    embedding: Vec<f32>,
163    index: usize,
164}
165
166/// Embed a single text string
167///
168/// Uses local embeddings by default (no API required).
169/// Set AVOCADODB_EMBEDDING_PROVIDER=openai to use OpenAI.
170///
171/// # Arguments
172///
173/// * `text` - The text to embed
174/// * `provider` - Embedding provider (defaults to Local)
175/// * `api_key` - OpenAI API key (only used if provider is OpenAI)
176///
177/// # Returns
178///
179/// A vector of floats representing the embedding (384 for local, 1536 for OpenAI)
180pub async fn embed_text(
181    text: &str,
182    provider: Option<EmbeddingProvider>,
183    api_key: Option<&str>,
184) -> Result<Vec<f32>> {
185    let results = embed_batch(vec![text], provider, api_key).await?;
186    results.into_iter().next().ok_or_else(|| {
187        Error::Embedding("No embedding returned".to_string())
188    })
189}
190
191/// Embed multiple text strings
192///
193/// Uses local embeddings by default (no API required).
194/// Set AVOCADODB_EMBEDDING_PROVIDER=openai to use OpenAI.
195///
196/// # Arguments
197///
198/// * `texts` - Vector of text strings to embed
199/// * `provider` - Embedding provider (defaults to Local)
200/// * `api_key` - OpenAI API key (only used if provider is OpenAI)
201///
202/// # Returns
203///
204/// A vector of embeddings, in the same order as the input texts
205pub async fn embed_batch(
206    texts: Vec<&str>,
207    provider: Option<EmbeddingProvider>,
208    api_key: Option<&str>,
209) -> Result<Vec<Vec<f32>>> {
210    let provider = provider.unwrap_or_else(EmbeddingProvider::from_env);
211
212    if texts.is_empty() {
213        return Ok(vec![]);
214    }
215
216    match provider {
217        EmbeddingProvider::Local => embed_batch_local(texts).await,
218        EmbeddingProvider::OpenAI => embed_batch_openai(texts, api_key).await,
219        EmbeddingProvider::Remote => embed_batch_remote(texts).await,
220    }
221}
222
223/// Local embedding generation with multiple strategies
224///
225/// Pure Rust implementation using fastembed (ONNX-based) for semantic embeddings.
226/// Falls back to Python subprocess, then hash-based if needed.
227///
228/// Strategy priority:
229/// 1. Pure Rust with fastembed (semantic, good quality, no Python required)
230///    - Uses all-MiniLM-L6-v2 model (384 dimensions)
231///    - ONNX-based, fast and efficient
232///    - Model cached after first download (~90MB)
233/// 2. Python subprocess with sentence-transformers (fallback if fastembed fails)
234///    - Requires: `pip install sentence-transformers`
235/// 3. Hash-based fallback (deterministic, but NOT semantic)
236///    - Works without dependencies
237///    - Poor semantic quality (similar texts don't cluster)
238async fn embed_batch_local(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
239    // Try pure Rust fastembed first (best performance, no Python required)
240    if let Ok(embeddings) = embed_batch_local_rust(texts.clone()).await {
241        return Ok(embeddings);
242    }
243    
244    // Respect hard-fail configuration to forbid non-semantic fallbacks in production
245    if matches!(std::env::var("AVOCADODB_FORBID_FALLBACKS").ok().as_deref(), Some("1" | "true" | "TRUE" | "yes" | "YES")) {
246        return Err(Error::Embedding(
247            "Local fastembed failed and fallbacks are disabled (AVOCADODB_FORBID_FALLBACKS=1)".to_string()
248        ));
249    }
250
251    // Fallback to Python sentence-transformers (if available)
252    static PY_WARN_ONCE: Once = Once::new();
253    PY_WARN_ONCE.call_once(|| {
254        log::warn!("Falling back to Python sentence-transformers for embeddings. Install Rust fastembed for best performance.");
255    });
256    if let Ok(embeddings) = embed_batch_local_python(texts.clone()).await {
257        return Ok(embeddings);
258    }
259    
260    // Final fallback to hash-based embeddings (works always, but not semantic)
261    static HASH_WARN_ONCE: Once = Once::new();
262    HASH_WARN_ONCE.call_once(|| {
263        log::error!("Falling back to HASH-BASED embeddings (NOT SEMANTIC). This mode is for emergencies only.");
264    });
265    embed_batch_local_hash(texts).await
266}
267
268/// Pure Rust embeddings using fastembed (ONNX-based, no Python required)
269///
270/// Uses fastembed crate with all-MiniLM-L6-v2 model for semantic embeddings.
271/// This is the preferred method as it's pure Rust, fast, and doesn't require Python.
272///
273/// Model is downloaded from HuggingFace on first use and cached locally.
274/// fastembed handles model caching internally, so initialization is fast after first use.
275async fn embed_batch_local_rust(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
276    use fastembed::{TextEmbedding, InitOptions};
277    use tokio::task;
278    
279    if texts.is_empty() {
280        return Ok(vec![]);
281    }
282    
283    // Convert &str to String for the blocking task
284    let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
285    
286    // Cache the TextEmbedding model instance across the process to avoid repeated initialization
287    static FASTEMBED_MODEL: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
288
289    // fastembed is synchronous, so we run it in a blocking task
290    // Note: fastembed handles model caching internally, so initialization is fast
291    let embeddings = task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
292        // Initialize or reuse cached model (downloads on first use, then caches)
293        let model_mutex = FASTEMBED_MODEL.get_or_init(|| {
294            let embedding_model = get_local_embedding_model();
295            let model = TextEmbedding::try_new(
296                InitOptions::new(embedding_model)
297                    .with_show_download_progress(false)
298            )
299            .expect("Failed to initialize fastembed model");
300            Mutex::new(model)
301        });
302
303        // Generate embeddings (fastembed handles normalization)
304        let embeddings = model_mutex
305            .lock()
306            .map_err(|_| Error::Embedding("Failed to lock fastembed model".to_string()))?
307            .embed(texts_owned, None)
308            .map_err(|e| Error::Embedding(format!("Failed to generate embeddings: {}", e)))?;
309        
310        // Verify dimensions (get expected dimension for selected model)
311        let expected_dim = get_local_embedding_dimension();
312        for emb in &embeddings {
313            if emb.len() != expected_dim {
314                return Err(Error::Embedding(format!(
315                    "Unexpected embedding dimension: {} (expected {})",
316                    emb.len(),
317                    expected_dim
318                )));
319            }
320        }
321        
322        Ok(embeddings)
323    })
324    .await
325    .map_err(|e| Error::Embedding(format!("Task join error: {}", e)))??;
326    
327    Ok(embeddings)
328}
329
330/// Local embeddings using Python sentence-transformers (semantic, good quality)
331///
332/// Uses a Python subprocess to call sentence-transformers with all-MiniLM-L6-v2.
333/// This provides semantic embeddings without API keys.
334///
335/// Requires: Python with sentence-transformers installed
336///   pip install sentence-transformers
337async fn embed_batch_local_python(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
338    // Check if Python is available
339    let python = which_python()?;
340    
341    // Create a Python script to generate embeddings
342    let script = format!(r#"
343import sys
344import json
345
346try:
347    from sentence_transformers import SentenceTransformer
348    import numpy as np
349    
350    # Load model (cached after first use)
351    model = SentenceTransformer('all-MiniLM-L6-v2')
352    
353    # Read texts from stdin (one per line)
354    texts = []
355    for line in sys.stdin:
356        texts.append(line.strip())
357    
358    # Generate embeddings
359    embeddings = model.encode(texts, normalize_embeddings=True)
360    
361    # Output as JSON array
362    result = [emb.tolist() for emb in embeddings]
363    print(json.dumps(result))
364    sys.exit(0)
365except ImportError:
366    print(json.dumps({{"error": "sentence-transformers not installed. Install with: pip install sentence-transformers"}}), file=sys.stderr)
367    sys.exit(1)
368except Exception as e:
369    print(json.dumps({{"error": str(e)}}), file=sys.stderr)
370    sys.exit(1)
371"#);
372    
373    // Run Python script
374    let mut child = AsyncCommand::new(&python)
375        .arg("-c")
376        .arg(&script)
377        .stdin(std::process::Stdio::piped())
378        .stdout(std::process::Stdio::piped())
379        .stderr(std::process::Stdio::piped())
380        .spawn()
381        .map_err(|e| Error::Embedding(format!("Failed to spawn Python process: {}", e)))?;
382    
383    // Write texts to stdin
384    if let Some(mut stdin) = child.stdin.take() {
385        for text in &texts {
386            stdin.write_all(text.as_bytes())
387                .await
388                .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
389            stdin.write_all(b"\n")
390                .await
391                .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
392        }
393        stdin.shutdown().await
394            .map_err(|e| Error::Embedding(format!("Failed to close Python stdin: {}", e)))?;
395    }
396    
397    // Wait for output
398    let output = child.wait_with_output()
399        .await
400        .map_err(|e| Error::Embedding(format!("Failed to wait for Python process: {}", e)))?;
401    
402    if !output.status.success() {
403        let stderr = String::from_utf8_lossy(&output.stderr);
404        return Err(Error::Embedding(format!("Python embedding failed: {}", stderr)));
405    }
406    
407    // Parse JSON output
408    let stdout = String::from_utf8_lossy(&output.stdout);
409    let embeddings: Vec<Vec<f32>> = serde_json::from_str(&stdout)
410        .map_err(|e| Error::Embedding(format!("Failed to parse Python output: {}", e)))?;
411    
412        // Verify dimensions (use default for Python fallback)
413        let expected_dim = get_local_embedding_dimension();
414        for emb in &embeddings {
415            if emb.len() != expected_dim {
416                return Err(Error::Embedding(format!(
417                    "Unexpected embedding dimension: {} (expected {})",
418                    emb.len(),
419                    expected_dim
420                )));
421            }
422        }
423    
424    if embeddings.len() != texts.len() {
425        return Err(Error::Embedding(format!(
426            "Mismatched embedding count: {} embeddings for {} texts",
427            embeddings.len(),
428            texts.len()
429        )));
430    }
431    
432    Ok(embeddings)
433}
434
435/// Find Python executable (python3 or python)
436fn which_python() -> Result<String> {
437    // Try python3 first, then python
438    for cmd in &["python3", "python"] {
439        if std::process::Command::new(cmd)
440            .arg("--version")
441            .output()
442            .is_ok()
443        {
444            return Ok(cmd.to_string());
445        }
446    }
447    Err(Error::Embedding("Python not found. Install Python 3 to use local embeddings.".to_string()))
448}
449
450/// Hash-based embeddings (fallback, NOT semantic)
451///
452/// Deterministic but not semantic - similar texts won't have similar embeddings.
453/// Used as fallback when Python/sentence-transformers is unavailable.
454async fn embed_batch_local_hash(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
455    use std::collections::hash_map::DefaultHasher;
456    use std::hash::{Hash, Hasher};
457    
458    let embeddings: Vec<Vec<f32>> = texts
459        .iter()
460        .map(|text| {
461            let mut hasher = DefaultHasher::new();
462            text.hash(&mut hasher);
463            let hash = hasher.finish();
464            
465            let dim = get_local_embedding_dimension();
466            let mut embedding = vec![0.0f32; dim];
467            for i in 0..dim {
468                let seed = hash.wrapping_add(i as u64);
469                embedding[i] = ((seed % 2000) as f32 - 1000.0) / 1000.0;
470            }
471            
472            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
473            if norm > 0.0 {
474                for x in &mut embedding {
475                    *x /= norm;
476                }
477            }
478            
479            embedding
480        })
481        .collect();
482    
483    Ok(embeddings)
484}
485
486/// OpenAI embedding generation
487async fn embed_batch_openai(
488    texts: Vec<&str>,
489    api_key: Option<&str>,
490) -> Result<Vec<Vec<f32>>> {
491    let api_key = api_key
492        .map(|s| s.to_string())
493        .or_else(|| env::var("OPENAI_API_KEY").ok())
494        .ok_or_else(|| {
495            Error::Embedding(
496                "OPENAI_API_KEY environment variable not set and no API key provided".to_string(),
497            )
498        })?;
499
500    // OpenAI limit is 2048 inputs per request
501    if texts.len() > 2048 {
502        return Err(Error::InvalidInput(format!(
503            "Too many texts to embed at once: {} (max 2048)",
504            texts.len()
505        )));
506    }
507
508    let client = Client::new();
509
510    let request = EmbeddingRequest {
511        model: OPENAI_MODEL.to_string(),
512        input: texts.iter().map(|s| s.to_string()).collect(),
513    };
514
515    let response = client
516        .post(OPENAI_API_URL)
517        .header("Authorization", format!("Bearer {}", api_key))
518        .header("Content-Type", "application/json")
519        .json(&request)
520        .send()
521        .await
522        .map_err(|e| Error::Embedding(format!("API request failed: {}", e)))?;
523
524    if !response.status().is_success() {
525        let status = response.status();
526        let body = response.text().await.unwrap_or_default();
527        return Err(Error::Embedding(format!(
528            "API returned error {}: {}",
529            status, body
530        )));
531    }
532
533    let embedding_response: EmbeddingResponse = response
534        .json()
535        .await
536        .map_err(|e| Error::Embedding(format!("Failed to parse response: {}", e)))?;
537
538    // Sort by index to ensure correct ordering
539    let mut data = embedding_response.data;
540    data.sort_by_key(|d| d.index);
541
542    let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
543
544    // Verify all embeddings have correct dimension
545    for emb in &embeddings {
546        if emb.len() != OPENAI_DIMENSION {
547            return Err(Error::Embedding(format!(
548                "Unexpected embedding dimension: {} (expected {})",
549                emb.len(),
550                OPENAI_DIMENSION
551            )));
552        }
553    }
554
555    Ok(embeddings)
556}
557
558/// Remote HTTP embedding generation
559///
560/// The remote service is configured via:
561/// - AVOCADODB_EMBEDDING_URL: required, e.g. https://your-modal-fn.modal.run/embed
562/// - AVOCADODB_EMBEDDING_API_KEY: optional, sent as Bearer token
563/// - AVOCADODB_EMBEDDING_MODEL: optional, forwarded to remote
564/// - AVOCADODB_EMBEDDING_DIM: optional, expected dimension (defaults to local dim)
565///
566/// Expected request body:
567///   { "inputs": ["text1", "text2"], "model": "BAAI/bge-small-en-v1.5" }
568///
569/// Expected response body (either of the following):
570///   { "embeddings": [[..],[..]], "dimension": 384 }
571///   [[..],[..]]
572async fn embed_batch_remote(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
573    use serde_json::json;
574
575    let url = env::var("AVOCADODB_EMBEDDING_URL")
576        .map_err(|_| Error::Embedding("AVOCADODB_EMBEDDING_URL not set for remote provider".to_string()))?;
577    if texts.is_empty() {
578        return Ok(vec![]);
579    }
580
581    let client = Client::new();
582    let mut req = client.post(&url).header("Content-Type", "application/json");
583
584    if let Ok(api_key) = env::var("AVOCADODB_EMBEDDING_API_KEY") {
585        if !api_key.is_empty() {
586            req = req.header("Authorization", format!("Bearer {}", api_key));
587        }
588    }
589
590    let model = env::var("AVOCADODB_EMBEDDING_MODEL").ok();
591    let body = if let Some(model_name) = model {
592        json!({ "inputs": texts, "model": model_name })
593    } else {
594        json!({ "inputs": texts })
595    };
596
597    let resp = req
598        .json(&body)
599        .send()
600        .await
601        .map_err(|e| Error::Embedding(format!("Remote request failed: {}", e)))?;
602
603    if !resp.status().is_success() {
604        let status = resp.status();
605        let text = resp.text().await.unwrap_or_default();
606        return Err(Error::Embedding(format!("Remote returned error {}: {}", status, text)));
607    }
608
609    // Try to parse as { embeddings: [...], dimension?: N }
610    let expected_dim = EmbeddingProvider::Remote.dimension();
611    let text_body = resp.text().await.map_err(|e| Error::Embedding(format!("Failed reading remote body: {}", e)))?;
612
613    // First attempt: object with embeddings
614    if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text_body) {
615        if let Some(arr) = v.get("embeddings").and_then(|x| x.as_array()) {
616            let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
617            for item in arr {
618                let vec_opt = item.as_array().map(|nums| {
619                    nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
620                });
621                let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
622                if !vec.is_empty() && vec.len() != expected_dim {
623                    // Allow remote to communicate dimension if provided
624                    if let Some(dim) = v.get("dimension").and_then(|d| d.as_u64()).map(|d| d as usize) {
625                        if vec.len() != dim {
626                            return Err(Error::Embedding(format!(
627                                "Unexpected embedding dimension: {} (expected {})",
628                                vec.len(),
629                                expected_dim
630                            )));
631                        }
632                    } else {
633                        return Err(Error::Embedding(format!(
634                            "Unexpected embedding dimension: {} (expected {})",
635                            vec.len(),
636                            expected_dim
637                        )));
638                    }
639                }
640                embeddings.push(vec);
641            }
642            if embeddings.len() != texts.len() {
643                return Err(Error::Embedding(format!(
644                    "Mismatched embedding count: got {}, expected {}",
645                    embeddings.len(),
646                    texts.len()
647                )));
648            }
649            return Ok(embeddings);
650        }
651
652        // Second attempt: top-level array
653        if let Some(arr) = v.as_array() {
654            let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
655            for item in arr {
656                let vec_opt = item.as_array().map(|nums| {
657                    nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
658                });
659                let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
660                if !vec.is_empty() && vec.len() != expected_dim {
661                    return Err(Error::Embedding(format!(
662                        "Unexpected embedding dimension: {} (expected {})",
663                        vec.len(),
664                        expected_dim
665                    )));
666                }
667                embeddings.push(vec);
668            }
669            if embeddings.len() != texts.len() {
670                return Err(Error::Embedding(format!(
671                    "Mismatched embedding count: got {}, expected {}",
672                    embeddings.len(),
673                    texts.len()
674                )));
675            }
676            return Ok(embeddings);
677        }
678    }
679
680    Err(Error::Embedding("Failed to parse remote embedding response".to_string()))
681}
682/// Get the embedding model name (based on current provider)
683pub fn embedding_model() -> &'static str {
684    EmbeddingProvider::from_env().model_name()
685}
686
687/// Get the embedding dimension (based on current provider)
688pub fn embedding_dimension() -> usize {
689    EmbeddingProvider::from_env().dimension()
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_embedding_provider_default() {
698        // Default should be local
699        let provider = EmbeddingProvider::default();
700        assert_eq!(provider, EmbeddingProvider::Local);
701        assert_eq!(provider.dimension(), get_local_embedding_dimension());
702    }
703
704    #[test]
705    fn test_embedding_dimensions() {
706        // Default model is AllMiniLML6V2 with 384 dimensions
707        assert_eq!(EmbeddingProvider::Local.dimension(), get_local_embedding_dimension());
708        assert_eq!(EmbeddingProvider::OpenAI.dimension(), 1536);
709    }
710
711    #[tokio::test]
712    async fn test_embed_batch_local() {
713        // Test local embeddings (should work without API key)
714        let texts = vec!["Hello", "World", "Test"];
715        let result = embed_batch_local(texts).await;
716        
717        assert!(result.is_ok());
718        let embeddings = result.unwrap();
719        assert_eq!(embeddings.len(), 3);
720        for emb in embeddings {
721            assert_eq!(emb.len(), get_local_embedding_dimension());
722        }
723    }
724
725    #[tokio::test]
726    #[ignore] // Only run when OPENAI_API_KEY is set
727    async fn test_embed_text_openai() {
728        let result = embed_text("Hello, world!", Some(EmbeddingProvider::OpenAI), None).await;
729        if env::var("OPENAI_API_KEY").is_ok() {
730            let embedding = result.unwrap();
731            assert_eq!(embedding.len(), OPENAI_DIMENSION);
732        } else {
733            assert!(result.is_err());
734        }
735    }
736}