1use std::collections::HashMap;
2use std::fs;
3use std::io::Cursor;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::sync::OnceLock;
7
8use anyhow::{Context, Result, bail, format_err};
9use serde_json::json;
10use tract_onnx::prelude::*;
11
12pub const EMBEDDING_DIMENSIONS: usize = 384;
13pub const MAX_SEQUENCE_LENGTH: usize = 128;
14pub const DEFAULT_MODEL_FILENAME: &str = "minilm_model_quint8_avx2.onnx";
15pub const DEFAULT_VOCAB_FILENAME: &str = "vocab.txt";
16pub const EMBEDDED_MODEL_SIZE: usize = 23_046_789;
17pub const EMBEDDED_MODEL_SHA256: &str =
18 "b941bf19f1f1283680f449fa6a7336bb5600bdcd5f84d10ddc5cd72218a0fd21";
19pub const EMBEDDED_VOCAB_SIZE: usize = 231_508;
20pub const EMBEDDED_VOCAB_SHA256: &str =
21 "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3";
22
23#[cfg(has_embedded_embeddings)]
24#[used]
25pub static EMBEDDED_MODEL_BYTES: [u8; EMBEDDED_MODEL_SIZE] =
26 *include_bytes!("../weights/minilm_model_quint8_avx2.onnx");
27#[cfg(has_embedded_embeddings)]
28pub static EMBEDDED_VOCAB: &str = include_str!("../weights/vocab.txt");
29
30type RunnableMiniLm = Arc<TypedRunnableModel>;
31
32static EXTERNAL_EMBEDDINGS: OnceLock<EmbeddingPaths> = OnceLock::new();
33static MODEL: OnceLock<TractResult<RunnableMiniLm>> = OnceLock::new();
34static VOCAB: OnceLock<TractResult<HashMap<String, i64>>> = OnceLock::new();
35
36#[derive(Debug, Clone)]
37struct EmbeddingPaths {
38 model_path: PathBuf,
39 vocab_path: PathBuf,
40}
41
42pub fn configure_embeddings_path(path: impl Into<PathBuf>) -> Result<()> {
43 if MODEL.get().is_some() || VOCAB.get().is_some() {
44 bail!("--embeddings must be configured before embeddings are first used");
45 }
46
47 let paths = resolve_embeddings_path(path.into())?;
48 EXTERNAL_EMBEDDINGS
49 .set(paths)
50 .map_err(|_| format_err!("--embeddings was configured more than once"))?;
51
52 Ok(())
53}
54
55fn resolve_embeddings_path(path: PathBuf) -> Result<EmbeddingPaths> {
56 let (model_path, vocab_path) = if path.is_dir() {
57 (
58 path.join(DEFAULT_MODEL_FILENAME),
59 path.join(DEFAULT_VOCAB_FILENAME),
60 )
61 } else {
62 let vocab_path = path
63 .parent()
64 .filter(|parent| !parent.as_os_str().is_empty())
65 .unwrap_or_else(|| Path::new("."))
66 .join(DEFAULT_VOCAB_FILENAME);
67 (path, vocab_path)
68 };
69
70 if !model_path.is_file() {
71 bail!("embedding model file not found at {}", model_path.display());
72 }
73
74 if !vocab_path.is_file() {
75 bail!(
76 "embedding vocabulary file not found at {}",
77 vocab_path.display()
78 );
79 }
80
81 Ok(EmbeddingPaths {
82 model_path,
83 vocab_path,
84 })
85}
86
87#[cfg(has_embedded_embeddings)]
88pub fn embedded_model_size() -> usize {
89 EMBEDDED_MODEL_BYTES.len()
90}
91
92#[cfg(has_embedded_embeddings)]
93pub fn embedded_model_bytes() -> &'static [u8] {
94 &EMBEDDED_MODEL_BYTES
95}
96
97pub fn embed_text(text: &str) -> TractResult<Vec<f32>> {
98 minilm_embedding(text)
99}
100
101pub fn blend(content_embedding: &[f32], tag_embedding: &[f32]) -> Vec<f32> {
102 let mut blended = vec![0.0; EMBEDDING_DIMENSIONS];
103
104 for (index, value) in blended.iter_mut().enumerate() {
105 *value = content_embedding.get(index).copied().unwrap_or_default() * 0.75
106 + tag_embedding.get(index).copied().unwrap_or_default() * 0.25;
107 }
108
109 normalize(&mut blended);
110 blended
111}
112
113pub fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
114 left.iter()
115 .zip(right.iter())
116 .map(|(left, right)| left * right)
117 .sum::<f32>()
118 .clamp(-1.0, 1.0)
119}
120
121pub fn encode_embedding(embedding: &[f32]) -> String {
122 serde_json::to_string(embedding).unwrap_or_else(|_| json!([]).to_string())
123}
124
125pub fn decode_embedding(raw: &str) -> Vec<f32> {
126 let mut embedding = serde_json::from_str::<Vec<f32>>(raw).unwrap_or_default();
127 embedding.resize(EMBEDDING_DIMENSIONS, 0.0);
128 embedding.truncate(EMBEDDING_DIMENSIONS);
129 normalize(&mut embedding);
130 embedding
131}
132
133fn minilm_embedding(text: &str) -> TractResult<Vec<f32>> {
134 let encoded = encode_text(text)?;
135 let shape = [1, MAX_SEQUENCE_LENGTH];
136 let input_ids = Tensor::from_shape(&shape, &encoded.input_ids)?.into_tvalue();
137 let attention_mask = Tensor::from_shape(&shape, &encoded.attention_mask)?.into_tvalue();
138 let token_type_ids = Tensor::from_shape(&shape, &encoded.token_type_ids)?.into_tvalue();
139 let outputs = load_model()?.run(tvec!(input_ids, attention_mask, token_type_ids))?;
140 let last_hidden_state = outputs[0].to_plain_array_view::<f32>()?;
141 let hidden_size = last_hidden_state.shape().get(2).copied().unwrap_or(0);
142 let mut embedding = vec![0.0; hidden_size];
143 let mut token_count = 0.0_f32;
144
145 for token_index in 0..MAX_SEQUENCE_LENGTH {
146 if encoded.attention_mask[token_index] == 0 {
147 continue;
148 }
149
150 token_count += 1.0;
151 for hidden_index in 0..hidden_size {
152 embedding[hidden_index] += last_hidden_state[[0, token_index, hidden_index]];
153 }
154 }
155
156 if token_count > 0.0 {
157 for value in &mut embedding {
158 *value /= token_count;
159 }
160 }
161
162 embedding.resize(EMBEDDING_DIMENSIONS, 0.0);
163 embedding.truncate(EMBEDDING_DIMENSIONS);
164 normalize(&mut embedding);
165 Ok(embedding)
166}
167
168fn load_model() -> TractResult<&'static RunnableMiniLm> {
169 MODEL
170 .get_or_init(load_model_from_source)
171 .as_ref()
172 .map_err(|error| format_err!("failed to load MiniLM model: {error}"))
173}
174
175fn load_model_from_source() -> TractResult<RunnableMiniLm> {
176 if let Some(paths) = EXTERNAL_EMBEDDINGS.get() {
177 let model_bytes = fs::read(&paths.model_path).with_context(|| {
178 format!(
179 "failed to read embedding model {}",
180 paths.model_path.display()
181 )
182 })?;
183 let mut model_bytes = Cursor::new(model_bytes);
184
185 return tract_onnx::onnx()
186 .model_for_read(&mut model_bytes)?
187 .into_optimized()?
188 .into_runnable();
189 }
190
191 if let Some(model_bytes) = embedded_model_bytes_if_available() {
192 let mut model_bytes = Cursor::new(model_bytes);
193
194 return tract_onnx::onnx()
195 .model_for_read(&mut model_bytes)?
196 .into_optimized()?
197 .into_runnable();
198 }
199
200 bail!("{}", missing_embeddings_message())
201}
202
203#[derive(Debug)]
204struct EncodedText {
205 input_ids: Vec<i64>,
206 attention_mask: Vec<i64>,
207 token_type_ids: Vec<i64>,
208}
209
210fn encode_text(text: &str) -> TractResult<EncodedText> {
211 let vocab = vocab()?;
212 let pad_id = token_id(vocab, "[PAD]");
213 let unknown_id = token_id(vocab, "[UNK]");
214 let cls_id = token_id(vocab, "[CLS]");
215 let sep_id = token_id(vocab, "[SEP]");
216 let mut input_ids = Vec::with_capacity(MAX_SEQUENCE_LENGTH);
217
218 input_ids.push(cls_id);
219 for token in basic_tokens(text) {
220 for piece in wordpiece(&token, vocab, unknown_id) {
221 if input_ids.len() >= MAX_SEQUENCE_LENGTH - 1 {
222 break;
223 }
224 input_ids.push(piece);
225 }
226
227 if input_ids.len() >= MAX_SEQUENCE_LENGTH - 1 {
228 break;
229 }
230 }
231 input_ids.push(sep_id);
232
233 let mut attention_mask = vec![1; input_ids.len()];
234 let mut token_type_ids = vec![0; input_ids.len()];
235
236 input_ids.resize(MAX_SEQUENCE_LENGTH, pad_id);
237 attention_mask.resize(MAX_SEQUENCE_LENGTH, 0);
238 token_type_ids.resize(MAX_SEQUENCE_LENGTH, 0);
239
240 Ok(EncodedText {
241 input_ids,
242 attention_mask,
243 token_type_ids,
244 })
245}
246
247fn vocab() -> TractResult<&'static HashMap<String, i64>> {
248 VOCAB
249 .get_or_init(load_vocab_from_source)
250 .as_ref()
251 .map_err(|error| format_err!("failed to load MiniLM vocabulary: {error}"))
252}
253
254fn load_vocab_from_source() -> TractResult<HashMap<String, i64>> {
255 let vocab = if let Some(paths) = EXTERNAL_EMBEDDINGS.get() {
256 fs::read_to_string(&paths.vocab_path).with_context(|| {
257 format!(
258 "failed to read embedding vocabulary {}",
259 paths.vocab_path.display()
260 )
261 })?
262 } else if let Some(vocab) = embedded_vocab_if_available() {
263 vocab.to_string()
264 } else {
265 bail!("{}", missing_embeddings_message());
266 };
267
268 Ok(vocab
269 .lines()
270 .enumerate()
271 .map(|(index, token)| (token.trim_end().to_string(), index as i64))
272 .collect())
273}
274
275fn token_id(vocab: &HashMap<String, i64>, token: &str) -> i64 {
276 *vocab.get(token).unwrap_or(&100)
277}
278
279fn basic_tokens(text: &str) -> Vec<String> {
280 let mut tokens = Vec::new();
281 let mut current = String::new();
282
283 for character in text.chars().flat_map(char::to_lowercase) {
284 if character.is_whitespace() {
285 push_current_token(&mut tokens, &mut current);
286 } else if is_punctuation(character) {
287 push_current_token(&mut tokens, &mut current);
288 tokens.push(character.to_string());
289 } else if !character.is_control() {
290 current.push(character);
291 }
292 }
293
294 push_current_token(&mut tokens, &mut current);
295 tokens
296}
297
298fn push_current_token(tokens: &mut Vec<String>, current: &mut String) {
299 if !current.is_empty() {
300 tokens.push(std::mem::take(current));
301 }
302}
303
304fn is_punctuation(character: char) -> bool {
305 character.is_ascii_punctuation()
306 || matches!(character as u32, 0x2000..=0x206F | 0x2E00..=0x2E7F)
307}
308
309fn wordpiece(token: &str, vocab: &HashMap<String, i64>, unknown_id: i64) -> Vec<i64> {
310 let characters = token.chars().collect::<Vec<_>>();
311 if characters.len() > 100 {
312 return vec![unknown_id];
313 }
314
315 let mut pieces = Vec::new();
316 let mut start = 0;
317
318 while start < characters.len() {
319 let mut end = characters.len();
320 let mut current = None;
321
322 while start < end {
323 let mut piece = String::new();
324 if start > 0 {
325 piece.push_str("##");
326 }
327 piece.extend(&characters[start..end]);
328
329 if let Some(id) = vocab.get(piece.as_str()) {
330 current = Some(*id);
331 break;
332 }
333 end -= 1;
334 }
335
336 let Some(id) = current else {
337 return vec![unknown_id];
338 };
339
340 pieces.push(id);
341 start = end;
342 }
343
344 pieces
345}
346
347fn missing_embeddings_message() -> &'static str {
348 "this mii-memory binary was built without embedded embeddings; pass --embeddings <PATH> or set MII_MEMORY_EMBEDDINGS to a directory containing minilm_model_quint8_avx2.onnx and vocab.txt"
349}
350
351#[cfg(has_embedded_embeddings)]
352fn embedded_model_bytes_if_available() -> Option<&'static [u8]> {
353 Some(embedded_model_bytes())
354}
355
356#[cfg(not(has_embedded_embeddings))]
357fn embedded_model_bytes_if_available() -> Option<&'static [u8]> {
358 None
359}
360
361#[cfg(has_embedded_embeddings)]
362fn embedded_vocab_if_available() -> Option<&'static str> {
363 Some(EMBEDDED_VOCAB)
364}
365
366#[cfg(not(has_embedded_embeddings))]
367fn embedded_vocab_if_available() -> Option<&'static str> {
368 None
369}
370
371fn normalize(embedding: &mut [f32]) {
372 let length = embedding
373 .iter()
374 .map(|value| value * value)
375 .sum::<f32>()
376 .sqrt();
377
378 if length == 0.0 {
379 return;
380 }
381
382 for value in embedding {
383 *value /= length;
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 #[cfg(has_embedded_embeddings)]
391 use sha2::{Digest, Sha256};
392
393 #[cfg(has_embedded_embeddings)]
394 #[test]
395 fn related_text_scores_higher_than_unrelated_text() {
396 let query = embed_text("rust sqlite memory tags").expect("query embedding");
397 let related =
398 embed_text("sqlite backed rust memory store with tags").expect("related embedding");
399 let unrelated = embed_text("fresh bread and ceramic cups").expect("unrelated embedding");
400
401 assert!(cosine_similarity(&query, &related) > cosine_similarity(&query, &unrelated));
402 }
403
404 #[cfg(has_embedded_embeddings)]
405 #[test]
406 fn minilm_embedding_returns_normalized_vector() {
407 let embedding = minilm_embedding("rust sqlite memory tags").expect("MiniLM embedding");
408 let length = embedding
409 .iter()
410 .map(|value| value * value)
411 .sum::<f32>()
412 .sqrt();
413
414 assert_eq!(embedding.len(), EMBEDDING_DIMENSIONS);
415 assert!(embedding.iter().any(|value| *value != 0.0));
416 assert!((length - 1.0).abs() < 0.0001);
417 }
418
419 #[cfg(has_embedded_embeddings)]
420 #[test]
421 fn minilm_model_and_vocab_are_embedded() {
422 let model_hash = Sha256::digest(embedded_model_bytes());
423 let vocab_hash = Sha256::digest(EMBEDDED_VOCAB.as_bytes());
424
425 assert_eq!(embedded_model_size(), EMBEDDED_MODEL_SIZE);
426 assert_eq!(hex::encode(model_hash), EMBEDDED_MODEL_SHA256);
427 assert_eq!(EMBEDDED_VOCAB.len(), EMBEDDED_VOCAB_SIZE);
428 assert_eq!(hex::encode(vocab_hash), EMBEDDED_VOCAB_SHA256);
429 }
430
431 #[cfg(not(has_embedded_embeddings))]
432 #[test]
433 fn embedding_requires_external_assets_when_not_embedded() {
434 let error = embed_text("rust sqlite memory tags")
435 .unwrap_err()
436 .to_string();
437
438 assert!(error.contains("--embeddings <PATH>"));
439 }
440}