Skip to main content

lean_ctx/core/embeddings/
mod.rs

1//! Embedding engine for semantic code search.
2//!
3//! Provides dense vector embeddings for code chunks using a local ONNX model
4//! (all-MiniLM-L6-v2). Feature-gated under `embeddings` — falls back gracefully
5//! to BM25-only search when the feature or model is not available.
6//!
7//! Architecture:
8//!   WordPieceTokenizer → ONNX Model (rten) → Mean Pooling → L2 Normalize → Vec<f32>
9
10pub mod download;
11pub mod pooling;
12pub mod tokenizer;
13
14use std::path::{Path, PathBuf};
15
16#[cfg(feature = "embeddings")]
17use std::sync::Arc;
18
19use tokenizer::{TokenizedInput, WordPieceTokenizer};
20
21#[cfg(feature = "embeddings")]
22use rten::Model;
23
24#[cfg(feature = "embeddings")]
25const DEFAULT_DIMENSIONS: usize = 384;
26#[cfg(feature = "embeddings")]
27const DEFAULT_MAX_SEQ_LEN: usize = 256;
28
29pub struct EmbeddingEngine {
30    #[cfg(feature = "embeddings")]
31    model: Arc<Model>,
32    tokenizer: WordPieceTokenizer,
33    dimensions: usize,
34    max_seq_len: usize,
35    #[cfg(feature = "embeddings")]
36    input_names: InputNodeIds,
37    #[cfg(feature = "embeddings")]
38    output_id: rten::NodeId,
39}
40
41#[cfg(feature = "embeddings")]
42struct InputNodeIds {
43    input_ids: rten::NodeId,
44    attention_mask: rten::NodeId,
45    token_type_ids: rten::NodeId,
46}
47
48impl EmbeddingEngine {
49    /// Load embedding model and vocabulary from a directory.
50    /// Downloads model automatically from HuggingFace if not present.
51    ///
52    /// Expected files (auto-downloaded):
53    /// - `model.onnx` — all-MiniLM-L6-v2 ONNX embedding model
54    /// - `vocab.txt` — WordPiece vocabulary (one token per line)
55    #[cfg(feature = "embeddings")]
56    pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
57        download::ensure_model(model_dir)?;
58
59        let vocab_path = model_dir.join("vocab.txt");
60        let model_path = model_dir.join("model.onnx");
61
62        let tokenizer = WordPieceTokenizer::from_file(&vocab_path)?;
63        let model = Model::load_file(&model_path)?;
64
65        let model_inputs = model.input_ids();
66        if model_inputs.len() < 3 {
67            anyhow::bail!(
68                "Expected BERT-style model with 3 inputs, got {}",
69                model_inputs.len()
70            );
71        }
72
73        let input_names = InputNodeIds {
74            input_ids: model_inputs[0],
75            attention_mask: model_inputs[1],
76            token_type_ids: model_inputs[2],
77        };
78
79        let output_id = *model
80            .output_ids()
81            .first()
82            .ok_or_else(|| anyhow::anyhow!("Model has no outputs"))?;
83
84        let dimensions = Self::detect_dimensions(&model, &tokenizer, &input_names, output_id)
85            .unwrap_or(DEFAULT_DIMENSIONS);
86
87        tracing::info!(
88            "Embedding engine loaded: {}d, max_seq_len={}",
89            dimensions,
90            DEFAULT_MAX_SEQ_LEN
91        );
92
93        Ok(Self {
94            model: Arc::new(model),
95            tokenizer,
96            dimensions,
97            max_seq_len: DEFAULT_MAX_SEQ_LEN,
98            input_names,
99            output_id,
100        })
101    }
102
103    #[cfg(not(feature = "embeddings"))]
104    pub fn load(_model_dir: &Path) -> anyhow::Result<Self> {
105        anyhow::bail!("Embeddings feature not enabled. Compile with --features embeddings")
106    }
107
108    /// Load from default model directory (~/.lean-ctx/models/).
109    pub fn load_default() -> anyhow::Result<Self> {
110        Self::load(&Self::model_directory())
111    }
112
113    /// Generate an embedding vector for a single text.
114    pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
115        let input = self.tokenizer.encode(text, self.max_seq_len);
116        self.run_inference(&input)
117    }
118
119    /// Generate embedding vectors for multiple texts.
120    pub fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
121        texts.iter().map(|t| self.embed(t)).collect()
122    }
123
124    pub fn dimensions(&self) -> usize {
125        self.dimensions
126    }
127
128    /// Resolve the model directory (respects LEAN_CTX_MODELS_DIR env).
129    pub fn model_directory() -> PathBuf {
130        if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
131            return PathBuf::from(dir);
132        }
133        if let Some(home) = dirs::home_dir() {
134            return home.join(".lean-ctx").join("models");
135        }
136        PathBuf::from("models")
137    }
138
139    /// Check if the model files are present and loadable.
140    pub fn is_available() -> bool {
141        let dir = Self::model_directory();
142        dir.join("model.onnx").exists() && dir.join("vocab.txt").exists()
143    }
144
145    #[cfg(feature = "embeddings")]
146    fn run_inference(&self, input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
147        use rten_tensor::{AsView, NdTensor};
148
149        let seq_len = input.input_ids.len();
150
151        let ids_tensor = NdTensor::from_data([1, seq_len], input.input_ids.clone());
152        let mask_tensor = NdTensor::from_data([1, seq_len], input.attention_mask.clone());
153        let type_tensor = NdTensor::from_data([1, seq_len], input.token_type_ids.clone());
154
155        let inputs = vec![
156            (self.input_names.input_ids, ids_tensor.into()),
157            (self.input_names.attention_mask, mask_tensor.into()),
158            (self.input_names.token_type_ids, type_tensor.into()),
159        ];
160
161        let outputs = self.model.run(inputs, &[self.output_id], None)?;
162
163        let hidden: Vec<f32> = outputs
164            .into_iter()
165            .next()
166            .ok_or_else(|| anyhow::anyhow!("No output from model"))?
167            .into_tensor::<f32>()
168            .ok_or_else(|| anyhow::anyhow!("Model output is not float32"))?
169            .to_vec();
170
171        let mut embedding =
172            pooling::mean_pool(&hidden, &input.attention_mask, seq_len, self.dimensions);
173        pooling::normalize_l2(&mut embedding);
174
175        Ok(embedding)
176    }
177
178    #[cfg(not(feature = "embeddings"))]
179    fn run_inference(&self, _input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
180        anyhow::bail!("Embeddings feature not enabled")
181    }
182
183    /// Detect embedding dimensions by running a dummy inference.
184    #[cfg(feature = "embeddings")]
185    fn detect_dimensions(
186        model: &Model,
187        tokenizer: &WordPieceTokenizer,
188        input_names: &InputNodeIds,
189        output_id: rten::NodeId,
190    ) -> Option<usize> {
191        use rten_tensor::{Layout, NdTensor};
192
193        let dummy = tokenizer.encode("test", 8);
194        let seq_len = dummy.input_ids.len();
195
196        let ids = NdTensor::from_data([1, seq_len], dummy.input_ids);
197        let mask = NdTensor::from_data([1, seq_len], dummy.attention_mask);
198        let types = NdTensor::from_data([1, seq_len], dummy.token_type_ids);
199
200        let inputs = vec![
201            (input_names.input_ids, ids.into()),
202            (input_names.attention_mask, mask.into()),
203            (input_names.token_type_ids, types.into()),
204        ];
205
206        let outputs = model.run(inputs, &[output_id], None).ok()?;
207        let tensor = outputs.into_iter().next()?.into_tensor::<f32>()?;
208        let shape = tensor.shape();
209
210        // Shape is [batch=1, seq_len, dim]
211        shape.last().copied()
212    }
213}
214
215/// Compute cosine similarity between two L2-normalized vectors.
216/// Both vectors must be pre-normalized for correct results.
217pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
218    debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
219    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
220}
221
222/// Compute cosine similarity without requiring pre-normalization.
223pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
224    debug_assert_eq!(a.len(), b.len());
225    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
226    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
227    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
228    if norm_a == 0.0 || norm_b == 0.0 {
229        return 0.0;
230    }
231    dot / (norm_a * norm_b)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn cosine_similarity_identical() {
240        let a = vec![1.0, 0.0, 0.0];
241        let b = vec![1.0, 0.0, 0.0];
242        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
243    }
244
245    #[test]
246    fn cosine_similarity_orthogonal() {
247        let a = vec![1.0, 0.0, 0.0];
248        let b = vec![0.0, 1.0, 0.0];
249        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
250    }
251
252    #[test]
253    fn cosine_similarity_opposite() {
254        let a = vec![1.0, 0.0, 0.0];
255        let b = vec![-1.0, 0.0, 0.0];
256        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
257    }
258
259    #[test]
260    fn cosine_similarity_raw_unnormalized() {
261        let a = vec![3.0, 4.0];
262        let b = vec![3.0, 4.0];
263        assert!((cosine_similarity_raw(&a, &b) - 1.0).abs() < 1e-6);
264    }
265
266    #[test]
267    fn cosine_similarity_raw_zero_vector() {
268        let a = vec![0.0, 0.0];
269        let b = vec![1.0, 2.0];
270        assert_eq!(cosine_similarity_raw(&a, &b), 0.0);
271    }
272
273    #[test]
274    fn model_directory_env_override_and_availability() {
275        let unique = "/tmp/lean_ctx_test_embed_42xyz";
276        std::env::set_var("LEAN_CTX_MODELS_DIR", unique);
277        let dir = EmbeddingEngine::model_directory();
278        assert_eq!(dir.to_string_lossy(), unique);
279        assert!(!EmbeddingEngine::is_available());
280        std::env::remove_var("LEAN_CTX_MODELS_DIR");
281    }
282}