avocado_core/
embedding.rs

1//! Embedding generation with local and OpenAI support
2//!
3//! By default, uses local embeddings (all-MiniLM-L6-v2 via ONNX) for self-sufficiency.
4//! OpenAI embeddings are optional and can be enabled via OPENAI_API_KEY.
5
6use crate::types::{Error, Result};
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::env;
10use tokio::process::Command as AsyncCommand;
11use std::sync::{Mutex, Once};
12use std::sync::OnceLock;
13use tokio::io::AsyncWriteExt;
14
15// OpenAI constants
16const OPENAI_API_URL: &str = "https://api.openai.com/v1/embeddings";
17const OPENAI_MODEL: &str = "text-embedding-ada-002";
18const OPENAI_DIMENSION: usize = 1536;
19
20// Local embedding model configuration
21// Default: all-MiniLM-L6-v2 (384 dimensions) - fast and efficient
22// Can be overridden via AVOCADODB_EMBEDDING_MODEL environment variable
23// Available models and their dimensions:
24//   - AllMiniLML6V2: 384 dims (default, fastest)
25//   - AllMiniLML12V2: 384 dims (slightly better quality)
26//   - BGESmallENV15: 384 dims (good for English)
27//   - BGELargeENV15: 1024 dims (higher quality, slower)
28//   - NomicEmbedTextV1: 768 dims (good balance)
29//   - NomicEmbedTextV15: 768 dims (improved version)
30const DEFAULT_LOCAL_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2";
31const DEFAULT_LOCAL_DIMENSION: usize = 384;
32
33/// Get the local embedding model enum based on environment variable
34fn get_local_embedding_model() -> fastembed::EmbeddingModel {
35    use fastembed::EmbeddingModel;
36    
37    if let Ok(model_str) = env::var("AVOCADODB_EMBEDDING_MODEL") {
38        match model_str.to_lowercase().as_str() {
39            "allminilml6v2" | "all-minilm-l6-v2" | "minilm6" => EmbeddingModel::AllMiniLML6V2,
40            "allminilml12v2" | "all-minilm-l12-v2" | "minilm12" => EmbeddingModel::AllMiniLML12V2,
41            "bgesmallen" | "bge-small-en-v1.5" | "bgesmall" => EmbeddingModel::BGESmallENV15,
42            "bgelargeen" | "bge-large-en-v1.5" | "bgelarge" => EmbeddingModel::BGELargeENV15,
43            "nomicv1" | "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
44            "nomicv15" | "nomic-embed-text-v1.5" | "nomic" => EmbeddingModel::NomicEmbedTextV15,
45            _ => {
46                log::warn!("Unknown embedding model '{}', using default AllMiniLML6V2", model_str);
47                EmbeddingModel::AllMiniLML6V2
48            }
49        }
50    } else {
51        EmbeddingModel::AllMiniLML6V2
52    }
53}
54
55/// Get the dimension for the selected local embedding model
56fn get_local_embedding_dimension() -> usize {
57    use fastembed::EmbeddingModel;
58    
59    match get_local_embedding_model() {
60        EmbeddingModel::AllMiniLML6V2 => 384,
61        EmbeddingModel::AllMiniLML12V2 => 384,
62        EmbeddingModel::BGESmallENV15 => 384,
63        EmbeddingModel::BGELargeENV15 => 1024,
64        EmbeddingModel::NomicEmbedTextV1 => 768,
65        EmbeddingModel::NomicEmbedTextV15 => 768,
66        _ => DEFAULT_LOCAL_DIMENSION, // Fallback
67    }
68}
69
70/// Get the model name string for the selected local embedding model
71fn get_local_model_name() -> &'static str {
72    use fastembed::EmbeddingModel;
73    
74    match get_local_embedding_model() {
75        EmbeddingModel::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
76        EmbeddingModel::AllMiniLML12V2 => "sentence-transformers/all-MiniLM-L12-v2",
77        EmbeddingModel::BGESmallENV15 => "BAAI/bge-small-en-v1.5",
78        EmbeddingModel::BGELargeENV15 => "BAAI/bge-large-en-v1.5",
79        EmbeddingModel::NomicEmbedTextV1 => "nomic-ai/nomic-embed-text-v1",
80        EmbeddingModel::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
81        _ => DEFAULT_LOCAL_MODEL, // Fallback
82    }
83}
84
85/// Embedding provider type
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum EmbeddingProvider {
88    /// Local embeddings using candle (default, no API required)
89    Local,
90    /// OpenAI embeddings (requires OPENAI_API_KEY)
91    OpenAI,
92    /// Remote HTTP embeddings (GPU sandbox or custom service)
93    Remote,
94    /// Ollama local server (e.g., bge-m3, nomic-embed-text)
95    Ollama,
96}
97
98impl Default for EmbeddingProvider {
99    fn default() -> Self {
100        // Default to local for self-sufficiency
101        EmbeddingProvider::Local
102    }
103}
104
105impl EmbeddingProvider {
106    /// Detect provider from environment or use default
107    pub fn from_env() -> Self {
108        // If OPENAI_API_KEY is set, allow OpenAI as option
109        // But default to local for self-sufficiency
110        if env::var("AVOCADODB_EMBEDDING_PROVIDER").is_ok() {
111            match env::var("AVOCADODB_EMBEDDING_PROVIDER")
112                .unwrap()
113                .to_lowercase()
114                .as_str()
115            {
116                "openai" => EmbeddingProvider::OpenAI,
117                "local" | "fastembed" => EmbeddingProvider::Local,
118                "remote" => EmbeddingProvider::Remote,
119                "ollama" => EmbeddingProvider::Ollama,
120                _ => EmbeddingProvider::Local,
121            }
122        } else {
123            EmbeddingProvider::Local
124        }
125    }
126
127    /// Get the embedding dimension for this provider
128    pub fn dimension(&self) -> usize {
129        match self {
130            EmbeddingProvider::Local => get_local_embedding_dimension(),
131            EmbeddingProvider::OpenAI => OPENAI_DIMENSION,
132            EmbeddingProvider::Ollama => get_ollama_embedding_dimension(),
133            EmbeddingProvider::Remote => {
134                // Allow overriding remote dimension via env; default to local dimension for compatibility
135                env::var("AVOCADODB_EMBEDDING_DIM")
136                    .ok()
137                    .and_then(|s| s.parse::<usize>().ok())
138                    .unwrap_or_else(get_local_embedding_dimension)
139            }
140        }
141    }
142
143    /// Get the model name for this provider
144    pub fn model_name(&self) -> &'static str {
145        match self {
146            EmbeddingProvider::Local => get_local_model_name(),
147            EmbeddingProvider::OpenAI => OPENAI_MODEL,
148            EmbeddingProvider::Ollama => get_ollama_model_name(),
149            // Remote model name is not fixed; callers can optionally set AVOCADODB_EMBEDDING_MODEL
150            EmbeddingProvider::Remote => DEFAULT_LOCAL_MODEL,
151        }
152    }
153}
154
155// Ollama configuration
156const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
157const DEFAULT_OLLAMA_MODEL: &str = "bge-m3";
158
159/// Get the Ollama model from environment
160fn get_ollama_model_name() -> &'static str {
161    // Note: We leak the string to get a 'static lifetime, which is fine since
162    // this is only called when using Ollama and the string is cached
163    static OLLAMA_MODEL: OnceLock<String> = OnceLock::new();
164    let model = OLLAMA_MODEL.get_or_init(|| {
165        env::var("AVOCADODB_OLLAMA_MODEL")
166            .unwrap_or_else(|_| DEFAULT_OLLAMA_MODEL.to_string())
167    });
168    // Safe: OnceLock guarantees this string lives for the program's lifetime
169    unsafe { std::mem::transmute::<&str, &'static str>(model.as_str()) }
170}
171
172/// Get the Ollama embedding dimension based on model name
173fn get_ollama_embedding_dimension() -> usize {
174    let model = get_ollama_model_name();
175    match model {
176        m if m.contains("bge-m3") => 1024,
177        m if m.contains("bge-large") => 1024,
178        m if m.contains("nomic") => 768,
179        m if m.contains("mxbai") => 1024,
180        m if m.contains("minilm") || m.contains("all-minilm") => 384,
181        m if m.contains("snowflake") => 1024,
182        _ => {
183            // Allow explicit dimension override
184            env::var("AVOCADODB_EMBEDDING_DIM")
185                .ok()
186                .and_then(|s| s.parse::<usize>().ok())
187                .unwrap_or(1024) // Default to 1024 for unknown models
188        }
189    }
190}
191
192// OpenAI API request/response types
193#[derive(Debug, Serialize)]
194struct EmbeddingRequest {
195    model: String,
196    input: Vec<String>,
197}
198
199#[derive(Debug, Deserialize)]
200struct EmbeddingResponse {
201    data: Vec<EmbeddingData>,
202}
203
204#[derive(Debug, Deserialize)]
205struct EmbeddingData {
206    embedding: Vec<f32>,
207    index: usize,
208}
209
210/// Embed a single text string
211///
212/// Uses local embeddings by default (no API required).
213/// Set AVOCADODB_EMBEDDING_PROVIDER=openai to use OpenAI.
214///
215/// # Arguments
216///
217/// * `text` - The text to embed
218/// * `provider` - Embedding provider (defaults to Local)
219/// * `api_key` - OpenAI API key (only used if provider is OpenAI)
220///
221/// # Returns
222///
223/// A vector of floats representing the embedding (384 for local, 1536 for OpenAI)
224pub async fn embed_text(
225    text: &str,
226    provider: Option<EmbeddingProvider>,
227    api_key: Option<&str>,
228) -> Result<Vec<f32>> {
229    let results = embed_batch(vec![text], provider, api_key).await?;
230    results.into_iter().next().ok_or_else(|| {
231        Error::Embedding("No embedding returned".to_string())
232    })
233}
234
235/// Embed multiple text strings
236///
237/// Uses local embeddings by default (no API required).
238/// Set AVOCADODB_EMBEDDING_PROVIDER=openai to use OpenAI.
239///
240/// # Arguments
241///
242/// * `texts` - Vector of text strings to embed
243/// * `provider` - Embedding provider (defaults to Local)
244/// * `api_key` - OpenAI API key (only used if provider is OpenAI)
245///
246/// # Returns
247///
248/// A vector of embeddings, in the same order as the input texts
249pub async fn embed_batch(
250    texts: Vec<&str>,
251    provider: Option<EmbeddingProvider>,
252    api_key: Option<&str>,
253) -> Result<Vec<Vec<f32>>> {
254    let provider = provider.unwrap_or_else(EmbeddingProvider::from_env);
255
256    if texts.is_empty() {
257        return Ok(vec![]);
258    }
259
260    match provider {
261        EmbeddingProvider::Local => embed_batch_local(texts).await,
262        EmbeddingProvider::OpenAI => embed_batch_openai(texts, api_key).await,
263        EmbeddingProvider::Remote => embed_batch_remote(texts).await,
264        EmbeddingProvider::Ollama => embed_batch_ollama(texts).await,
265    }
266}
267
268/// Local embedding generation with multiple strategies
269///
270/// Pure Rust implementation using fastembed (ONNX-based) for semantic embeddings.
271/// Falls back to Python subprocess, then hash-based if needed.
272///
273/// Strategy priority:
274/// 1. Pure Rust with fastembed (semantic, good quality, no Python required)
275///    - Uses all-MiniLM-L6-v2 model (384 dimensions)
276///    - ONNX-based, fast and efficient
277///    - Model cached after first download (~90MB)
278/// 2. Python subprocess with sentence-transformers (fallback if fastembed fails)
279///    - Requires: `pip install sentence-transformers`
280/// 3. Hash-based fallback (deterministic, but NOT semantic)
281///    - Works without dependencies
282///    - Poor semantic quality (similar texts don't cluster)
283async fn embed_batch_local(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
284    // Try pure Rust fastembed first (best performance, no Python required)
285    if let Ok(embeddings) = embed_batch_local_rust(texts.clone()).await {
286        return Ok(embeddings);
287    }
288    
289    // Respect hard-fail configuration to forbid non-semantic fallbacks in production
290    if matches!(std::env::var("AVOCADODB_FORBID_FALLBACKS").ok().as_deref(), Some("1" | "true" | "TRUE" | "yes" | "YES")) {
291        return Err(Error::Embedding(
292            "Local fastembed failed and fallbacks are disabled (AVOCADODB_FORBID_FALLBACKS=1)".to_string()
293        ));
294    }
295
296    // Fallback to Python sentence-transformers (if available)
297    static PY_WARN_ONCE: Once = Once::new();
298    PY_WARN_ONCE.call_once(|| {
299        log::warn!("Falling back to Python sentence-transformers for embeddings. Install Rust fastembed for best performance.");
300    });
301    if let Ok(embeddings) = embed_batch_local_python(texts.clone()).await {
302        return Ok(embeddings);
303    }
304    
305    // Final fallback to hash-based embeddings (works always, but not semantic)
306    static HASH_WARN_ONCE: Once = Once::new();
307    HASH_WARN_ONCE.call_once(|| {
308        log::error!("Falling back to HASH-BASED embeddings (NOT SEMANTIC). This mode is for emergencies only.");
309    });
310    embed_batch_local_hash(texts).await
311}
312
313/// Pure Rust embeddings using fastembed (ONNX-based, no Python required)
314///
315/// Uses fastembed crate with all-MiniLM-L6-v2 model for semantic embeddings.
316/// This is the preferred method as it's pure Rust, fast, and doesn't require Python.
317///
318/// Model is downloaded from HuggingFace on first use and cached locally.
319/// fastembed handles model caching internally, so initialization is fast after first use.
320async fn embed_batch_local_rust(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
321    use fastembed::{TextEmbedding, InitOptions};
322    use tokio::task;
323    
324    if texts.is_empty() {
325        return Ok(vec![]);
326    }
327    
328    // Convert &str to String for the blocking task
329    let texts_owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
330    
331    // Cache the TextEmbedding model instance across the process to avoid repeated initialization
332    static FASTEMBED_MODEL: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
333
334    // fastembed is synchronous, so we run it in a blocking task
335    // Note: fastembed handles model caching internally, so initialization is fast
336    let embeddings = task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
337        // Initialize or reuse cached model (downloads on first use, then caches)
338        let model_mutex = FASTEMBED_MODEL.get_or_init(|| {
339            let embedding_model = get_local_embedding_model();
340            let model = TextEmbedding::try_new(
341                InitOptions::new(embedding_model)
342                    .with_show_download_progress(false)
343            )
344            .expect("Failed to initialize fastembed model");
345            Mutex::new(model)
346        });
347
348        // Generate embeddings (fastembed handles normalization)
349        let embeddings = model_mutex
350            .lock()
351            .map_err(|_| Error::Embedding("Failed to lock fastembed model".to_string()))?
352            .embed(texts_owned, None)
353            .map_err(|e| Error::Embedding(format!("Failed to generate embeddings: {}", e)))?;
354        
355        // Verify dimensions (get expected dimension for selected model)
356        let expected_dim = get_local_embedding_dimension();
357        for emb in &embeddings {
358            if emb.len() != expected_dim {
359                return Err(Error::Embedding(format!(
360                    "Unexpected embedding dimension: {} (expected {})",
361                    emb.len(),
362                    expected_dim
363                )));
364            }
365        }
366        
367        Ok(embeddings)
368    })
369    .await
370    .map_err(|e| Error::Embedding(format!("Task join error: {}", e)))??;
371    
372    Ok(embeddings)
373}
374
375/// Local embeddings using Python sentence-transformers (semantic, good quality)
376///
377/// Uses a Python subprocess to call sentence-transformers with all-MiniLM-L6-v2.
378/// This provides semantic embeddings without API keys.
379///
380/// Requires: Python with sentence-transformers installed
381///   pip install sentence-transformers
382async fn embed_batch_local_python(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
383    // Check if Python is available
384    let python = which_python()?;
385    
386    // Create a Python script to generate embeddings
387    let script = format!(r#"
388import sys
389import json
390
391try:
392    from sentence_transformers import SentenceTransformer
393    import numpy as np
394    
395    # Load model (cached after first use)
396    model = SentenceTransformer('all-MiniLM-L6-v2')
397    
398    # Read texts from stdin (one per line)
399    texts = []
400    for line in sys.stdin:
401        texts.append(line.strip())
402    
403    # Generate embeddings
404    embeddings = model.encode(texts, normalize_embeddings=True)
405    
406    # Output as JSON array
407    result = [emb.tolist() for emb in embeddings]
408    print(json.dumps(result))
409    sys.exit(0)
410except ImportError:
411    print(json.dumps({{"error": "sentence-transformers not installed. Install with: pip install sentence-transformers"}}), file=sys.stderr)
412    sys.exit(1)
413except Exception as e:
414    print(json.dumps({{"error": str(e)}}), file=sys.stderr)
415    sys.exit(1)
416"#);
417    
418    // Run Python script
419    let mut child = AsyncCommand::new(&python)
420        .arg("-c")
421        .arg(&script)
422        .stdin(std::process::Stdio::piped())
423        .stdout(std::process::Stdio::piped())
424        .stderr(std::process::Stdio::piped())
425        .spawn()
426        .map_err(|e| Error::Embedding(format!("Failed to spawn Python process: {}", e)))?;
427    
428    // Write texts to stdin
429    if let Some(mut stdin) = child.stdin.take() {
430        for text in &texts {
431            stdin.write_all(text.as_bytes())
432                .await
433                .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
434            stdin.write_all(b"\n")
435                .await
436                .map_err(|e| Error::Embedding(format!("Failed to write to Python stdin: {}", e)))?;
437        }
438        stdin.shutdown().await
439            .map_err(|e| Error::Embedding(format!("Failed to close Python stdin: {}", e)))?;
440    }
441    
442    // Wait for output
443    let output = child.wait_with_output()
444        .await
445        .map_err(|e| Error::Embedding(format!("Failed to wait for Python process: {}", e)))?;
446    
447    if !output.status.success() {
448        let stderr = String::from_utf8_lossy(&output.stderr);
449        return Err(Error::Embedding(format!("Python embedding failed: {}", stderr)));
450    }
451    
452    // Parse JSON output
453    let stdout = String::from_utf8_lossy(&output.stdout);
454    let embeddings: Vec<Vec<f32>> = serde_json::from_str(&stdout)
455        .map_err(|e| Error::Embedding(format!("Failed to parse Python output: {}", e)))?;
456    
457        // Verify dimensions (use default for Python fallback)
458        let expected_dim = get_local_embedding_dimension();
459        for emb in &embeddings {
460            if emb.len() != expected_dim {
461                return Err(Error::Embedding(format!(
462                    "Unexpected embedding dimension: {} (expected {})",
463                    emb.len(),
464                    expected_dim
465                )));
466            }
467        }
468    
469    if embeddings.len() != texts.len() {
470        return Err(Error::Embedding(format!(
471            "Mismatched embedding count: {} embeddings for {} texts",
472            embeddings.len(),
473            texts.len()
474        )));
475    }
476    
477    Ok(embeddings)
478}
479
480/// Find Python executable (python3 or python)
481fn which_python() -> Result<String> {
482    // Try python3 first, then python
483    for cmd in &["python3", "python"] {
484        if std::process::Command::new(cmd)
485            .arg("--version")
486            .output()
487            .is_ok()
488        {
489            return Ok(cmd.to_string());
490        }
491    }
492    Err(Error::Embedding("Python not found. Install Python 3 to use local embeddings.".to_string()))
493}
494
495/// Hash-based embeddings (fallback, NOT semantic)
496///
497/// Deterministic but not semantic - similar texts won't have similar embeddings.
498/// Used as fallback when Python/sentence-transformers is unavailable.
499async fn embed_batch_local_hash(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
500    use std::collections::hash_map::DefaultHasher;
501    use std::hash::{Hash, Hasher};
502    
503    let embeddings: Vec<Vec<f32>> = texts
504        .iter()
505        .map(|text| {
506            let mut hasher = DefaultHasher::new();
507            text.hash(&mut hasher);
508            let hash = hasher.finish();
509            
510            let dim = get_local_embedding_dimension();
511            let mut embedding = vec![0.0f32; dim];
512            for i in 0..dim {
513                let seed = hash.wrapping_add(i as u64);
514                embedding[i] = ((seed % 2000) as f32 - 1000.0) / 1000.0;
515            }
516            
517            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
518            if norm > 0.0 {
519                for x in &mut embedding {
520                    *x /= norm;
521                }
522            }
523            
524            embedding
525        })
526        .collect();
527    
528    Ok(embeddings)
529}
530
531/// OpenAI embedding generation
532async fn embed_batch_openai(
533    texts: Vec<&str>,
534    api_key: Option<&str>,
535) -> Result<Vec<Vec<f32>>> {
536    let api_key = api_key
537        .map(|s| s.to_string())
538        .or_else(|| env::var("OPENAI_API_KEY").ok())
539        .ok_or_else(|| {
540            Error::Embedding(
541                "OPENAI_API_KEY environment variable not set and no API key provided".to_string(),
542            )
543        })?;
544
545    // OpenAI limit is 2048 inputs per request
546    if texts.len() > 2048 {
547        return Err(Error::InvalidInput(format!(
548            "Too many texts to embed at once: {} (max 2048)",
549            texts.len()
550        )));
551    }
552
553    let client = Client::new();
554
555    let request = EmbeddingRequest {
556        model: OPENAI_MODEL.to_string(),
557        input: texts.iter().map(|s| s.to_string()).collect(),
558    };
559
560    let response = client
561        .post(OPENAI_API_URL)
562        .header("Authorization", format!("Bearer {}", api_key))
563        .header("Content-Type", "application/json")
564        .json(&request)
565        .send()
566        .await
567        .map_err(|e| Error::Embedding(format!("API request failed: {}", e)))?;
568
569    if !response.status().is_success() {
570        let status = response.status();
571        let body = response.text().await.unwrap_or_default();
572        return Err(Error::Embedding(format!(
573            "API returned error {}: {}",
574            status, body
575        )));
576    }
577
578    let embedding_response: EmbeddingResponse = response
579        .json()
580        .await
581        .map_err(|e| Error::Embedding(format!("Failed to parse response: {}", e)))?;
582
583    // Sort by index to ensure correct ordering
584    let mut data = embedding_response.data;
585    data.sort_by_key(|d| d.index);
586
587    let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
588
589    // Verify all embeddings have correct dimension
590    for emb in &embeddings {
591        if emb.len() != OPENAI_DIMENSION {
592            return Err(Error::Embedding(format!(
593                "Unexpected embedding dimension: {} (expected {})",
594                emb.len(),
595                OPENAI_DIMENSION
596            )));
597        }
598    }
599
600    Ok(embeddings)
601}
602
603/// Remote HTTP embedding generation
604///
605/// The remote service is configured via:
606/// - AVOCADODB_EMBEDDING_URL: required, e.g. https://your-modal-fn.modal.run/embed
607/// - AVOCADODB_EMBEDDING_API_KEY: optional, sent as Bearer token
608/// - AVOCADODB_EMBEDDING_MODEL: optional, forwarded to remote
609/// - AVOCADODB_EMBEDDING_DIM: optional, expected dimension (defaults to local dim)
610///
611/// Expected request body:
612///   { "inputs": ["text1", "text2"], "model": "BAAI/bge-small-en-v1.5" }
613///
614/// Expected response body (either of the following):
615///   { "embeddings": [[..],[..]], "dimension": 384 }
616///   [[..],[..]]
617async fn embed_batch_remote(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
618    use serde_json::json;
619
620    let url = env::var("AVOCADODB_EMBEDDING_URL")
621        .map_err(|_| Error::Embedding("AVOCADODB_EMBEDDING_URL not set for remote provider".to_string()))?;
622    if texts.is_empty() {
623        return Ok(vec![]);
624    }
625
626    let client = Client::new();
627    let mut req = client.post(&url).header("Content-Type", "application/json");
628
629    if let Ok(api_key) = env::var("AVOCADODB_EMBEDDING_API_KEY") {
630        if !api_key.is_empty() {
631            req = req.header("Authorization", format!("Bearer {}", api_key));
632        }
633    }
634
635    let model = env::var("AVOCADODB_EMBEDDING_MODEL").ok();
636    let body = if let Some(model_name) = model {
637        json!({ "inputs": texts, "model": model_name })
638    } else {
639        json!({ "inputs": texts })
640    };
641
642    let resp = req
643        .json(&body)
644        .send()
645        .await
646        .map_err(|e| Error::Embedding(format!("Remote request failed: {}", e)))?;
647
648    if !resp.status().is_success() {
649        let status = resp.status();
650        let text = resp.text().await.unwrap_or_default();
651        return Err(Error::Embedding(format!("Remote returned error {}: {}", status, text)));
652    }
653
654    // Try to parse as { embeddings: [...], dimension?: N }
655    let expected_dim = EmbeddingProvider::Remote.dimension();
656    let text_body = resp.text().await.map_err(|e| Error::Embedding(format!("Failed reading remote body: {}", e)))?;
657
658    // First attempt: object with embeddings
659    if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text_body) {
660        if let Some(arr) = v.get("embeddings").and_then(|x| x.as_array()) {
661            let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
662            for item in arr {
663                let vec_opt = item.as_array().map(|nums| {
664                    nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
665                });
666                let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
667                if !vec.is_empty() && vec.len() != expected_dim {
668                    // Allow remote to communicate dimension if provided
669                    if let Some(dim) = v.get("dimension").and_then(|d| d.as_u64()).map(|d| d as usize) {
670                        if vec.len() != dim {
671                            return Err(Error::Embedding(format!(
672                                "Unexpected embedding dimension: {} (expected {})",
673                                vec.len(),
674                                expected_dim
675                            )));
676                        }
677                    } else {
678                        return Err(Error::Embedding(format!(
679                            "Unexpected embedding dimension: {} (expected {})",
680                            vec.len(),
681                            expected_dim
682                        )));
683                    }
684                }
685                embeddings.push(vec);
686            }
687            if embeddings.len() != texts.len() {
688                return Err(Error::Embedding(format!(
689                    "Mismatched embedding count: got {}, expected {}",
690                    embeddings.len(),
691                    texts.len()
692                )));
693            }
694            return Ok(embeddings);
695        }
696
697        // Second attempt: top-level array
698        if let Some(arr) = v.as_array() {
699            let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
700            for item in arr {
701                let vec_opt = item.as_array().map(|nums| {
702                    nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect::<Vec<f32>>()
703                });
704                let vec = vec_opt.ok_or_else(|| Error::Embedding("Invalid embeddings array format".to_string()))?;
705                if !vec.is_empty() && vec.len() != expected_dim {
706                    return Err(Error::Embedding(format!(
707                        "Unexpected embedding dimension: {} (expected {})",
708                        vec.len(),
709                        expected_dim
710                    )));
711                }
712                embeddings.push(vec);
713            }
714            if embeddings.len() != texts.len() {
715                return Err(Error::Embedding(format!(
716                    "Mismatched embedding count: got {}, expected {}",
717                    embeddings.len(),
718                    texts.len()
719                )));
720            }
721            return Ok(embeddings);
722        }
723    }
724
725    Err(Error::Embedding("Failed to parse remote embedding response".to_string()))
726}
727
728/// Ollama embedding generation
729///
730/// Uses local Ollama server with configurable model.
731/// Configure via:
732/// - AVOCADODB_OLLAMA_URL: Ollama server URL (default: http://localhost:11434)
733/// - AVOCADODB_OLLAMA_MODEL: Model name (default: bge-m3)
734///
735/// Supports models like:
736/// - bge-m3 (1024 dimensions, multilingual)
737/// - nomic-embed-text (768 dimensions)
738/// - mxbai-embed-large (1024 dimensions)
739/// - all-minilm (384 dimensions)
740async fn embed_batch_ollama(texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
741    use serde_json::json;
742
743    let base_url = env::var("AVOCADODB_OLLAMA_URL")
744        .unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string());
745    let model = get_ollama_model_name();
746    let expected_dim = get_ollama_embedding_dimension();
747
748    if texts.is_empty() {
749        return Ok(vec![]);
750    }
751
752    let client = Client::new();
753
754    // Try batch endpoint first (Ollama 0.4.0+)
755    let url = format!("{}/api/embed", base_url);
756    let body = json!({
757        "model": model,
758        "input": texts,
759    });
760
761    let resp = client
762        .post(&url)
763        .header("Content-Type", "application/json")
764        .json(&body)
765        .send()
766        .await
767        .map_err(|e| Error::Embedding(format!("Ollama request failed: {}", e)))?;
768
769    if resp.status().is_success() {
770        let text_body = resp.text().await
771            .map_err(|e| Error::Embedding(format!("Failed reading Ollama response: {}", e)))?;
772
773        if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text_body) {
774            if let Some(arr) = v.get("embeddings").and_then(|x| x.as_array()) {
775                let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(arr.len());
776                for item in arr {
777                    let vec: Vec<f32> = item.as_array()
778                        .map(|nums| nums.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect())
779                        .ok_or_else(|| Error::Embedding("Invalid embedding array".to_string()))?;
780                    embeddings.push(vec);
781                }
782                if embeddings.len() != texts.len() {
783                    return Err(Error::Embedding(format!(
784                        "Mismatched embedding count: got {}, expected {}",
785                        embeddings.len(),
786                        texts.len()
787                    )));
788                }
789                return Ok(embeddings);
790            }
791        }
792    }
793
794    // Fall back to single-text endpoint for older Ollama versions
795    let url = format!("{}/api/embeddings", base_url);
796    let mut embeddings = Vec::with_capacity(texts.len());
797
798    for text in texts {
799        let body = json!({
800            "model": model,
801            "prompt": text,
802        });
803
804        let resp = client
805            .post(&url)
806            .header("Content-Type", "application/json")
807            .json(&body)
808            .send()
809            .await
810            .map_err(|e| Error::Embedding(format!("Ollama request failed: {}", e)))?;
811
812        if !resp.status().is_success() {
813            let status = resp.status();
814            let body = resp.text().await.unwrap_or_default();
815            return Err(Error::Embedding(format!(
816                "Ollama API error {}: {}",
817                status, body
818            )));
819        }
820
821        let text_body = resp.text().await
822            .map_err(|e| Error::Embedding(format!("Failed reading Ollama response: {}", e)))?;
823
824        let v: serde_json::Value = serde_json::from_str(&text_body)
825            .map_err(|e| Error::Embedding(format!("Failed parsing Ollama response: {}", e)))?;
826
827        let embedding: Vec<f32> = v.get("embedding")
828            .and_then(|e| e.as_array())
829            .map(|arr| arr.iter().filter_map(|n| n.as_f64().map(|f| f as f32)).collect())
830            .ok_or_else(|| Error::Embedding("No embedding in Ollama response".to_string()))?;
831
832        if embedding.len() != expected_dim {
833            return Err(Error::Embedding(format!(
834                "Unexpected embedding dimension: {} (expected {})",
835                embedding.len(),
836                expected_dim
837            )));
838        }
839
840        embeddings.push(embedding);
841    }
842
843    Ok(embeddings)
844}
845
846/// Get the embedding model name (based on current provider)
847pub fn embedding_model() -> &'static str {
848    EmbeddingProvider::from_env().model_name()
849}
850
851/// Get the embedding dimension (based on current provider)
852pub fn embedding_dimension() -> usize {
853    EmbeddingProvider::from_env().dimension()
854}
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859
860    #[test]
861    fn test_embedding_provider_default() {
862        // Default should be local
863        let provider = EmbeddingProvider::default();
864        assert_eq!(provider, EmbeddingProvider::Local);
865        assert_eq!(provider.dimension(), get_local_embedding_dimension());
866    }
867
868    #[test]
869    fn test_embedding_dimensions() {
870        // Default model is AllMiniLML6V2 with 384 dimensions
871        assert_eq!(EmbeddingProvider::Local.dimension(), get_local_embedding_dimension());
872        assert_eq!(EmbeddingProvider::OpenAI.dimension(), 1536);
873    }
874
875    #[tokio::test]
876    async fn test_embed_batch_local() {
877        // Test local embeddings (should work without API key)
878        let texts = vec!["Hello", "World", "Test"];
879        let result = embed_batch_local(texts).await;
880        
881        assert!(result.is_ok());
882        let embeddings = result.unwrap();
883        assert_eq!(embeddings.len(), 3);
884        for emb in embeddings {
885            assert_eq!(emb.len(), get_local_embedding_dimension());
886        }
887    }
888
889    #[tokio::test]
890    #[ignore] // Only run when OPENAI_API_KEY is set
891    async fn test_embed_text_openai() {
892        let result = embed_text("Hello, world!", Some(EmbeddingProvider::OpenAI), None).await;
893        if env::var("OPENAI_API_KEY").is_ok() {
894            let embedding = result.unwrap();
895            assert_eq!(embedding.len(), OPENAI_DIMENSION);
896        } else {
897            assert!(result.is_err());
898        }
899    }
900}