Skip to main content

lean_ctx/core/embeddings/
mod.rs

1//! Embedding engine for semantic code search.
2//!
3//! Provides dense vector embeddings for code chunks using a local ONNX model.
4//! Supports multiple models via `EmbeddingModel` registry — selected via
5//! `LEAN_CTX_EMBEDDING_MODEL` env var (default: all-MiniLM-L6-v2).
6//!
7//! Feature-gated under `embeddings` — falls back gracefully to BM25-only
8//! search when the feature or model is not available.
9//!
10//! Architecture:
11//!   Tokenizer → ONNX Model (rten) → Mean Pooling → L2 Normalize → `Vec<f32>`
12
13pub mod download;
14pub mod model_registry;
15pub mod pooling;
16pub mod tokenizer;
17
18use std::path::{Path, PathBuf};
19
20use model_registry::{EmbeddingModel, ModelConfig, VocabSource};
21use tokenizer::{TokenizedInput, WordPieceTokenizer};
22
23#[cfg(feature = "embeddings")]
24use std::sync::Arc;
25
26#[cfg(feature = "embeddings")]
27use rten::Model;
28
29pub struct EmbeddingEngine {
30    #[cfg(feature = "embeddings")]
31    model: Arc<Model>,
32    tokenizer: TokenizerKind,
33    dimensions: usize,
34    max_seq_len: usize,
35    model_id: EmbeddingModel,
36    model_config: ModelConfig,
37    #[cfg(feature = "embeddings")]
38    input_names: InputNodeIds,
39    #[cfg(feature = "embeddings")]
40    output_id: rten::NodeId,
41}
42
43/// Abstraction over different tokenizer backends.
44enum TokenizerKind {
45    WordPiece(WordPieceTokenizer),
46    HfTokenizer(tokenizer::HfTokenizerWrapper),
47}
48
49#[cfg(feature = "embeddings")]
50struct InputNodeIds {
51    input_ids: rten::NodeId,
52    attention_mask: rten::NodeId,
53    token_type_ids: Option<rten::NodeId>,
54}
55
56impl EmbeddingEngine {
57    /// Load embedding model and vocabulary from a directory.
58    /// Downloads model automatically from HuggingFace if not present.
59    #[cfg(feature = "embeddings")]
60    pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
61        let selected = model_registry::resolve_model();
62        Self::load_model(model_dir, selected)
63    }
64
65    /// Load a specific embedding model from a directory.
66    #[cfg(feature = "embeddings")]
67    pub fn load_model(base_dir: &Path, model_id: EmbeddingModel) -> anyhow::Result<Self> {
68        let config = model_id.config();
69        let model_dir = base_dir.join(model_id.storage_dir_name());
70
71        download::ensure_model(&model_dir, &config)?;
72
73        let tokenizer = load_tokenizer(&model_dir, &config)?;
74        let model_path = model_dir.join("model.onnx");
75        let model = Model::load_file(&model_path)?;
76
77        let model_inputs = model.input_ids();
78        if model_inputs.len() < 2 {
79            anyhow::bail!(
80                "Expected model with at least 2 inputs (input_ids, attention_mask), got {}",
81                model_inputs.len()
82            );
83        }
84
85        let token_type_ids = if config.needs_token_type_ids {
86            if model_inputs.len() < 3 {
87                anyhow::bail!(
88                    "Model {} requires token_type_ids but only has {} inputs",
89                    config.name,
90                    model_inputs.len()
91                );
92            }
93            Some(model_inputs[2])
94        } else if model_inputs.len() >= 3 {
95            Some(model_inputs[2])
96        } else {
97            None
98        };
99
100        let input_names = InputNodeIds {
101            input_ids: model_inputs[0],
102            attention_mask: model_inputs[1],
103            token_type_ids,
104        };
105
106        let output_id = *model
107            .output_ids()
108            .first()
109            .ok_or_else(|| anyhow::anyhow!("Model has no outputs"))?;
110
111        let dimensions = detect_dimensions(
112            &model,
113            &tokenizer,
114            &input_names,
115            output_id,
116            config.max_seq_len,
117        )
118        .unwrap_or(config.dimensions);
119
120        tracing::info!(
121            "Embedding engine loaded: model={}, {}d, max_seq_len={}",
122            config.name,
123            dimensions,
124            config.max_seq_len,
125        );
126
127        Ok(Self {
128            model: Arc::new(model),
129            tokenizer,
130            dimensions,
131            max_seq_len: config.max_seq_len,
132            model_id,
133            model_config: config,
134            input_names,
135            output_id,
136        })
137    }
138
139    #[cfg(not(feature = "embeddings"))]
140    pub fn load(_model_dir: &Path) -> anyhow::Result<Self> {
141        anyhow::bail!("Embeddings feature not enabled. Compile with --features embeddings")
142    }
143
144    /// Load from default model directory (~/.lean-ctx/models/).
145    pub fn load_default() -> anyhow::Result<Self> {
146        Self::load(&Self::model_directory())
147    }
148
149    /// Generate an embedding vector for a single text (document/code).
150    pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
151        let prefixed;
152        let input_text = if let Some(prefix) = self.model_config.document_prefix {
153            prefixed = format!("{prefix}{text}");
154            &prefixed
155        } else {
156            text
157        };
158        let input = tokenize(&self.tokenizer, input_text, self.max_seq_len);
159        self.run_inference(&input)
160    }
161
162    /// Generate an embedding vector for a query string.
163    /// Applies query-specific prefix if the model requires one.
164    pub fn embed_query(&self, query: &str) -> anyhow::Result<Vec<f32>> {
165        let prefixed;
166        let input_text = if let Some(prefix) = self.model_config.query_prefix {
167            prefixed = format!("{prefix}{query}");
168            &prefixed
169        } else {
170            query
171        };
172        let input = tokenize(&self.tokenizer, input_text, self.max_seq_len);
173        self.run_inference(&input)
174    }
175
176    /// Generate embedding vectors for multiple texts (documents/code).
177    pub fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
178        texts.iter().map(|t| self.embed(t)).collect()
179    }
180
181    pub fn dimensions(&self) -> usize {
182        self.dimensions
183    }
184
185    pub fn model_id(&self) -> EmbeddingModel {
186        self.model_id
187    }
188
189    pub fn model_name(&self) -> &str {
190        self.model_config.name
191    }
192
193    /// Resolve the model directory (respects LEAN_CTX_MODELS_DIR env).
194    pub fn model_directory() -> PathBuf {
195        if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
196            return PathBuf::from(dir);
197        }
198        if let Ok(d) = crate::core::data_dir::lean_ctx_data_dir() {
199            return d.join("models");
200        }
201        PathBuf::from("models")
202    }
203
204    /// Check if the model files are present and loadable.
205    pub fn is_available() -> bool {
206        let base_dir = Self::model_directory();
207        let selected = model_registry::resolve_model();
208        let config = selected.config();
209        let model_dir = base_dir.join(selected.storage_dir_name());
210        model_dir.join("model.onnx").exists()
211            && model_dir.join(config.vocab_file.filename()).exists()
212    }
213
214    #[cfg(feature = "embeddings")]
215    fn run_inference(&self, input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
216        use rten_tensor::{AsView, NdTensor};
217
218        let seq_len = input.input_ids.len();
219
220        let ids_tensor = NdTensor::from_data([1, seq_len], input.input_ids.clone());
221        let mask_tensor = NdTensor::from_data([1, seq_len], input.attention_mask.clone());
222
223        let mut inputs = vec![
224            (self.input_names.input_ids, ids_tensor.into()),
225            (self.input_names.attention_mask, mask_tensor.into()),
226        ];
227
228        if let Some(type_id) = self.input_names.token_type_ids {
229            let type_tensor = NdTensor::from_data([1, seq_len], input.token_type_ids.clone());
230            inputs.push((type_id, type_tensor.into()));
231        }
232
233        let outputs = self.model.run(inputs, &[self.output_id], None)?;
234
235        let hidden: Vec<f32> = outputs
236            .into_iter()
237            .next()
238            .ok_or_else(|| anyhow::anyhow!("No output from model"))?
239            .into_tensor::<f32>()
240            .ok_or_else(|| anyhow::anyhow!("Model output is not float32"))?
241            .to_vec();
242
243        let mut embedding =
244            pooling::mean_pool(&hidden, &input.attention_mask, seq_len, self.dimensions);
245        pooling::normalize_l2(&mut embedding);
246
247        Ok(embedding)
248    }
249
250    #[cfg(not(feature = "embeddings"))]
251    fn run_inference(&self, _input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
252        anyhow::bail!("Embeddings feature not enabled")
253    }
254}
255
256/// Load the appropriate tokenizer for the model config.
257fn load_tokenizer(model_dir: &Path, config: &ModelConfig) -> anyhow::Result<TokenizerKind> {
258    match config.vocab_file {
259        VocabSource::VocabTxt(filename) => {
260            let path = model_dir.join(filename);
261            let tok = WordPieceTokenizer::from_file(&path)?;
262            Ok(TokenizerKind::WordPiece(tok))
263        }
264        VocabSource::TokenizerJson(filename) => {
265            let path = model_dir.join(filename);
266            let tok = tokenizer::HfTokenizerWrapper::from_file(&path)?;
267            Ok(TokenizerKind::HfTokenizer(tok))
268        }
269    }
270}
271
272/// Tokenize text using whatever tokenizer backend is loaded.
273fn tokenize(tokenizer: &TokenizerKind, text: &str, max_len: usize) -> TokenizedInput {
274    match tokenizer {
275        TokenizerKind::WordPiece(wp) => wp.encode(text, max_len),
276        TokenizerKind::HfTokenizer(hf) => hf.encode(text, max_len),
277    }
278}
279
280/// Detect embedding dimensions by running a dummy inference.
281#[cfg(feature = "embeddings")]
282fn detect_dimensions(
283    model: &Model,
284    tokenizer: &TokenizerKind,
285    input_names: &InputNodeIds,
286    output_id: rten::NodeId,
287    max_seq_len: usize,
288) -> Option<usize> {
289    use rten_tensor::{Layout, NdTensor};
290
291    let dummy = tokenize(tokenizer, "test", max_seq_len.min(8));
292    let seq_len = dummy.input_ids.len();
293
294    let ids = NdTensor::from_data([1, seq_len], dummy.input_ids);
295    let mask = NdTensor::from_data([1, seq_len], dummy.attention_mask);
296
297    let mut inputs = vec![
298        (input_names.input_ids, ids.into()),
299        (input_names.attention_mask, mask.into()),
300    ];
301
302    if let Some(type_id) = input_names.token_type_ids {
303        let types = NdTensor::from_data([1, seq_len], dummy.token_type_ids);
304        inputs.push((type_id, types.into()));
305    }
306
307    let outputs = model.run(inputs, &[output_id], None).ok()?;
308    let tensor = outputs.into_iter().next()?.into_tensor::<f32>()?;
309    let shape = tensor.shape();
310
311    // Shape is [batch=1, seq_len, dim]
312    shape.last().copied()
313}
314
315/// Compute cosine similarity between two L2-normalized vectors.
316/// Both vectors must be pre-normalized for correct results.
317///
318/// Uses the chunked, autovectorizable dot product from [`crate::core::embedding_quant`]
319/// (turbovec-derived) so every semantic-search hot path gets SIMD throughput.
320pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
321    debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
322    crate::core::embedding_quant::dot_f32(a, b)
323}
324
325/// Compute cosine similarity without requiring pre-normalization.
326pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
327    debug_assert_eq!(a.len(), b.len());
328    use crate::core::embedding_quant::dot_f32;
329    let dot = dot_f32(a, b);
330    let norm_a = dot_f32(a, a).sqrt();
331    let norm_b = dot_f32(b, b).sqrt();
332    if norm_a == 0.0 || norm_b == 0.0 {
333        return 0.0;
334    }
335    dot / (norm_a * norm_b)
336}
337
338#[cfg(feature = "embeddings")]
339static SHARED_ENGINE: std::sync::OnceLock<anyhow::Result<EmbeddingEngine>> =
340    std::sync::OnceLock::new();
341
342/// Global singleton embedding engine. Loaded once, shared across all consumers.
343/// Returns None if the embeddings feature is disabled or the model fails to load.
344/// NOTE: This function BLOCKS on first call while loading the ONNX model.
345/// For non-blocking access, use `try_shared_engine()` instead.
346#[cfg(feature = "embeddings")]
347pub fn shared_engine() -> Option<&'static EmbeddingEngine> {
348    SHARED_ENGINE
349        .get_or_init(EmbeddingEngine::load_default)
350        .as_ref()
351        .ok()
352}
353
354/// Non-blocking variant: returns the engine ONLY if already loaded.
355/// Never triggers model loading or download. Safe to call on hot paths.
356#[cfg(feature = "embeddings")]
357pub fn try_shared_engine() -> Option<&'static EmbeddingEngine> {
358    SHARED_ENGINE.get()?.as_ref().ok()
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn cosine_similarity_identical() {
367        let a = vec![1.0, 0.0, 0.0];
368        let b = vec![1.0, 0.0, 0.0];
369        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
370    }
371
372    #[test]
373    fn cosine_similarity_orthogonal() {
374        let a = vec![1.0, 0.0, 0.0];
375        let b = vec![0.0, 1.0, 0.0];
376        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
377    }
378
379    #[test]
380    fn cosine_similarity_opposite() {
381        let a = vec![1.0, 0.0, 0.0];
382        let b = vec![-1.0, 0.0, 0.0];
383        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
384    }
385
386    #[test]
387    fn cosine_similarity_raw_unnormalized() {
388        let a = vec![3.0, 4.0];
389        let b = vec![3.0, 4.0];
390        assert!((cosine_similarity_raw(&a, &b) - 1.0).abs() < 1e-6);
391    }
392
393    #[test]
394    fn cosine_similarity_raw_zero_vector() {
395        let a = vec![0.0, 0.0];
396        let b = vec![1.0, 2.0];
397        assert_eq!(cosine_similarity_raw(&a, &b), 0.0);
398    }
399
400    #[test]
401    fn model_directory_env_override_and_availability() {
402        let unique = "/tmp/lean_ctx_test_embed_42xyz";
403        std::env::set_var("LEAN_CTX_MODELS_DIR", unique);
404        let dir = EmbeddingEngine::model_directory();
405        assert_eq!(dir.to_string_lossy(), unique);
406        assert!(!EmbeddingEngine::is_available());
407        std::env::remove_var("LEAN_CTX_MODELS_DIR");
408    }
409
410    #[test]
411    #[cfg(feature = "embeddings")]
412    fn try_shared_engine_returns_none_when_not_initialized() {
413        let result = try_shared_engine();
414        assert!(
415            result.is_none(),
416            "try_shared_engine should return None without triggering load"
417        );
418    }
419}