lean_ctx/core/embeddings/
mod.rs1pub mod download;
14pub mod model_registry;
15pub mod pooling;
16pub mod tokenizer;
17
18use std::path::{Path, PathBuf};
19
20use model_registry::{EmbeddingModel, ModelConfig, VocabSource};
21use tokenizer::{TokenizedInput, WordPieceTokenizer};
22
23#[cfg(feature = "embeddings")]
24use std::sync::Arc;
25
26#[cfg(feature = "embeddings")]
27use rten::Model;
28
29pub struct EmbeddingEngine {
30 #[cfg(feature = "embeddings")]
31 model: Arc<Model>,
32 tokenizer: TokenizerKind,
33 dimensions: usize,
34 max_seq_len: usize,
35 model_id: EmbeddingModel,
36 model_config: ModelConfig,
37 #[cfg(feature = "embeddings")]
38 input_names: InputNodeIds,
39 #[cfg(feature = "embeddings")]
40 output_id: rten::NodeId,
41}
42
43enum TokenizerKind {
45 WordPiece(WordPieceTokenizer),
46 HfTokenizer(tokenizer::HfTokenizerWrapper),
47}
48
49#[cfg(feature = "embeddings")]
50struct InputNodeIds {
51 input_ids: rten::NodeId,
52 attention_mask: rten::NodeId,
53 token_type_ids: Option<rten::NodeId>,
54}
55
56impl EmbeddingEngine {
57 #[cfg(feature = "embeddings")]
60 pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
61 let selected = model_registry::resolve_model();
62 Self::load_model(model_dir, selected)
63 }
64
65 #[cfg(feature = "embeddings")]
67 pub fn load_model(base_dir: &Path, model_id: EmbeddingModel) -> anyhow::Result<Self> {
68 let config = model_id.config();
69 let model_dir = base_dir.join(model_id.storage_dir_name());
70
71 download::ensure_model(&model_dir, &config)?;
72
73 let tokenizer = load_tokenizer(&model_dir, &config)?;
74 let model_path = model_dir.join("model.onnx");
75 let model = Model::load_file(&model_path)?;
76
77 let model_inputs = model.input_ids();
78 if model_inputs.len() < 2 {
79 anyhow::bail!(
80 "Expected model with at least 2 inputs (input_ids, attention_mask), got {}",
81 model_inputs.len()
82 );
83 }
84
85 let token_type_ids = if config.needs_token_type_ids {
86 if model_inputs.len() < 3 {
87 anyhow::bail!(
88 "Model {} requires token_type_ids but only has {} inputs",
89 config.name,
90 model_inputs.len()
91 );
92 }
93 Some(model_inputs[2])
94 } else if model_inputs.len() >= 3 {
95 Some(model_inputs[2])
96 } else {
97 None
98 };
99
100 let input_names = InputNodeIds {
101 input_ids: model_inputs[0],
102 attention_mask: model_inputs[1],
103 token_type_ids,
104 };
105
106 let output_id = *model
107 .output_ids()
108 .first()
109 .ok_or_else(|| anyhow::anyhow!("Model has no outputs"))?;
110
111 let dimensions = detect_dimensions(
112 &model,
113 &tokenizer,
114 &input_names,
115 output_id,
116 config.max_seq_len,
117 )
118 .unwrap_or(config.dimensions);
119
120 tracing::info!(
121 "Embedding engine loaded: model={}, {}d, max_seq_len={}",
122 config.name,
123 dimensions,
124 config.max_seq_len,
125 );
126
127 Ok(Self {
128 model: Arc::new(model),
129 tokenizer,
130 dimensions,
131 max_seq_len: config.max_seq_len,
132 model_id,
133 model_config: config,
134 input_names,
135 output_id,
136 })
137 }
138
139 #[cfg(not(feature = "embeddings"))]
140 pub fn load(_model_dir: &Path) -> anyhow::Result<Self> {
141 anyhow::bail!("Embeddings feature not enabled. Compile with --features embeddings")
142 }
143
144 pub fn load_default() -> anyhow::Result<Self> {
146 Self::load(&Self::model_directory())
147 }
148
149 pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
151 let prefixed;
152 let input_text = if let Some(prefix) = self.model_config.document_prefix {
153 prefixed = format!("{prefix}{text}");
154 &prefixed
155 } else {
156 text
157 };
158 let input = tokenize(&self.tokenizer, input_text, self.max_seq_len);
159 self.run_inference(&input)
160 }
161
162 pub fn embed_query(&self, query: &str) -> anyhow::Result<Vec<f32>> {
165 let prefixed;
166 let input_text = if let Some(prefix) = self.model_config.query_prefix {
167 prefixed = format!("{prefix}{query}");
168 &prefixed
169 } else {
170 query
171 };
172 let input = tokenize(&self.tokenizer, input_text, self.max_seq_len);
173 self.run_inference(&input)
174 }
175
176 pub fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
178 texts.iter().map(|t| self.embed(t)).collect()
179 }
180
181 pub fn dimensions(&self) -> usize {
182 self.dimensions
183 }
184
185 pub fn model_id(&self) -> EmbeddingModel {
186 self.model_id
187 }
188
189 pub fn model_name(&self) -> &str {
190 self.model_config.name
191 }
192
193 pub fn model_directory() -> PathBuf {
195 if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
196 return PathBuf::from(dir);
197 }
198 if let Ok(d) = crate::core::data_dir::lean_ctx_data_dir() {
199 return d.join("models");
200 }
201 PathBuf::from("models")
202 }
203
204 pub fn is_available() -> bool {
206 let base_dir = Self::model_directory();
207 let selected = model_registry::resolve_model();
208 let config = selected.config();
209 let model_dir = base_dir.join(selected.storage_dir_name());
210 model_dir.join("model.onnx").exists()
211 && model_dir.join(config.vocab_file.filename()).exists()
212 }
213
214 #[cfg(feature = "embeddings")]
215 fn run_inference(&self, input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
216 use rten_tensor::{AsView, NdTensor};
217
218 let seq_len = input.input_ids.len();
219
220 let ids_tensor = NdTensor::from_data([1, seq_len], input.input_ids.clone());
221 let mask_tensor = NdTensor::from_data([1, seq_len], input.attention_mask.clone());
222
223 let mut inputs = vec![
224 (self.input_names.input_ids, ids_tensor.into()),
225 (self.input_names.attention_mask, mask_tensor.into()),
226 ];
227
228 if let Some(type_id) = self.input_names.token_type_ids {
229 let type_tensor = NdTensor::from_data([1, seq_len], input.token_type_ids.clone());
230 inputs.push((type_id, type_tensor.into()));
231 }
232
233 let outputs = self.model.run(inputs, &[self.output_id], None)?;
234
235 let hidden: Vec<f32> = outputs
236 .into_iter()
237 .next()
238 .ok_or_else(|| anyhow::anyhow!("No output from model"))?
239 .into_tensor::<f32>()
240 .ok_or_else(|| anyhow::anyhow!("Model output is not float32"))?
241 .to_vec();
242
243 let mut embedding =
244 pooling::mean_pool(&hidden, &input.attention_mask, seq_len, self.dimensions);
245 pooling::normalize_l2(&mut embedding);
246
247 Ok(embedding)
248 }
249
250 #[cfg(not(feature = "embeddings"))]
251 fn run_inference(&self, _input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
252 anyhow::bail!("Embeddings feature not enabled")
253 }
254}
255
256fn load_tokenizer(model_dir: &Path, config: &ModelConfig) -> anyhow::Result<TokenizerKind> {
258 match config.vocab_file {
259 VocabSource::VocabTxt(filename) => {
260 let path = model_dir.join(filename);
261 let tok = WordPieceTokenizer::from_file(&path)?;
262 Ok(TokenizerKind::WordPiece(tok))
263 }
264 VocabSource::TokenizerJson(filename) => {
265 let path = model_dir.join(filename);
266 let tok = tokenizer::HfTokenizerWrapper::from_file(&path)?;
267 Ok(TokenizerKind::HfTokenizer(tok))
268 }
269 }
270}
271
272fn tokenize(tokenizer: &TokenizerKind, text: &str, max_len: usize) -> TokenizedInput {
274 match tokenizer {
275 TokenizerKind::WordPiece(wp) => wp.encode(text, max_len),
276 TokenizerKind::HfTokenizer(hf) => hf.encode(text, max_len),
277 }
278}
279
280#[cfg(feature = "embeddings")]
282fn detect_dimensions(
283 model: &Model,
284 tokenizer: &TokenizerKind,
285 input_names: &InputNodeIds,
286 output_id: rten::NodeId,
287 max_seq_len: usize,
288) -> Option<usize> {
289 use rten_tensor::{Layout, NdTensor};
290
291 let dummy = tokenize(tokenizer, "test", max_seq_len.min(8));
292 let seq_len = dummy.input_ids.len();
293
294 let ids = NdTensor::from_data([1, seq_len], dummy.input_ids);
295 let mask = NdTensor::from_data([1, seq_len], dummy.attention_mask);
296
297 let mut inputs = vec![
298 (input_names.input_ids, ids.into()),
299 (input_names.attention_mask, mask.into()),
300 ];
301
302 if let Some(type_id) = input_names.token_type_ids {
303 let types = NdTensor::from_data([1, seq_len], dummy.token_type_ids);
304 inputs.push((type_id, types.into()));
305 }
306
307 let outputs = model.run(inputs, &[output_id], None).ok()?;
308 let tensor = outputs.into_iter().next()?.into_tensor::<f32>()?;
309 let shape = tensor.shape();
310
311 shape.last().copied()
313}
314
315pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
321 debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
322 crate::core::embedding_quant::dot_f32(a, b)
323}
324
325pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
327 debug_assert_eq!(a.len(), b.len());
328 use crate::core::embedding_quant::dot_f32;
329 let dot = dot_f32(a, b);
330 let norm_a = dot_f32(a, a).sqrt();
331 let norm_b = dot_f32(b, b).sqrt();
332 if norm_a == 0.0 || norm_b == 0.0 {
333 return 0.0;
334 }
335 dot / (norm_a * norm_b)
336}
337
338#[cfg(feature = "embeddings")]
339static SHARED_ENGINE: std::sync::OnceLock<anyhow::Result<EmbeddingEngine>> =
340 std::sync::OnceLock::new();
341
342#[cfg(feature = "embeddings")]
347pub fn shared_engine() -> Option<&'static EmbeddingEngine> {
348 SHARED_ENGINE
349 .get_or_init(EmbeddingEngine::load_default)
350 .as_ref()
351 .ok()
352}
353
354#[cfg(feature = "embeddings")]
357pub fn try_shared_engine() -> Option<&'static EmbeddingEngine> {
358 SHARED_ENGINE.get()?.as_ref().ok()
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn cosine_similarity_identical() {
367 let a = vec![1.0, 0.0, 0.0];
368 let b = vec![1.0, 0.0, 0.0];
369 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
370 }
371
372 #[test]
373 fn cosine_similarity_orthogonal() {
374 let a = vec![1.0, 0.0, 0.0];
375 let b = vec![0.0, 1.0, 0.0];
376 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
377 }
378
379 #[test]
380 fn cosine_similarity_opposite() {
381 let a = vec![1.0, 0.0, 0.0];
382 let b = vec![-1.0, 0.0, 0.0];
383 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
384 }
385
386 #[test]
387 fn cosine_similarity_raw_unnormalized() {
388 let a = vec![3.0, 4.0];
389 let b = vec![3.0, 4.0];
390 assert!((cosine_similarity_raw(&a, &b) - 1.0).abs() < 1e-6);
391 }
392
393 #[test]
394 fn cosine_similarity_raw_zero_vector() {
395 let a = vec![0.0, 0.0];
396 let b = vec![1.0, 2.0];
397 assert_eq!(cosine_similarity_raw(&a, &b), 0.0);
398 }
399
400 #[test]
401 fn model_directory_env_override_and_availability() {
402 let unique = "/tmp/lean_ctx_test_embed_42xyz";
403 std::env::set_var("LEAN_CTX_MODELS_DIR", unique);
404 let dir = EmbeddingEngine::model_directory();
405 assert_eq!(dir.to_string_lossy(), unique);
406 assert!(!EmbeddingEngine::is_available());
407 std::env::remove_var("LEAN_CTX_MODELS_DIR");
408 }
409
410 #[test]
411 #[cfg(feature = "embeddings")]
412 fn try_shared_engine_returns_none_when_not_initialized() {
413 let result = try_shared_engine();
414 assert!(
415 result.is_none(),
416 "try_shared_engine should return None without triggering load"
417 );
418 }
419}