Skip to main content

lean_ctx/core/embeddings/
mod.rs

1//! Embedding engine for semantic code search.
2//!
3//! Provides dense vector embeddings for code chunks using a local ONNX model
4//! (all-MiniLM-L6-v2). Feature-gated under `embeddings` — falls back gracefully
5//! to BM25-only search when the feature or model is not available.
6//!
7//! Architecture:
8//!   WordPieceTokenizer → ONNX Model (rten) → Mean Pooling → L2 Normalize → Vec<f32>
9
10pub mod download;
11pub mod pooling;
12pub mod tokenizer;
13
14use std::path::{Path, PathBuf};
15
16#[cfg(feature = "embeddings")]
17use std::sync::Arc;
18
19use tokenizer::{TokenizedInput, WordPieceTokenizer};
20
21#[cfg(feature = "embeddings")]
22use rten::Model;
23
24#[cfg(feature = "embeddings")]
25const DEFAULT_DIMENSIONS: usize = 384;
26#[cfg(feature = "embeddings")]
27const DEFAULT_MAX_SEQ_LEN: usize = 256;
28
29pub struct EmbeddingEngine {
30    #[cfg(feature = "embeddings")]
31    model: Arc<Model>,
32    tokenizer: WordPieceTokenizer,
33    dimensions: usize,
34    max_seq_len: usize,
35    #[cfg(feature = "embeddings")]
36    input_names: InputNodeIds,
37    #[cfg(feature = "embeddings")]
38    output_id: rten::NodeId,
39}
40
41#[cfg(feature = "embeddings")]
42struct InputNodeIds {
43    input_ids: rten::NodeId,
44    attention_mask: rten::NodeId,
45    token_type_ids: rten::NodeId,
46}
47
48impl EmbeddingEngine {
49    /// Load embedding model and vocabulary from a directory.
50    /// Downloads model automatically from HuggingFace if not present.
51    ///
52    /// Expected files (auto-downloaded):
53    /// - `model.onnx` — all-MiniLM-L6-v2 ONNX embedding model
54    /// - `vocab.txt` — WordPiece vocabulary (one token per line)
55    #[cfg(feature = "embeddings")]
56    pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
57        download::ensure_model(model_dir)?;
58
59        let vocab_path = model_dir.join("vocab.txt");
60        let model_path = model_dir.join("model.onnx");
61
62        let tokenizer = WordPieceTokenizer::from_file(&vocab_path)?;
63        let model = Model::load_file(&model_path)?;
64
65        let model_inputs = model.input_ids();
66        if model_inputs.len() < 3 {
67            anyhow::bail!(
68                "Expected BERT-style model with 3 inputs, got {}",
69                model_inputs.len()
70            );
71        }
72
73        let input_names = InputNodeIds {
74            input_ids: model_inputs[0],
75            attention_mask: model_inputs[1],
76            token_type_ids: model_inputs[2],
77        };
78
79        let output_id = *model
80            .output_ids()
81            .first()
82            .ok_or_else(|| anyhow::anyhow!("Model has no outputs"))?;
83
84        let dimensions =
85            Self::detect_dimensions(&model, &tokenizer, &input_names, output_id)
86                .unwrap_or(DEFAULT_DIMENSIONS);
87
88        tracing::info!(
89            "Embedding engine loaded: {}d, max_seq_len={}",
90            dimensions,
91            DEFAULT_MAX_SEQ_LEN
92        );
93
94        Ok(Self {
95            model: Arc::new(model),
96            tokenizer,
97            dimensions,
98            max_seq_len: DEFAULT_MAX_SEQ_LEN,
99            input_names,
100            output_id,
101        })
102    }
103
104    #[cfg(not(feature = "embeddings"))]
105    pub fn load(_model_dir: &Path) -> anyhow::Result<Self> {
106        anyhow::bail!("Embeddings feature not enabled. Compile with --features embeddings")
107    }
108
109    /// Load from default model directory (~/.lean-ctx/models/).
110    pub fn load_default() -> anyhow::Result<Self> {
111        Self::load(&Self::model_directory())
112    }
113
114    /// Generate an embedding vector for a single text.
115    pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
116        let input = self.tokenizer.encode(text, self.max_seq_len);
117        self.run_inference(&input)
118    }
119
120    /// Generate embedding vectors for multiple texts.
121    pub fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
122        texts.iter().map(|t| self.embed(t)).collect()
123    }
124
125    pub fn dimensions(&self) -> usize {
126        self.dimensions
127    }
128
129    /// Resolve the model directory (respects LEAN_CTX_MODELS_DIR env).
130    pub fn model_directory() -> PathBuf {
131        if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
132            return PathBuf::from(dir);
133        }
134        if let Some(home) = dirs::home_dir() {
135            return home.join(".lean-ctx").join("models");
136        }
137        PathBuf::from("models")
138    }
139
140    /// Check if the model files are present and loadable.
141    pub fn is_available() -> bool {
142        let dir = Self::model_directory();
143        dir.join("model.onnx").exists() && dir.join("vocab.txt").exists()
144    }
145
146    #[cfg(feature = "embeddings")]
147    fn run_inference(&self, input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
148        use rten_tensor::{AsView, NdTensor};
149
150        let seq_len = input.input_ids.len();
151
152        let ids_tensor = NdTensor::from_data([1, seq_len], input.input_ids.clone());
153        let mask_tensor = NdTensor::from_data([1, seq_len], input.attention_mask.clone());
154        let type_tensor = NdTensor::from_data([1, seq_len], input.token_type_ids.clone());
155
156        let inputs = vec![
157            (self.input_names.input_ids, ids_tensor.into()),
158            (self.input_names.attention_mask, mask_tensor.into()),
159            (self.input_names.token_type_ids, type_tensor.into()),
160        ];
161
162        let outputs = self.model.run(inputs, &[self.output_id], None)?;
163
164        let hidden: Vec<f32> = outputs
165            .into_iter()
166            .next()
167            .ok_or_else(|| anyhow::anyhow!("No output from model"))?
168            .into_tensor::<f32>()
169            .ok_or_else(|| anyhow::anyhow!("Model output is not float32"))?
170            .to_vec();
171
172        let mut embedding =
173            pooling::mean_pool(&hidden, &input.attention_mask, seq_len, self.dimensions);
174        pooling::normalize_l2(&mut embedding);
175
176        Ok(embedding)
177    }
178
179    #[cfg(not(feature = "embeddings"))]
180    fn run_inference(&self, _input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
181        anyhow::bail!("Embeddings feature not enabled")
182    }
183
184    /// Detect embedding dimensions by running a dummy inference.
185    #[cfg(feature = "embeddings")]
186    fn detect_dimensions(
187        model: &Model,
188        tokenizer: &WordPieceTokenizer,
189        input_names: &InputNodeIds,
190        output_id: rten::NodeId,
191    ) -> Option<usize> {
192        use rten_tensor::{Layout, NdTensor};
193
194        let dummy = tokenizer.encode("test", 8);
195        let seq_len = dummy.input_ids.len();
196
197        let ids = NdTensor::from_data([1, seq_len], dummy.input_ids);
198        let mask = NdTensor::from_data([1, seq_len], dummy.attention_mask);
199        let types = NdTensor::from_data([1, seq_len], dummy.token_type_ids);
200
201        let inputs = vec![
202            (input_names.input_ids, ids.into()),
203            (input_names.attention_mask, mask.into()),
204            (input_names.token_type_ids, types.into()),
205        ];
206
207        let outputs = model.run(inputs, &[output_id], None).ok()?;
208        let tensor = outputs.into_iter().next()?.into_tensor::<f32>()?;
209        let shape = tensor.shape();
210
211        // Shape is [batch=1, seq_len, dim]
212        shape.last().copied()
213    }
214}
215
216/// Compute cosine similarity between two L2-normalized vectors.
217/// Both vectors must be pre-normalized for correct results.
218pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
219    debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
220    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
221}
222
223/// Compute cosine similarity without requiring pre-normalization.
224pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
225    debug_assert_eq!(a.len(), b.len());
226    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
227    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
228    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
229    if norm_a == 0.0 || norm_b == 0.0 {
230        return 0.0;
231    }
232    dot / (norm_a * norm_b)
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn cosine_similarity_identical() {
241        let a = vec![1.0, 0.0, 0.0];
242        let b = vec![1.0, 0.0, 0.0];
243        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
244    }
245
246    #[test]
247    fn cosine_similarity_orthogonal() {
248        let a = vec![1.0, 0.0, 0.0];
249        let b = vec![0.0, 1.0, 0.0];
250        assert!(cosine_similarity(&a, &b).abs() < 1e-6);
251    }
252
253    #[test]
254    fn cosine_similarity_opposite() {
255        let a = vec![1.0, 0.0, 0.0];
256        let b = vec![-1.0, 0.0, 0.0];
257        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
258    }
259
260    #[test]
261    fn cosine_similarity_raw_unnormalized() {
262        let a = vec![3.0, 4.0];
263        let b = vec![3.0, 4.0];
264        assert!((cosine_similarity_raw(&a, &b) - 1.0).abs() < 1e-6);
265    }
266
267    #[test]
268    fn cosine_similarity_raw_zero_vector() {
269        let a = vec![0.0, 0.0];
270        let b = vec![1.0, 2.0];
271        assert_eq!(cosine_similarity_raw(&a, &b), 0.0);
272    }
273
274    #[test]
275    fn model_directory_env_override_and_availability() {
276        let unique = "/tmp/lean_ctx_test_embed_42xyz";
277        std::env::set_var("LEAN_CTX_MODELS_DIR", unique);
278        let dir = EmbeddingEngine::model_directory();
279        assert_eq!(dir.to_string_lossy(), unique);
280        assert!(!EmbeddingEngine::is_available());
281        std::env::remove_var("LEAN_CTX_MODELS_DIR");
282    }
283}