Skip to main content

mii_memory/
embedding.rs

1use std::collections::HashMap;
2use std::fs;
3use std::io::Cursor;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::sync::OnceLock;
7
8use anyhow::{Context, Result, bail, format_err};
9use serde_json::json;
10use tract_onnx::prelude::*;
11
12pub const EMBEDDING_DIMENSIONS: usize = 384;
13pub const MAX_SEQUENCE_LENGTH: usize = 128;
14pub const DEFAULT_MODEL_FILENAME: &str = "minilm_model_quint8_avx2.onnx";
15pub const DEFAULT_VOCAB_FILENAME: &str = "vocab.txt";
16pub const EMBEDDED_MODEL_SIZE: usize = 23_046_789;
17pub const EMBEDDED_MODEL_SHA256: &str =
18    "b941bf19f1f1283680f449fa6a7336bb5600bdcd5f84d10ddc5cd72218a0fd21";
19pub const EMBEDDED_VOCAB_SIZE: usize = 231_508;
20pub const EMBEDDED_VOCAB_SHA256: &str =
21    "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3";
22
23#[cfg(has_embedded_embeddings)]
24#[used]
25pub static EMBEDDED_MODEL_BYTES: [u8; EMBEDDED_MODEL_SIZE] =
26    *include_bytes!("../weights/minilm_model_quint8_avx2.onnx");
27#[cfg(has_embedded_embeddings)]
28pub static EMBEDDED_VOCAB: &str = include_str!("../weights/vocab.txt");
29
30type RunnableMiniLm = Arc<TypedRunnableModel>;
31
32static EXTERNAL_EMBEDDINGS: OnceLock<EmbeddingPaths> = OnceLock::new();
33static MODEL: OnceLock<TractResult<RunnableMiniLm>> = OnceLock::new();
34static VOCAB: OnceLock<TractResult<HashMap<String, i64>>> = OnceLock::new();
35
36#[derive(Debug, Clone)]
37struct EmbeddingPaths {
38    model_path: PathBuf,
39    vocab_path: PathBuf,
40}
41
42pub fn configure_embeddings_path(path: impl Into<PathBuf>) -> Result<()> {
43    if MODEL.get().is_some() || VOCAB.get().is_some() {
44        bail!("--embeddings must be configured before embeddings are first used");
45    }
46
47    let paths = resolve_embeddings_path(path.into())?;
48    EXTERNAL_EMBEDDINGS
49        .set(paths)
50        .map_err(|_| format_err!("--embeddings was configured more than once"))?;
51
52    Ok(())
53}
54
55fn resolve_embeddings_path(path: PathBuf) -> Result<EmbeddingPaths> {
56    let (model_path, vocab_path) = if path.is_dir() {
57        (
58            path.join(DEFAULT_MODEL_FILENAME),
59            path.join(DEFAULT_VOCAB_FILENAME),
60        )
61    } else {
62        let vocab_path = path
63            .parent()
64            .filter(|parent| !parent.as_os_str().is_empty())
65            .unwrap_or_else(|| Path::new("."))
66            .join(DEFAULT_VOCAB_FILENAME);
67        (path, vocab_path)
68    };
69
70    if !model_path.is_file() {
71        bail!("embedding model file not found at {}", model_path.display());
72    }
73
74    if !vocab_path.is_file() {
75        bail!(
76            "embedding vocabulary file not found at {}",
77            vocab_path.display()
78        );
79    }
80
81    Ok(EmbeddingPaths {
82        model_path,
83        vocab_path,
84    })
85}
86
87#[cfg(has_embedded_embeddings)]
88pub fn embedded_model_size() -> usize {
89    EMBEDDED_MODEL_BYTES.len()
90}
91
92#[cfg(has_embedded_embeddings)]
93pub fn embedded_model_bytes() -> &'static [u8] {
94    &EMBEDDED_MODEL_BYTES
95}
96
97pub fn embed_text(text: &str) -> TractResult<Vec<f32>> {
98    minilm_embedding(text)
99}
100
101pub fn blend(content_embedding: &[f32], tag_embedding: &[f32]) -> Vec<f32> {
102    let mut blended = vec![0.0; EMBEDDING_DIMENSIONS];
103
104    for (index, value) in blended.iter_mut().enumerate() {
105        *value = content_embedding.get(index).copied().unwrap_or_default() * 0.75
106            + tag_embedding.get(index).copied().unwrap_or_default() * 0.25;
107    }
108
109    normalize(&mut blended);
110    blended
111}
112
113pub fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
114    left.iter()
115        .zip(right.iter())
116        .map(|(left, right)| left * right)
117        .sum::<f32>()
118        .clamp(-1.0, 1.0)
119}
120
121pub fn encode_embedding(embedding: &[f32]) -> String {
122    serde_json::to_string(embedding).unwrap_or_else(|_| json!([]).to_string())
123}
124
125pub fn decode_embedding(raw: &str) -> Vec<f32> {
126    let mut embedding = serde_json::from_str::<Vec<f32>>(raw).unwrap_or_default();
127    embedding.resize(EMBEDDING_DIMENSIONS, 0.0);
128    embedding.truncate(EMBEDDING_DIMENSIONS);
129    normalize(&mut embedding);
130    embedding
131}
132
133fn minilm_embedding(text: &str) -> TractResult<Vec<f32>> {
134    let encoded = encode_text(text)?;
135    let shape = [1, MAX_SEQUENCE_LENGTH];
136    let input_ids = Tensor::from_shape(&shape, &encoded.input_ids)?.into_tvalue();
137    let attention_mask = Tensor::from_shape(&shape, &encoded.attention_mask)?.into_tvalue();
138    let token_type_ids = Tensor::from_shape(&shape, &encoded.token_type_ids)?.into_tvalue();
139    let outputs = load_model()?.run(tvec!(input_ids, attention_mask, token_type_ids))?;
140    let last_hidden_state = outputs[0].to_plain_array_view::<f32>()?;
141    let hidden_size = last_hidden_state.shape().get(2).copied().unwrap_or(0);
142    let mut embedding = vec![0.0; hidden_size];
143    let mut token_count = 0.0_f32;
144
145    for token_index in 0..MAX_SEQUENCE_LENGTH {
146        if encoded.attention_mask[token_index] == 0 {
147            continue;
148        }
149
150        token_count += 1.0;
151        for hidden_index in 0..hidden_size {
152            embedding[hidden_index] += last_hidden_state[[0, token_index, hidden_index]];
153        }
154    }
155
156    if token_count > 0.0 {
157        for value in &mut embedding {
158            *value /= token_count;
159        }
160    }
161
162    embedding.resize(EMBEDDING_DIMENSIONS, 0.0);
163    embedding.truncate(EMBEDDING_DIMENSIONS);
164    normalize(&mut embedding);
165    Ok(embedding)
166}
167
168fn load_model() -> TractResult<&'static RunnableMiniLm> {
169    MODEL
170        .get_or_init(load_model_from_source)
171        .as_ref()
172        .map_err(|error| format_err!("failed to load MiniLM model: {error}"))
173}
174
175fn load_model_from_source() -> TractResult<RunnableMiniLm> {
176    if let Some(paths) = EXTERNAL_EMBEDDINGS.get() {
177        let model_bytes = fs::read(&paths.model_path).with_context(|| {
178            format!(
179                "failed to read embedding model {}",
180                paths.model_path.display()
181            )
182        })?;
183        let mut model_bytes = Cursor::new(model_bytes);
184
185        return tract_onnx::onnx()
186            .model_for_read(&mut model_bytes)?
187            .into_optimized()?
188            .into_runnable();
189    }
190
191    if let Some(model_bytes) = embedded_model_bytes_if_available() {
192        let mut model_bytes = Cursor::new(model_bytes);
193
194        return tract_onnx::onnx()
195            .model_for_read(&mut model_bytes)?
196            .into_optimized()?
197            .into_runnable();
198    }
199
200    bail!("{}", missing_embeddings_message())
201}
202
203#[derive(Debug)]
204struct EncodedText {
205    input_ids: Vec<i64>,
206    attention_mask: Vec<i64>,
207    token_type_ids: Vec<i64>,
208}
209
210fn encode_text(text: &str) -> TractResult<EncodedText> {
211    let vocab = vocab()?;
212    let pad_id = token_id(vocab, "[PAD]");
213    let unknown_id = token_id(vocab, "[UNK]");
214    let cls_id = token_id(vocab, "[CLS]");
215    let sep_id = token_id(vocab, "[SEP]");
216    let mut input_ids = Vec::with_capacity(MAX_SEQUENCE_LENGTH);
217
218    input_ids.push(cls_id);
219    for token in basic_tokens(text) {
220        for piece in wordpiece(&token, vocab, unknown_id) {
221            if input_ids.len() >= MAX_SEQUENCE_LENGTH - 1 {
222                break;
223            }
224            input_ids.push(piece);
225        }
226
227        if input_ids.len() >= MAX_SEQUENCE_LENGTH - 1 {
228            break;
229        }
230    }
231    input_ids.push(sep_id);
232
233    let mut attention_mask = vec![1; input_ids.len()];
234    let mut token_type_ids = vec![0; input_ids.len()];
235
236    input_ids.resize(MAX_SEQUENCE_LENGTH, pad_id);
237    attention_mask.resize(MAX_SEQUENCE_LENGTH, 0);
238    token_type_ids.resize(MAX_SEQUENCE_LENGTH, 0);
239
240    Ok(EncodedText {
241        input_ids,
242        attention_mask,
243        token_type_ids,
244    })
245}
246
247fn vocab() -> TractResult<&'static HashMap<String, i64>> {
248    VOCAB
249        .get_or_init(load_vocab_from_source)
250        .as_ref()
251        .map_err(|error| format_err!("failed to load MiniLM vocabulary: {error}"))
252}
253
254fn load_vocab_from_source() -> TractResult<HashMap<String, i64>> {
255    let vocab = if let Some(paths) = EXTERNAL_EMBEDDINGS.get() {
256        fs::read_to_string(&paths.vocab_path).with_context(|| {
257            format!(
258                "failed to read embedding vocabulary {}",
259                paths.vocab_path.display()
260            )
261        })?
262    } else if let Some(vocab) = embedded_vocab_if_available() {
263        vocab.to_string()
264    } else {
265        bail!("{}", missing_embeddings_message());
266    };
267
268    Ok(vocab
269        .lines()
270        .enumerate()
271        .map(|(index, token)| (token.trim_end().to_string(), index as i64))
272        .collect())
273}
274
275fn token_id(vocab: &HashMap<String, i64>, token: &str) -> i64 {
276    *vocab.get(token).unwrap_or(&100)
277}
278
279fn basic_tokens(text: &str) -> Vec<String> {
280    let mut tokens = Vec::new();
281    let mut current = String::new();
282
283    for character in text.chars().flat_map(char::to_lowercase) {
284        if character.is_whitespace() {
285            push_current_token(&mut tokens, &mut current);
286        } else if is_punctuation(character) {
287            push_current_token(&mut tokens, &mut current);
288            tokens.push(character.to_string());
289        } else if !character.is_control() {
290            current.push(character);
291        }
292    }
293
294    push_current_token(&mut tokens, &mut current);
295    tokens
296}
297
298fn push_current_token(tokens: &mut Vec<String>, current: &mut String) {
299    if !current.is_empty() {
300        tokens.push(std::mem::take(current));
301    }
302}
303
304fn is_punctuation(character: char) -> bool {
305    character.is_ascii_punctuation()
306        || matches!(character as u32, 0x2000..=0x206F | 0x2E00..=0x2E7F)
307}
308
309fn wordpiece(token: &str, vocab: &HashMap<String, i64>, unknown_id: i64) -> Vec<i64> {
310    let characters = token.chars().collect::<Vec<_>>();
311    if characters.len() > 100 {
312        return vec![unknown_id];
313    }
314
315    let mut pieces = Vec::new();
316    let mut start = 0;
317
318    while start < characters.len() {
319        let mut end = characters.len();
320        let mut current = None;
321
322        while start < end {
323            let mut piece = String::new();
324            if start > 0 {
325                piece.push_str("##");
326            }
327            piece.extend(&characters[start..end]);
328
329            if let Some(id) = vocab.get(piece.as_str()) {
330                current = Some(*id);
331                break;
332            }
333            end -= 1;
334        }
335
336        let Some(id) = current else {
337            return vec![unknown_id];
338        };
339
340        pieces.push(id);
341        start = end;
342    }
343
344    pieces
345}
346
347fn missing_embeddings_message() -> &'static str {
348    "this mii-memory binary was built without embedded embeddings; pass --embeddings <PATH> or set MII_MEMORY_EMBEDDINGS to a directory containing minilm_model_quint8_avx2.onnx and vocab.txt"
349}
350
351#[cfg(has_embedded_embeddings)]
352fn embedded_model_bytes_if_available() -> Option<&'static [u8]> {
353    Some(embedded_model_bytes())
354}
355
356#[cfg(not(has_embedded_embeddings))]
357fn embedded_model_bytes_if_available() -> Option<&'static [u8]> {
358    None
359}
360
361#[cfg(has_embedded_embeddings)]
362fn embedded_vocab_if_available() -> Option<&'static str> {
363    Some(EMBEDDED_VOCAB)
364}
365
366#[cfg(not(has_embedded_embeddings))]
367fn embedded_vocab_if_available() -> Option<&'static str> {
368    None
369}
370
371fn normalize(embedding: &mut [f32]) {
372    let length = embedding
373        .iter()
374        .map(|value| value * value)
375        .sum::<f32>()
376        .sqrt();
377
378    if length == 0.0 {
379        return;
380    }
381
382    for value in embedding {
383        *value /= length;
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    #[cfg(has_embedded_embeddings)]
391    use sha2::{Digest, Sha256};
392
393    #[cfg(has_embedded_embeddings)]
394    #[test]
395    fn related_text_scores_higher_than_unrelated_text() {
396        let query = embed_text("rust sqlite memory tags").expect("query embedding");
397        let related =
398            embed_text("sqlite backed rust memory store with tags").expect("related embedding");
399        let unrelated = embed_text("fresh bread and ceramic cups").expect("unrelated embedding");
400
401        assert!(cosine_similarity(&query, &related) > cosine_similarity(&query, &unrelated));
402    }
403
404    #[cfg(has_embedded_embeddings)]
405    #[test]
406    fn minilm_embedding_returns_normalized_vector() {
407        let embedding = minilm_embedding("rust sqlite memory tags").expect("MiniLM embedding");
408        let length = embedding
409            .iter()
410            .map(|value| value * value)
411            .sum::<f32>()
412            .sqrt();
413
414        assert_eq!(embedding.len(), EMBEDDING_DIMENSIONS);
415        assert!(embedding.iter().any(|value| *value != 0.0));
416        assert!((length - 1.0).abs() < 0.0001);
417    }
418
419    #[cfg(has_embedded_embeddings)]
420    #[test]
421    fn minilm_model_and_vocab_are_embedded() {
422        let model_hash = Sha256::digest(embedded_model_bytes());
423        let vocab_hash = Sha256::digest(EMBEDDED_VOCAB.as_bytes());
424
425        assert_eq!(embedded_model_size(), EMBEDDED_MODEL_SIZE);
426        assert_eq!(hex::encode(model_hash), EMBEDDED_MODEL_SHA256);
427        assert_eq!(EMBEDDED_VOCAB.len(), EMBEDDED_VOCAB_SIZE);
428        assert_eq!(hex::encode(vocab_hash), EMBEDDED_VOCAB_SHA256);
429    }
430
431    #[cfg(not(has_embedded_embeddings))]
432    #[test]
433    fn embedding_requires_external_assets_when_not_embedded() {
434        let error = embed_text("rust sqlite memory tags")
435            .unwrap_err()
436            .to_string();
437
438        assert!(error.contains("--embeddings <PATH>"));
439    }
440}