lean_ctx/core/embeddings/
mod.rs1pub mod download;
11pub mod pooling;
12pub mod tokenizer;
13
14use std::path::{Path, PathBuf};
15
16#[cfg(feature = "embeddings")]
17use std::sync::Arc;
18
19use tokenizer::{TokenizedInput, WordPieceTokenizer};
20
21#[cfg(feature = "embeddings")]
22use rten::Model;
23
24#[cfg(feature = "embeddings")]
25const DEFAULT_DIMENSIONS: usize = 384;
26#[cfg(feature = "embeddings")]
27const DEFAULT_MAX_SEQ_LEN: usize = 256;
28
29pub struct EmbeddingEngine {
30 #[cfg(feature = "embeddings")]
31 model: Arc<Model>,
32 tokenizer: WordPieceTokenizer,
33 dimensions: usize,
34 max_seq_len: usize,
35 #[cfg(feature = "embeddings")]
36 input_names: InputNodeIds,
37 #[cfg(feature = "embeddings")]
38 output_id: rten::NodeId,
39}
40
41#[cfg(feature = "embeddings")]
42struct InputNodeIds {
43 input_ids: rten::NodeId,
44 attention_mask: rten::NodeId,
45 token_type_ids: rten::NodeId,
46}
47
48impl EmbeddingEngine {
49 #[cfg(feature = "embeddings")]
56 pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
57 download::ensure_model(model_dir)?;
58
59 let vocab_path = model_dir.join("vocab.txt");
60 let model_path = model_dir.join("model.onnx");
61
62 let tokenizer = WordPieceTokenizer::from_file(&vocab_path)?;
63 let model = Model::load_file(&model_path)?;
64
65 let model_inputs = model.input_ids();
66 if model_inputs.len() < 3 {
67 anyhow::bail!(
68 "Expected BERT-style model with 3 inputs, got {}",
69 model_inputs.len()
70 );
71 }
72
73 let input_names = InputNodeIds {
74 input_ids: model_inputs[0],
75 attention_mask: model_inputs[1],
76 token_type_ids: model_inputs[2],
77 };
78
79 let output_id = *model
80 .output_ids()
81 .first()
82 .ok_or_else(|| anyhow::anyhow!("Model has no outputs"))?;
83
84 let dimensions =
85 Self::detect_dimensions(&model, &tokenizer, &input_names, output_id)
86 .unwrap_or(DEFAULT_DIMENSIONS);
87
88 tracing::info!(
89 "Embedding engine loaded: {}d, max_seq_len={}",
90 dimensions,
91 DEFAULT_MAX_SEQ_LEN
92 );
93
94 Ok(Self {
95 model: Arc::new(model),
96 tokenizer,
97 dimensions,
98 max_seq_len: DEFAULT_MAX_SEQ_LEN,
99 input_names,
100 output_id,
101 })
102 }
103
104 #[cfg(not(feature = "embeddings"))]
105 pub fn load(_model_dir: &Path) -> anyhow::Result<Self> {
106 anyhow::bail!("Embeddings feature not enabled. Compile with --features embeddings")
107 }
108
109 pub fn load_default() -> anyhow::Result<Self> {
111 Self::load(&Self::model_directory())
112 }
113
114 pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
116 let input = self.tokenizer.encode(text, self.max_seq_len);
117 self.run_inference(&input)
118 }
119
120 pub fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
122 texts.iter().map(|t| self.embed(t)).collect()
123 }
124
125 pub fn dimensions(&self) -> usize {
126 self.dimensions
127 }
128
129 pub fn model_directory() -> PathBuf {
131 if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
132 return PathBuf::from(dir);
133 }
134 if let Some(home) = dirs::home_dir() {
135 return home.join(".lean-ctx").join("models");
136 }
137 PathBuf::from("models")
138 }
139
140 pub fn is_available() -> bool {
142 let dir = Self::model_directory();
143 dir.join("model.onnx").exists() && dir.join("vocab.txt").exists()
144 }
145
146 #[cfg(feature = "embeddings")]
147 fn run_inference(&self, input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
148 use rten_tensor::{AsView, NdTensor};
149
150 let seq_len = input.input_ids.len();
151
152 let ids_tensor = NdTensor::from_data([1, seq_len], input.input_ids.clone());
153 let mask_tensor = NdTensor::from_data([1, seq_len], input.attention_mask.clone());
154 let type_tensor = NdTensor::from_data([1, seq_len], input.token_type_ids.clone());
155
156 let inputs = vec![
157 (self.input_names.input_ids, ids_tensor.into()),
158 (self.input_names.attention_mask, mask_tensor.into()),
159 (self.input_names.token_type_ids, type_tensor.into()),
160 ];
161
162 let outputs = self.model.run(inputs, &[self.output_id], None)?;
163
164 let hidden: Vec<f32> = outputs
165 .into_iter()
166 .next()
167 .ok_or_else(|| anyhow::anyhow!("No output from model"))?
168 .into_tensor::<f32>()
169 .ok_or_else(|| anyhow::anyhow!("Model output is not float32"))?
170 .to_vec();
171
172 let mut embedding =
173 pooling::mean_pool(&hidden, &input.attention_mask, seq_len, self.dimensions);
174 pooling::normalize_l2(&mut embedding);
175
176 Ok(embedding)
177 }
178
179 #[cfg(not(feature = "embeddings"))]
180 fn run_inference(&self, _input: &TokenizedInput) -> anyhow::Result<Vec<f32>> {
181 anyhow::bail!("Embeddings feature not enabled")
182 }
183
184 #[cfg(feature = "embeddings")]
186 fn detect_dimensions(
187 model: &Model,
188 tokenizer: &WordPieceTokenizer,
189 input_names: &InputNodeIds,
190 output_id: rten::NodeId,
191 ) -> Option<usize> {
192 use rten_tensor::{Layout, NdTensor};
193
194 let dummy = tokenizer.encode("test", 8);
195 let seq_len = dummy.input_ids.len();
196
197 let ids = NdTensor::from_data([1, seq_len], dummy.input_ids);
198 let mask = NdTensor::from_data([1, seq_len], dummy.attention_mask);
199 let types = NdTensor::from_data([1, seq_len], dummy.token_type_ids);
200
201 let inputs = vec![
202 (input_names.input_ids, ids.into()),
203 (input_names.attention_mask, mask.into()),
204 (input_names.token_type_ids, types.into()),
205 ];
206
207 let outputs = model.run(inputs, &[output_id], None).ok()?;
208 let tensor = outputs.into_iter().next()?.into_tensor::<f32>()?;
209 let shape = tensor.shape();
210
211 shape.last().copied()
213 }
214}
215
216pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
219 debug_assert_eq!(a.len(), b.len(), "vectors must have equal dimensions");
220 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
221}
222
223pub fn cosine_similarity_raw(a: &[f32], b: &[f32]) -> f32 {
225 debug_assert_eq!(a.len(), b.len());
226 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
227 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
228 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
229 if norm_a == 0.0 || norm_b == 0.0 {
230 return 0.0;
231 }
232 dot / (norm_a * norm_b)
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn cosine_similarity_identical() {
241 let a = vec![1.0, 0.0, 0.0];
242 let b = vec![1.0, 0.0, 0.0];
243 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
244 }
245
246 #[test]
247 fn cosine_similarity_orthogonal() {
248 let a = vec![1.0, 0.0, 0.0];
249 let b = vec![0.0, 1.0, 0.0];
250 assert!(cosine_similarity(&a, &b).abs() < 1e-6);
251 }
252
253 #[test]
254 fn cosine_similarity_opposite() {
255 let a = vec![1.0, 0.0, 0.0];
256 let b = vec![-1.0, 0.0, 0.0];
257 assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-6);
258 }
259
260 #[test]
261 fn cosine_similarity_raw_unnormalized() {
262 let a = vec![3.0, 4.0];
263 let b = vec![3.0, 4.0];
264 assert!((cosine_similarity_raw(&a, &b) - 1.0).abs() < 1e-6);
265 }
266
267 #[test]
268 fn cosine_similarity_raw_zero_vector() {
269 let a = vec![0.0, 0.0];
270 let b = vec![1.0, 2.0];
271 assert_eq!(cosine_similarity_raw(&a, &b), 0.0);
272 }
273
274 #[test]
275 fn model_directory_env_override_and_availability() {
276 let unique = "/tmp/lean_ctx_test_embed_42xyz";
277 std::env::set_var("LEAN_CTX_MODELS_DIR", unique);
278 let dir = EmbeddingEngine::model_directory();
279 assert_eq!(dir.to_string_lossy(), unique);
280 assert!(!EmbeddingEngine::is_available());
281 std::env::remove_var("LEAN_CTX_MODELS_DIR");
282 }
283}