lean_ctx/core/embeddings/
mod.rs1pub mod download;
11pub mod pooling;
12pub mod tokenizer;
13
14use std::path::{Path, PathBuf};
15
16#[cfg(feature = "embeddings")]
17use std::sync::Arc;
18
19use tokenizer::{TokenizedInput, WordPieceTokenizer};
20
21#[cfg(feature = "embeddings")]
22use rten::Model;
23
24#[cfg(feature = "embeddings")]
25const DEFAULT_DIMENSIONS: usize = 384;
26#[cfg(feature = "embeddings")]
27const DEFAULT_MAX_SEQ_LEN: usize = 256;
28
29pub struct EmbeddingEngine {
30 #[cfg(feature = "embeddings")]
31 model: Arc<Model>,
32 tokenizer: WordPieceTokenizer,
33 dimensions: usize,
34 max_seq_len: usize,
35 #[cfg(feature = "embeddings")]
36 input_names: InputNodeIds,
37 #[cfg(feature = "embeddings")]
38 output_id: rten::NodeId,
39}
40
41#[cfg(feature = "embeddings")]
42struct InputNodeIds {
43 input_ids: rten::NodeId,
44 attention_mask: rten::NodeId,
45 token_type_ids: rten::NodeId,
46}
47
48impl EmbeddingEngine {
49 #[cfg(feature = "embeddings")]
56 pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
57 download::ensure_model(model_dir)?;
58
59 let vocab_path = model_dir.join("vocab.txt");
60 let model_path = model_dir.join("model.onnx");
61
62 let tokenizer = WordPieceTokenizer::from_file(&vocab_path)?;
63 let model = Model::load_file(&model_path)?;
64
65 let model_inputs = model.input_ids();
66 if model_inputs.len() < 3 {
67 anyhow::bail!(
68 "Expected BERT-style model with 3 inputs, got {}",
69 model_inputs.len()
70 );
71 }
72
73 let input_names = InputNodeIds {
74 input_ids: model_inputs[0],
75 attention_mask: model_inputs[1],
76 token_type_ids: model_inputs[2],
77 };
78
79 let output_id = *model
80 .output_ids()
81 .first()
82 .ok_or_else(|| anyhow::anyhow!("Model has no outputs"))?;
83
84 let dimensions = Self::detect_dimensions(&model, &tokenizer, &input_names, output_id)
85 .unwrap_or(DEFAULT_DIMENSIONS);
86
87 tracing::info!(
88 "Embedding engine loaded: {}d, max_seq_len={}",
89 dimensions,
90 DEFAULT_MAX_SEQ_LEN
91 );
92
93 Ok(Self {
94 model: Arc::new(model),
95 tokenizer,
96 dimensions,
97 max_seq_len: DEFAULT_MAX_SEQ_LEN,
98 input_names,
99 output_id,
100 })
101 }
102
103 #[cfg(not(feature = "embeddings"))]
104 pub fn load(_model_dir: &Path) -> anyhow::Result<Self> {
105 anyhow::bail!("Embeddings feature not enabled. Compile with --features embeddings")
106 }
107
108 pub fn load_default() -> anyhow::Result<Self> {
110 Self::load(&Self::model_directory())
111 }
112
113 pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
115 let input = self.tokenizer.encode(text, self.max_seq_len);
116 self.run_inference(&input)
117 }
118
119 pub fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
121 texts.iter().map(|t| self.embed(t)).collect()
122 }
123
124 pub fn dimensions(&self) -> usize {
125 self.dimensions
126 }
127
128 pub fn model_directory() -> PathBuf {
130 if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
131 return PathBuf::from(dir);
132 }
133 if let Some(home) = dirs::home_dir() {
134 return home.join(".lean-ctx").join("models");
135 }
136 PathBuf::from("models")
137 }
138
139 pub fn is_available() -> bool {
141 let dir = Self::model_directory();
142 dir.join("model.onnx").exists() && dir.join("vocab.txt").exists()
143 }
144
145 #[cfg(feature = "embeddings")]
146 fn run_inference(&self, input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
147 use rten_tensor::{AsView, NdTensor};
148
149 let seq_len = input.input_ids.len();
150
151 let ids_tensor = NdTensor::from_data([1, seq_len], input.input_ids.clone());
152 let mask_tensor = NdTensor::from_data([1, seq_len], input.attention_mask.clone());
153 let type_tensor = NdTensor::from_data([1, seq_len], input.token_type_ids.clone());
154
155 let inputs = vec![
156 (self.input_names.input_ids, ids_tensor.into()),
157 (self.input_names.attention_mask, mask_tensor.into()),
158 (self.input_names.token_type_ids, type_tensor.into()),
159 ];
160
161 let outputs = self.model.run(inputs, &[self.output_id], None)?;
162
163 let hidden: Vec<f32> = outputs
164 .into_iter()
165 .next()
166 .ok_or_else(|| anyhow::anyhow!("No output from model"))?
167 .into_tensor::<f32>()
168 .ok_or_else(|| anyhow::anyhow!("Model output is not float32"))?
169 .to_vec();
170
171 let mut embedding =
172 pooling::mean_pool(&hidden, &input.attention_mask, seq_len, self.dimensions);
173 pooling::normalize_l2(&mut embedding);
174
175 Ok(embedding)
176 }
177
178 #[cfg(not(feature = "embeddings"))]
179 fn run_inference(&self, _input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
180 anyhow::bail!("Embeddings feature not enabled")
181 }
182
183 #[cfg(feature = "embeddings")]
185 fn detect_dimensions(
186 model: &Model,
187 tokenizer: &WordPieceTokenizer,
188 input_names: &InputNodeIds,
189 output_id: rten::NodeId,
190 ) -> Option<usize> {
191 use rten_tensor::{Layout, NdTensor};
192
193 let dummy = tokenizer.encode("test", 8);
194 let seq_len = dummy.input_ids.len();
195
196 let ids = NdTensor::from_data([1, seq_len], dummy.input_ids);
197 let mask = NdTensor::from_data([1, seq_len], dummy.attention_mask);
198 let types = NdTensor::from_data([1, seq_len], dummy.token_type_ids);
199
200 let inputs = vec![
201 (input_names.input_ids, ids.into()),
202 (input_names.attention_mask, mask.into()),
203 (input_names.token_type_ids, types.into()),
204 ];
205
206 let outputs = model.run(inputs, &[output_id], None).ok()?;
207 let tensor = outputs.into_iter().next()?.into_tensor::<f32>()?;
208 let shape = tensor.shape();
209
210 shape.last().copied()
212 }
213}
214
215pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
218 debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
219 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
220}
221
222pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
224 debug_assert_eq!(a.len(), b.len());
225 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
226 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
227 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
228 if norm_a == 0.0 || norm_b == 0.0 {
229 return 0.0;
230 }
231 dot / (norm_a * norm_b)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn cosine_similarity_identical() {
240 let a = vec![1.0, 0.0, 0.0];
241 let b = vec![1.0, 0.0, 0.0];
242 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
243 }
244
245 #[test]
246 fn cosine_similarity_orthogonal() {
247 let a = vec![1.0, 0.0, 0.0];
248 let b = vec![0.0, 1.0, 0.0];
249 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
250 }
251
252 #[test]
253 fn cosine_similarity_opposite() {
254 let a = vec![1.0, 0.0, 0.0];
255 let b = vec![-1.0, 0.0, 0.0];
256 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
257 }
258
259 #[test]
260 fn cosine_similarity_raw_unnormalized() {
261 let a = vec![3.0, 4.0];
262 let b = vec![3.0, 4.0];
263 assert!((cosine_similarity_raw(&a, &b) - 1.0).abs() < 1e-6);
264 }
265
266 #[test]
267 fn cosine_similarity_raw_zero_vector() {
268 let a = vec![0.0, 0.0];
269 let b = vec![1.0, 2.0];
270 assert_eq!(cosine_similarity_raw(&a, &b), 0.0);
271 }
272
273 #[test]
274 fn model_directory_env_override_and_availability() {
275 let unique = "/tmp/lean_ctx_test_embed_42xyz";
276 std::env::set_var("LEAN_CTX_MODELS_DIR", unique);
277 let dir = EmbeddingEngine::model_directory();
278 assert_eq!(dir.to_string_lossy(), unique);
279 assert!(!EmbeddingEngine::is_available());
280 std::env::remove_var("LEAN_CTX_MODELS_DIR");
281 }
282}