Skip to main content

cqs/
embedder.rs

1//! Embedding generation with ort + tokenizers
2
3use lru::LruCache;
4use ndarray::Array2;
5use once_cell::sync::OnceCell;
6use ort::ep::ExecutionProvider as OrtExecutionProvider;
7use ort::session::Session;
8use std::num::NonZeroUsize;
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11use thiserror::Error;
12
13// Model configuration - E5-base-v2 (full CUDA coverage, no rotary embedding fallback)
14const MODEL_REPO: &str = "intfloat/e5-base-v2";
15const MODEL_FILE: &str = "onnx/model.onnx";
16const TOKENIZER_FILE: &str = "onnx/tokenizer.json";
17
18// blake3 checksums for model verification (empty = skip validation)
19const MODEL_BLAKE3: &str = "";
20const TOKENIZER_BLAKE3: &str = "";
21
22#[derive(Error, Debug)]
23pub enum EmbedderError {
24    #[error("Model not found: {0}")]
25    ModelNotFound(String),
26    #[error("Tokenizer error: {0}")]
27    TokenizerError(String),
28    #[error("Inference failed: {0}")]
29    InferenceFailed(String),
30    #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
31    ChecksumMismatch {
32        path: String,
33        expected: String,
34        actual: String,
35    },
36    #[error("Query cannot be empty")]
37    EmptyQuery,
38    #[error("HuggingFace Hub error: {0}")]
39    HfHubError(String),
40}
41
42impl From<ort::Error> for EmbedderError {
43    fn from(e: ort::Error) -> Self {
44        EmbedderError::InferenceFailed(e.to_string())
45    }
46}
47
48/// A 769-dimensional L2-normalized embedding vector
49///
50/// Embeddings are produced by E5-base-v2 (768-dim) with an
51/// optional 769th dimension for sentiment (-1.0 to +1.0).
52/// Can be compared using cosine similarity (dot product for normalized vectors).
53#[derive(Debug, Clone)]
54pub struct Embedding(Vec<f32>);
55
56/// Standard embedding dimension from model
57pub const MODEL_DIM: usize = 768;
58/// Full embedding dimension with sentiment
59pub const EMBEDDING_DIM: usize = 769;
60
61impl Embedding {
62    /// Create a new embedding from raw vector data
63    pub fn new(data: Vec<f32>) -> Self {
64        Self(data)
65    }
66
67    /// Append sentiment as 769th dimension
68    ///
69    /// Converts a 768-dim model embedding to 769-dim with sentiment.
70    /// Sentiment should be -1.0 (negative) to +1.0 (positive).
71    pub fn with_sentiment(mut self, sentiment: f32) -> Self {
72        debug_assert_eq!(self.0.len(), MODEL_DIM, "Expected 768-dim embedding");
73        self.0.push(sentiment.clamp(-1.0, 1.0));
74        self
75    }
76
77    /// Get the sentiment (769th dimension) if present
78    pub fn sentiment(&self) -> Option<f32> {
79        if self.0.len() == EMBEDDING_DIM {
80            Some(self.0[MODEL_DIM])
81        } else {
82            None
83        }
84    }
85
86    /// Get the embedding as a slice
87    pub fn as_slice(&self) -> &[f32] {
88        &self.0
89    }
90
91    /// Get a reference to the inner Vec (needed for some APIs like hnsw_rs)
92    pub fn as_vec(&self) -> &Vec<f32> {
93        &self.0
94    }
95
96    /// Consume the embedding and return the inner vector
97    pub fn into_inner(self) -> Vec<f32> {
98        self.0
99    }
100
101    /// Get the dimension of the embedding
102    pub fn len(&self) -> usize {
103        self.0.len()
104    }
105
106    /// Check if the embedding is empty
107    pub fn is_empty(&self) -> bool {
108        self.0.is_empty()
109    }
110}
111
112/// Hardware execution provider for inference
113#[derive(Debug, Clone, Copy)]
114pub enum ExecutionProvider {
115    /// NVIDIA CUDA (requires CUDA toolkit)
116    CUDA { device_id: i32 },
117    /// NVIDIA TensorRT (faster than CUDA, requires TensorRT)
118    TensorRT { device_id: i32 },
119    /// CPU fallback (always available)
120    CPU,
121}
122
123impl std::fmt::Display for ExecutionProvider {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        match self {
126            ExecutionProvider::CUDA { device_id } => write!(f, "CUDA (device {})", device_id),
127            ExecutionProvider::TensorRT { device_id } => {
128                write!(f, "TensorRT (device {})", device_id)
129            }
130            ExecutionProvider::CPU => write!(f, "CPU"),
131        }
132    }
133}
134
135/// Text embedding generator using nomic-embed-text-v1.5
136///
137/// Automatically downloads the model from HuggingFace Hub on first use.
138/// Detects GPU availability and uses CUDA/TensorRT when available.
139///
140/// # Example
141///
142/// ```no_run
143/// use cqs::Embedder;
144///
145/// let mut embedder = Embedder::new()?;
146/// let embedding = embedder.embed_query("parse configuration file")?;
147/// println!("Embedding dimension: {}", embedding.len()); // 768
148/// # Ok::<(), anyhow::Error>(())
149/// ```
150pub struct Embedder {
151    /// Lazy-loaded ONNX session (expensive ~500ms init, needs Mutex for run())
152    session: OnceCell<Mutex<Session>>,
153    /// Lazy-loaded tokenizer
154    tokenizer: OnceCell<tokenizers::Tokenizer>,
155    /// Cached model paths
156    model_path: PathBuf,
157    tokenizer_path: PathBuf,
158    provider: ExecutionProvider,
159    max_length: usize,
160    batch_size: usize,
161    /// LRU cache for query embeddings (avoids re-computing same queries)
162    query_cache: Mutex<LruCache<String, Embedding>>,
163}
164
165impl Embedder {
166    /// Create a new embedder, downloading the model if necessary
167    ///
168    /// Automatically detects GPU and uses CUDA/TensorRT when available.
169    /// Falls back to CPU if no GPU is found.
170    ///
171    /// Note: ONNX session is lazy-loaded on first embedding request (~500ms).
172    pub fn new() -> Result<Self, EmbedderError> {
173        let (model_path, tokenizer_path) = ensure_model()?;
174        let provider = select_provider();
175
176        let batch_size = match provider {
177            ExecutionProvider::CPU => 4,
178            _ => 16,
179        };
180
181        let query_cache = Mutex::new(LruCache::new(
182            NonZeroUsize::new(100).expect("100 is non-zero"),
183        ));
184
185        Ok(Self {
186            session: OnceCell::new(),
187            tokenizer: OnceCell::new(),
188            model_path,
189            tokenizer_path,
190            provider,
191            max_length: 512,
192            batch_size,
193            query_cache,
194        })
195    }
196
197    /// Create a CPU-only embedder
198    ///
199    /// Use this for single-query embedding where CPU is faster than GPU
200    /// due to CUDA context setup overhead. GPU only helps for batch embedding.
201    pub fn new_cpu() -> Result<Self, EmbedderError> {
202        let (model_path, tokenizer_path) = ensure_model()?;
203
204        let query_cache = Mutex::new(LruCache::new(
205            NonZeroUsize::new(100).expect("100 is non-zero"),
206        ));
207
208        Ok(Self {
209            session: OnceCell::new(),
210            tokenizer: OnceCell::new(),
211            model_path,
212            tokenizer_path,
213            provider: ExecutionProvider::CPU,
214            max_length: 512,
215            batch_size: 4,
216            query_cache,
217        })
218    }
219
220    /// Get or initialize the ONNX session
221    fn session(&self) -> Result<std::sync::MutexGuard<'_, Session>, EmbedderError> {
222        let session = self
223            .session
224            .get_or_try_init(|| create_session(&self.model_path, self.provider).map(Mutex::new))?;
225        Ok(session.lock().unwrap_or_else(|p| p.into_inner()))
226    }
227
228    /// Get or initialize the tokenizer
229    fn tokenizer(&self) -> Result<&tokenizers::Tokenizer, EmbedderError> {
230        self.tokenizer.get_or_try_init(|| {
231            tokenizers::Tokenizer::from_file(&self.tokenizer_path)
232                .map_err(|e| EmbedderError::TokenizerError(e.to_string()))
233        })
234    }
235
236    /// Count tokens in a text
237    pub fn token_count(&self, text: &str) -> Result<usize, EmbedderError> {
238        let encoding = self
239            .tokenizer()?
240            .encode(text, false)
241            .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
242        Ok(encoding.get_ids().len())
243    }
244
245    /// Split text into overlapping windows of max_tokens with overlap tokens of context.
246    /// Returns Vec of (window_content, window_index).
247    /// If text fits in max_tokens, returns single window with index 0.
248    pub fn split_into_windows(
249        &self,
250        text: &str,
251        max_tokens: usize,
252        overlap: usize,
253    ) -> Result<Vec<(String, u32)>, EmbedderError> {
254        let tokenizer = self.tokenizer()?;
255        let encoding = tokenizer
256            .encode(text, false)
257            .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
258
259        let ids = encoding.get_ids();
260        if ids.len() <= max_tokens {
261            return Ok(vec![(text.to_string(), 0)]);
262        }
263
264        let mut windows = Vec::new();
265        let step = max_tokens.saturating_sub(overlap).max(1); // Ensure step >= 1 to prevent infinite loop
266        let mut start = 0;
267        let mut window_idx = 0u32;
268
269        while start < ids.len() {
270            let end = (start + max_tokens).min(ids.len());
271            let window_ids: Vec<u32> = ids[start..end].to_vec();
272
273            // Decode back to text
274            let window_text = tokenizer
275                .decode(&window_ids, true)
276                .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
277
278            windows.push((window_text, window_idx));
279            window_idx += 1;
280
281            if end >= ids.len() {
282                break;
283            }
284            start += step;
285        }
286
287        Ok(windows)
288    }
289
290    /// Embed documents (code chunks). Adds "passage: " prefix for E5.
291    pub fn embed_documents(&mut self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedderError> {
292        let prefixed: Vec<String> = texts.iter().map(|t| format!("passage: {}", t)).collect();
293        self.embed_batch(&prefixed)
294    }
295
296    /// Embed a query. Adds "query: " prefix for E5. Uses LRU cache for repeated queries.
297    pub fn embed_query(&mut self, text: &str) -> Result<Embedding, EmbedderError> {
298        let text = text.trim();
299        if text.is_empty() {
300            return Err(EmbedderError::EmptyQuery);
301        }
302
303        // Check cache first
304        {
305            let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
306                tracing::debug!("Query cache lock poisoned, recovering");
307                poisoned.into_inner()
308            });
309            if let Some(cached) = cache.get(text) {
310                return Ok(cached.clone());
311            }
312        }
313
314        // Compute embedding
315        let prefixed = format!("query: {}", text);
316        let results = self.embed_batch(&[prefixed])?;
317        let base_embedding = results
318            .into_iter()
319            .next()
320            .expect("embed_batch with single item always returns one result");
321
322        // Add neutral sentiment (0.0) as 769th dimension
323        let embedding = base_embedding.with_sentiment(0.0);
324
325        // Store in cache
326        {
327            let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
328                tracing::debug!("Query cache lock poisoned, recovering");
329                poisoned.into_inner()
330            });
331            cache.put(text.to_string(), embedding.clone());
332        }
333
334        Ok(embedding)
335    }
336
337    /// Get the execution provider being used
338    pub fn provider(&self) -> ExecutionProvider {
339        self.provider
340    }
341
342    /// Get the batch size
343    pub fn batch_size(&self) -> usize {
344        self.batch_size
345    }
346
347    /// Warm up the model with a dummy inference
348    pub fn warm(&mut self) -> Result<(), EmbedderError> {
349        let _ = self.embed_query("warmup")?;
350        Ok(())
351    }
352
353    fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Embedding>, EmbedderError> {
354        use ort::value::Tensor;
355
356        let _span = tracing::info_span!("embed_batch", count = texts.len()).entered();
357
358        if texts.is_empty() {
359            return Ok(vec![]);
360        }
361
362        // Tokenize (lazy init tokenizer)
363        let encodings = self
364            .tokenizer()?
365            .encode_batch(texts.to_vec(), true)
366            .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
367
368        // Prepare inputs - INT64 (i64) for ONNX model
369        let input_ids: Vec<Vec<i64>> = encodings
370            .iter()
371            .map(|e| e.get_ids().iter().map(|&id| id as i64).collect())
372            .collect();
373        let attention_mask: Vec<Vec<i64>> = encodings
374            .iter()
375            .map(|e| e.get_attention_mask().iter().map(|&m| m as i64).collect())
376            .collect();
377
378        // Pad to max length in batch
379        let max_len = input_ids
380            .iter()
381            .map(|v| v.len())
382            .max()
383            .unwrap_or(0)
384            .min(self.max_length);
385
386        // Create padded arrays
387        let input_ids_arr = pad_2d_i64(&input_ids, max_len, 0);
388        let attention_mask_arr = pad_2d_i64(&attention_mask, max_len, 0);
389        // token_type_ids: all zeros, same shape as input_ids
390        let token_type_ids_arr = Array2::<i64>::zeros((texts.len(), max_len));
391
392        // Create tensors
393        let input_ids_tensor = Tensor::from_array(input_ids_arr)?;
394        let attention_mask_tensor = Tensor::from_array(attention_mask_arr)?;
395        let token_type_ids_tensor = Tensor::from_array(token_type_ids_arr)?;
396
397        // Run inference (lazy init session)
398        let mut session = self.session()?;
399        let outputs = session.run(ort::inputs![
400            "input_ids" => input_ids_tensor,
401            "attention_mask" => attention_mask_tensor,
402            "token_type_ids" => token_type_ids_tensor,
403        ])?;
404
405        // Get the last_hidden_state output: shape [batch, seq_len, 768]
406        let (_shape, data) = outputs["last_hidden_state"].try_extract_tensor::<f32>()?;
407
408        // Mean pooling over sequence dimension, weighted by attention mask
409        let batch_size = texts.len();
410        let seq_len = max_len;
411        let embedding_dim = 768;
412        let mut results = Vec::with_capacity(batch_size);
413
414        for (i, mask_vec) in attention_mask.iter().enumerate().take(batch_size) {
415            let mut sum = vec![0.0f32; embedding_dim];
416            let mut count = 0.0f32;
417
418            for j in 0..seq_len {
419                let mask = mask_vec.get(j).copied().unwrap_or(0) as f32;
420                if mask > 0.0 {
421                    count += mask;
422                    let offset = i * seq_len * embedding_dim + j * embedding_dim;
423                    for (k, sum_val) in sum.iter_mut().enumerate() {
424                        *sum_val += data[offset + k] * mask;
425                    }
426                }
427            }
428
429            // Avoid division by zero
430            if count > 0.0 {
431                for sum_val in &mut sum {
432                    *sum_val /= count;
433                }
434            }
435
436            results.push(Embedding::new(normalize_l2(sum)));
437        }
438
439        Ok(results)
440    }
441}
442
443/// Download model and tokenizer from HuggingFace Hub
444fn ensure_model() -> Result<(PathBuf, PathBuf), EmbedderError> {
445    use hf_hub::api::sync::Api;
446
447    let api = Api::new().map_err(|e| EmbedderError::HfHubError(e.to_string()))?;
448    let repo = api.model(MODEL_REPO.to_string());
449
450    let model_path = repo
451        .get(MODEL_FILE)
452        .map_err(|e| EmbedderError::HfHubError(e.to_string()))?;
453    let tokenizer_path = repo
454        .get(TOKENIZER_FILE)
455        .map_err(|e| EmbedderError::HfHubError(e.to_string()))?;
456
457    // Verify checksums (skip if not configured)
458    if !MODEL_BLAKE3.is_empty() {
459        verify_checksum(&model_path, MODEL_BLAKE3)?;
460    }
461    if !TOKENIZER_BLAKE3.is_empty() {
462        verify_checksum(&tokenizer_path, TOKENIZER_BLAKE3)?;
463    }
464
465    Ok((model_path, tokenizer_path))
466}
467
468/// Verify file checksum using blake3
469fn verify_checksum(path: &Path, expected: &str) -> Result<(), EmbedderError> {
470    let mut file =
471        std::fs::File::open(path).map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
472    let mut hasher = blake3::Hasher::new();
473    std::io::copy(&mut file, &mut hasher)
474        .map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
475    let actual = hasher.finalize().to_hex().to_string();
476
477    if actual != expected {
478        return Err(EmbedderError::ChecksumMismatch {
479            path: path.display().to_string(),
480            expected: expected.to_string(),
481            actual,
482        });
483    }
484    Ok(())
485}
486
487/// Ensure ort CUDA provider libraries are findable
488///
489/// The ort crate downloads provider libs to ~/.cache/ort.pyke.io/... but
490/// doesn't add them to the library search path. This function creates
491/// symlinks in a directory that's already in LD_LIBRARY_PATH.
492fn ensure_ort_provider_libs() {
493    // Find ort's download cache
494    let home = match std::env::var("HOME") {
495        Ok(h) => std::path::PathBuf::from(h),
496        Err(_) => return,
497    };
498    let ort_cache = home.join(".cache/ort.pyke.io/dfbin/x86_64-unknown-linux-gnu");
499
500    // Find the versioned subdirectory (hash-named)
501    let ort_lib_dir = match std::fs::read_dir(&ort_cache) {
502        Ok(entries) => entries
503            .filter_map(|e| e.ok())
504            .filter(|e| e.path().is_dir())
505            .map(|e| e.path())
506            .next(),
507        Err(_) => return,
508    };
509
510    let ort_lib_dir = match ort_lib_dir {
511        Some(d) => d,
512        None => return,
513    };
514
515    // Find target directory from LD_LIBRARY_PATH (skip ort cache dirs to avoid self-symlinks)
516    let ld_path = std::env::var("LD_LIBRARY_PATH").unwrap_or_default();
517    let ort_cache_str = ort_cache.to_string_lossy();
518    let target_dir = ld_path
519        .split(':')
520        .find(|p| {
521            !p.is_empty() && std::path::Path::new(p).is_dir() && !p.contains(ort_cache_str.as_ref())
522            // Don't symlink into ort's own cache
523        })
524        .map(std::path::PathBuf::from);
525
526    let target_dir = match target_dir {
527        Some(d) => d,
528        None => return, // No writable lib dir in path (or only ort cache in path)
529    };
530
531    // Provider libs to symlink
532    let provider_libs = [
533        "libonnxruntime_providers_shared.so",
534        "libonnxruntime_providers_cuda.so",
535        "libonnxruntime_providers_tensorrt.so",
536    ];
537
538    for lib in &provider_libs {
539        let src = ort_lib_dir.join(lib);
540        let dst = target_dir.join(lib);
541
542        // Skip if source doesn't exist
543        if !src.exists() {
544            continue;
545        }
546
547        // Skip if symlink already valid
548        if dst.symlink_metadata().is_ok() {
549            if let Ok(target) = std::fs::read_link(&dst) {
550                if target == src {
551                    continue; // Already correct
552                }
553            }
554            // Remove stale symlink
555            let _ = std::fs::remove_file(&dst);
556        }
557
558        // Create symlink
559        if let Err(e) = std::os::unix::fs::symlink(&src, &dst) {
560            tracing::debug!("Failed to symlink {}: {}", lib, e);
561        } else {
562            tracing::info!("Created symlink: {} -> {}", dst.display(), src.display());
563        }
564    }
565}
566
567/// Select the best available execution provider
568fn select_provider() -> ExecutionProvider {
569    use ort::ep::{TensorRT, CUDA};
570
571    // Ensure provider libs are findable before checking availability
572    ensure_ort_provider_libs();
573
574    // Try CUDA first
575    let cuda = CUDA::default();
576    if cuda.is_available().unwrap_or(false) {
577        return ExecutionProvider::CUDA { device_id: 0 };
578    }
579
580    // Try TensorRT
581    let tensorrt = TensorRT::default();
582    if tensorrt.is_available().unwrap_or(false) {
583        return ExecutionProvider::TensorRT { device_id: 0 };
584    }
585
586    ExecutionProvider::CPU
587}
588
589/// Create an ort session with the specified provider
590fn create_session(
591    model_path: &Path,
592    provider: ExecutionProvider,
593) -> Result<Session, EmbedderError> {
594    use ort::ep::{TensorRT, CUDA};
595
596    let builder = Session::builder()?;
597
598    let session = match provider {
599        ExecutionProvider::CUDA { device_id } => builder
600            .with_execution_providers([CUDA::default().with_device_id(device_id).build()])?
601            .commit_from_file(model_path)?,
602        ExecutionProvider::TensorRT { device_id } => {
603            builder
604                .with_execution_providers([
605                    TensorRT::default().with_device_id(device_id).build(),
606                    // Fallback to CUDA for unsupported ops
607                    CUDA::default().with_device_id(device_id).build(),
608                ])?
609                .commit_from_file(model_path)?
610        }
611        ExecutionProvider::CPU => builder.commit_from_file(model_path)?,
612    };
613
614    Ok(session)
615}
616
617/// Pad 2D sequences to a fixed length
618fn pad_2d_i64(inputs: &[Vec<i64>], max_len: usize, pad_value: i64) -> Array2<i64> {
619    let batch_size = inputs.len();
620    let mut arr = Array2::from_elem((batch_size, max_len), pad_value);
621    for (i, seq) in inputs.iter().enumerate() {
622        for (j, &val) in seq.iter().take(max_len).enumerate() {
623            arr[[i, j]] = val;
624        }
625    }
626    arr
627}
628
629/// L2 normalize a vector (single-pass, in-place)
630fn normalize_l2(mut v: Vec<f32>) -> Vec<f32> {
631    let norm_sq: f32 = v.iter().fold(0.0, |acc, &x| acc + x * x);
632    if norm_sq > 0.0 {
633        let inv_norm = 1.0 / norm_sq.sqrt();
634        v.iter_mut().for_each(|x| *x *= inv_norm);
635    }
636    v
637}