Skip to main content

ripvec_core/encoder/ripvec/
static_model.rs

1//! In-process reimplementation of the Model2Vec static embedder.
2//!
3//! Replaces the `model2vec-rs` 0.2 dependency. Reasons:
4//!
5//! 1. **Parallelism**: `model2vec_rs::StaticModel::encode_with_args` runs
6//!    `pool_ids` in a serial inner loop and calls `tokenizers::Tokenizer::encode_batch_fast`
7//!    (which spawns its own rayon pool internally). Calling that path
8//!    from inside an outer rayon `par_chunks` produced ~60% `__psynch_cvwait`
9//!    in our linux-corpus profile — nested rayon scopes parking on each
10//!    other. This implementation: tokenize ONCE across the full corpus on
11//!    the unfettered thread pool, then mean-pool every encoding in parallel
12//!    via a single `par_iter`. No nesting.
13//!
14//! 2. **ndarray version**: model2vec-rs pinned `ndarray 0.15`; ripvec-core
15//!    uses `ndarray 0.17`. The two `Array2<f32>` types are not
16//!    interchangeable. Owning the load path here lets us use the workspace
17//!    ndarray directly.
18//!
19//! 3. **Allocator pressure**: model2vec-rs builds intermediate
20//!    `Vec<String>` clones inside `encode_with_args`. The local
21//!    implementation tokenizes from `&[&str]` references directly.
22//!
23//! The file format is the published Model2Vec layout (tokenizer.json +
24//! model.safetensors + config.json). Local paths only — if Hub download
25//! is needed, pre-stage the files via `curl` (see
26//! `crates/ripvec-core/tests/ripvec_port_parity.rs` for the recipe).
27//!
28//! ## Behavioural parity
29//!
30//! Identical math to `model2vec_rs::StaticModel::encode_with_args`:
31//!
32//! - Truncate input strings by char count = `max_tokens * median_token_length`
33//!   (HF tokenizers can be slow on huge strings).
34//! - Tokenize via `tokenizers::Tokenizer::encode_batch_fast`.
35//! - Drop UNK tokens.
36//! - Truncate token ID list to `max_tokens`.
37//! - Pool: for each token, look up the embedding row (optionally remapped
38//!   via `token_mapping`), scale by the per-token weight (default 1.0),
39//!   accumulate.
40//! - Divide by token count; L2-normalize if `normalize` is set.
41//!
42//! Verified by the integration test
43//! `crates/ripvec-core/tests/ripvec_port_parity.rs` which exercises the
44//! end-to-end pipeline against `minishlab/potion-code-16M`.
45
46use std::path::Path;
47
48use anyhow::{Context, Result, anyhow};
49use ndarray::Array2;
50use rayon::prelude::*;
51use safetensors::SafeTensors;
52use safetensors::tensor::Dtype;
53use serde_json::Value;
54use tokenizers::Tokenizer;
55use wide::f32x8;
56
57/// Default token cap per chunk during embedding. Matches the
58/// `model2vec_rs` default; CodeChunks are typically far below this.
59pub const DEFAULT_MAX_TOKENS: usize = 512;
60
61/// Tokenize sub-batch size used inside [`StaticEmbedModel::encode_batch`].
62///
63/// `tokenizers::encode_batch_fast` parallelizes internally via rayon.
64/// One giant call across the full corpus dominates wall time in
65/// `Encoding` allocation + internal chunk scheduling; 1024 mirrors
66/// `model2vec_rs`'s internal default and measured noticeably faster
67/// on a 92K-file linux-source corpus.
68const BATCH_SIZE: usize = 1024;
69
70/// Loaded Model2Vec static embedder.
71///
72/// Constructed via [`StaticEmbedModel::from_path`]. Use
73/// [`encode_query`](Self::encode_query) for a single text and
74/// [`encode_batch`](Self::encode_batch) for many — the batch path is
75/// where the parallel-pool win lives.
76pub struct StaticEmbedModel {
77    tokenizer: Tokenizer,
78    /// `(vocab_size, hidden_dim)` row-major embedding table.
79    embeddings: Array2<f32>,
80    /// Per-token scalar weight (typically present in quantized models).
81    /// `None` means use 1.0 for every token.
82    weights: Option<Vec<f32>>,
83    /// Optional remap from token-id → embedding-row index.
84    /// `None` means use the token-id directly.
85    token_mapping: Option<Vec<usize>>,
86    /// Whether to L2-normalize the pooled output. Read from `config.json`.
87    normalize: bool,
88    /// Median bytes-per-token across the tokenizer vocab. Used for the
89    /// char-level truncation heuristic (avoids pathological tokenization
90    /// of multi-MB strings).
91    median_token_length: usize,
92    /// Token id to drop after tokenization (typically the BPE
93    /// `[UNK]`/`<unk>` id). `None` if the tokenizer has no unk token.
94    unk_token_id: Option<usize>,
95}
96
97impl StaticEmbedModel {
98    /// Load from a local directory containing
99    /// `tokenizer.json`, `model.safetensors`, and `config.json`.
100    ///
101    /// `normalize_override` lets callers force-enable or force-disable
102    /// L2 normalization regardless of what `config.json` says. Pass
103    /// `None` to honor the config.
104    pub fn from_path(path: &Path, normalize_override: Option<bool>) -> Result<Self> {
105        let tokenizer_path = path.join("tokenizer.json");
106        let model_path = path.join("model.safetensors");
107        let config_path = path.join("config.json");
108        let tokenizer_bytes =
109            std::fs::read(&tokenizer_path).context("read tokenizer.json failed")?;
110        let model_bytes = std::fs::read(&model_path).context("read model.safetensors failed")?;
111        let config_bytes = std::fs::read(&config_path).context("read config.json failed")?;
112        Self::from_bytes(
113            &tokenizer_bytes,
114            &model_bytes,
115            &config_bytes,
116            normalize_override,
117        )
118    }
119
120    /// Load from in-memory bytes (e.g., for embedded resources or tests).
121    #[allow(clippy::too_many_lines)]
122    pub fn from_bytes(
123        tokenizer_bytes: &[u8],
124        model_bytes: &[u8],
125        config_bytes: &[u8],
126        normalize_override: Option<bool>,
127    ) -> Result<Self> {
128        let mut tokenizer = Tokenizer::from_bytes(tokenizer_bytes)
129            .map_err(|e| anyhow!("tokenizer load failed: {e}"))?;
130        // Disable padding/truncation. The published Model2Vec tokenizer
131        // configs (e.g. minishlab/potion-code-16M) set
132        // `padding.strategy = "BatchLongest"`, which causes
133        // `encode_batch_fast` to pad every encoding in a batch up to
134        // the longest. On a 250K-item batch this dominates wall time —
135        // we measured 33s+ in `Encoding::pad` and 70% cvar parking
136        // before disabling. We do our own per-token filtering and
137        // length cap inside `pool_ids`/`filter_ids`, so the tokenizer's
138        // pad/trunc layer is pure overhead.
139        tokenizer.with_padding(None).with_truncation(None).ok();
140
141        let cfg: Value = serde_json::from_slice(config_bytes).context("config.json parse")?;
142        let cfg_norm = cfg
143            .get("normalize")
144            .and_then(Value::as_bool)
145            .unwrap_or(true);
146        let normalize = normalize_override.unwrap_or(cfg_norm);
147
148        let safet = SafeTensors::deserialize(model_bytes).context("safetensors deserialize")?;
149
150        // The embedding tensor is named "embeddings" in canonical
151        // Model2Vec packs, "0" in some sentence-transformers exports,
152        // and "embedding.weight" in older variants. Try in that order.
153        let embed_tensor = safet
154            .tensor("embeddings")
155            .or_else(|_| safet.tensor("0"))
156            .or_else(|_| safet.tensor("embedding.weight"))
157            .map_err(|_| anyhow!("embeddings tensor not found in safetensors"))?;
158        let [rows, cols]: [usize; 2] = embed_tensor
159            .shape()
160            .try_into()
161            .map_err(|_| anyhow!("embedding tensor is not 2-D"))?;
162        let raw = embed_tensor.data();
163        let floats: Vec<f32> = match embed_tensor.dtype() {
164            Dtype::F32 => raw
165                .chunks_exact(4)
166                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
167                .collect(),
168            Dtype::F16 => raw
169                .chunks_exact(2)
170                .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
171                .collect(),
172            Dtype::I8 => raw.iter().map(|&b| f32::from(b.cast_signed())).collect(),
173            other => return Err(anyhow!("unsupported embedding dtype: {other:?}")),
174        };
175        let embeddings = Array2::from_shape_vec((rows, cols), floats)
176            .context("embedding matrix shape mismatch")?;
177
178        // Optional "weights" tensor (per-token scales, in some packs).
179        let weights = safet.tensor("weights").ok().map(|t| {
180            let raw = t.data();
181            match t.dtype() {
182                Dtype::F64 => raw
183                    .chunks_exact(8)
184                    .map(|b| {
185                        // Per-token weights only need f32 precision; f64
186                        // values in published Model2Vec packs are
187                        // always small constants well within f32 range.
188                        #[expect(
189                            clippy::cast_possible_truncation,
190                            reason = "weights are bounded; f32 precision is sufficient downstream"
191                        )]
192                        let v = f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])
193                            as f32;
194                        v
195                    })
196                    .collect::<Vec<f32>>(),
197                Dtype::F32 => raw
198                    .chunks_exact(4)
199                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
200                    .collect::<Vec<f32>>(),
201                Dtype::F16 => raw
202                    .chunks_exact(2)
203                    .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
204                    .collect::<Vec<f32>>(),
205                _ => Vec::new(),
206            }
207        });
208
209        // Optional "mapping" tensor (token-id → embedding row).
210        //
211        // Published Model2Vec packs serialize this as **`int64`** (the
212        // numpy default for index arrays); some older packs use `int32`.
213        // Read the on-disk dtype rather than assuming a width: an i32
214        // read against an i64 tensor splits each true index into low/
215        // high halves and corrupts every embedding lookup, which silently
216        // turned the bi-encoder into a random hash.
217        let token_mapping = safet.tensor("mapping").ok().map(|t| {
218            let raw = t.data();
219            #[expect(
220                clippy::cast_sign_loss,
221                clippy::cast_possible_truncation,
222                reason = "mapping values are non-negative row indices well within usize range"
223            )]
224            match t.dtype() {
225                Dtype::I64 => raw
226                    .chunks_exact(8)
227                    .map(|b| {
228                        i64::from_le_bytes([
229                            b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
230                        ]) as usize
231                    })
232                    .collect::<Vec<usize>>(),
233                Dtype::I32 => raw
234                    .chunks_exact(4)
235                    .map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize)
236                    .collect::<Vec<usize>>(),
237                Dtype::U32 => raw
238                    .chunks_exact(4)
239                    .map(|b| u32::from_le_bytes([b[0], b[1], b[2], b[3]]) as usize)
240                    .collect::<Vec<usize>>(),
241                Dtype::U64 => raw
242                    .chunks_exact(8)
243                    .map(|b| {
244                        u64::from_le_bytes([
245                            b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
246                        ]) as usize
247                    })
248                    .collect::<Vec<usize>>(),
249                _ => Vec::new(),
250            }
251        });
252
253        let (median_token_length, unk_token_id) = compute_metadata(&tokenizer)?;
254
255        Ok(Self {
256            tokenizer,
257            embeddings,
258            weights,
259            token_mapping,
260            normalize,
261            median_token_length,
262            unk_token_id,
263        })
264    }
265
266    /// Embedding dimension.
267    #[must_use]
268    pub fn hidden_dim(&self) -> usize {
269        self.embeddings.ncols()
270    }
271
272    /// Encode a single text into a row vector.
273    ///
274    /// Used at query time. The tokenization step is single-text so the
275    /// nested-rayon trap doesn't apply, but it's a separate code path
276    /// that avoids the unnecessary `encode_batch_fast` setup.
277    pub fn encode_query(&self, text: &str) -> Vec<f32> {
278        let truncated = truncate_chars(text, DEFAULT_MAX_TOKENS, self.median_token_length);
279        let Ok(encoding) = self.tokenizer.encode_fast(truncated, false) else {
280            return vec![0.0; self.hidden_dim()];
281        };
282        let ids = filter_ids(encoding.get_ids(), self.unk_token_id, DEFAULT_MAX_TOKENS);
283        self.pool_ids(&ids)
284    }
285
286    /// Encode a batch of texts.
287    ///
288    /// Iterates over fixed-size sub-batches (`BATCH_SIZE = 1024`), each
289    /// tokenized via `encode_batch_fast` (parallel internally inside
290    /// tokenizers) and then mean-pooled via `par_iter` on the rayon
291    /// pool. Calling one giant `encode_batch_fast` on a 250K-item
292    /// corpus dominates wall time in `Encoding` allocation + internal
293    /// chunk scheduling; the 1024-batch shape mirrors
294    /// `model2vec_rs`'s internal default and measured noticeably
295    /// faster on a 92K-file linux-source corpus.
296    pub fn encode_batch(&self, texts: &[&str]) -> Vec<Vec<f32>> {
297        if texts.is_empty() {
298            return Vec::new();
299        }
300        let mut out: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
301        for chunk in texts.chunks(BATCH_SIZE) {
302            let truncated: Vec<String> = chunk
303                .iter()
304                .map(|t| {
305                    truncate_chars(t, DEFAULT_MAX_TOKENS, self.median_token_length).to_string()
306                })
307                .collect();
308            let Ok(encodings) = self.tokenizer.encode_batch_fast::<String>(truncated, false) else {
309                out.extend(std::iter::repeat_n(
310                    vec![0.0; self.hidden_dim()],
311                    chunk.len(),
312                ));
313                continue;
314            };
315            let pooled: Vec<Vec<f32>> = encodings
316                .par_iter()
317                .map(|enc| {
318                    let ids = filter_ids(enc.get_ids(), self.unk_token_id, DEFAULT_MAX_TOKENS);
319                    self.pool_ids(&ids)
320                })
321                .collect();
322            out.extend(pooled);
323        }
324        out
325    }
326
327    /// Mean-pool a list of token ids into one row vector.
328    ///
329    /// Hot kernel: the inner accumulator runs O(tokens × hidden_dim)
330    /// per chunk and was profile-visible at 3.5% self on the linux
331    /// corpus (~38s of 104s wall). Hand-vectorized with `wide::f32x8`
332    /// (8-lane SIMD: NEON x2 on aarch64, AVX2 on x86_64). For
333    /// `potion-code-16M` (hidden_dim = 256), the inner loop is 32
334    /// 8-wide adds per token instead of 256 scalar adds — ~4x
335    /// reduction in instruction count, with fused multiply-add on
336    /// the weighted-token path.
337    ///
338    /// `pool_ids` itself is serial — parallelism is per-chunk via the
339    /// caller's `par_iter`.
340    fn pool_ids(&self, ids: &[u32]) -> Vec<f32> {
341        let dim = self.hidden_dim();
342        let mut sum = vec![0.0_f32; dim];
343        let mut count: usize = 0;
344        // `as_slice()` returns `Some(&[f32])` for standard-layout
345        // arrays. `from_shape_vec` always produces standard layout,
346        // so this never returns None for our embedding matrix —
347        // expect with a clear panic message in case that ever
348        // changes.
349        let embeddings_slice = self
350            .embeddings
351            .as_slice()
352            .expect("embedding matrix is non-contiguous; static_model load invariant violated");
353        let nrows = self.embeddings.nrows();
354        for &id in ids {
355            let tok = id as usize;
356            let row_idx = self
357                .token_mapping
358                .as_deref()
359                .and_then(|m| m.get(tok).copied())
360                .unwrap_or(tok);
361            if row_idx >= nrows {
362                continue;
363            }
364            let row_start = row_idx * dim;
365            let row = &embeddings_slice[row_start..row_start + dim];
366            let scale = self
367                .weights
368                .as_deref()
369                .and_then(|w| w.get(tok).copied())
370                .unwrap_or(1.0);
371            // Bit-exact comparison against 1.0 is intentional: the
372            // weights tensor (when present) stores small constants that
373            // are either exactly 1.0 (no scaling, fast path) or genuine
374            // per-token scalars. Treating a near-1.0 weight as "skip
375            // scaling" would silently bias the embedding.
376            #[expect(
377                clippy::float_cmp,
378                reason = "bit-exact 1.0 check is the intended fast-path gate"
379            )]
380            let no_scale = scale == 1.0;
381            if no_scale {
382                accumulate_f32x8(&mut sum, row);
383            } else {
384                accumulate_scaled_f32x8(&mut sum, row, scale);
385            }
386            count += 1;
387        }
388        let denom = count.max(1) as f32;
389        scale_in_place_f32x8(&mut sum, 1.0 / denom);
390        if self.normalize {
391            let norm = l2_norm_f32x8(&sum).max(1e-12);
392            scale_in_place_f32x8(&mut sum, 1.0 / norm);
393        }
394        sum
395    }
396}
397
398/// Truncate `s` to at most `max_tokens * median_len` chars without
399/// splitting a UTF-8 boundary. Matches Model2Vec's pre-tokenization
400/// safety cap (BPE on a multi-MB string is pathological).
401fn truncate_chars(s: &str, max_tokens: usize, median_len: usize) -> &str {
402    s.char_indices()
403        .nth(max_tokens.saturating_mul(median_len))
404        .map_or(s, |(byte_idx, _)| &s[..byte_idx])
405}
406
407// ---------------------------------------------------------------------------
408// SIMD pool kernels.
409//
410// All three helpers process `f32x8` blocks (8 lanes) followed by a scalar
411// tail for `len % 8`. f32x8 maps to two NEON `float32x4_t` registers on
412// aarch64 and one AVX2 `__m256` register on x86_64; portable via the `wide`
413// crate. The weighted accumulator uses `mul_add` which lowers to FMA where
414// available (vfmaq_f32 / vfmadd231ps).
415//
416// For the canonical `potion-code-16M` model (hidden_dim = 256, 8-divisible),
417// the scalar tail is never entered.
418// ---------------------------------------------------------------------------
419
420/// `acc[i] += row[i]` for `i in 0..acc.len()`, vectorized.
421fn accumulate_f32x8(acc: &mut [f32], row: &[f32]) {
422    debug_assert_eq!(acc.len(), row.len(), "pool dim mismatch");
423    let n = acc.len();
424    let body = n - (n % 8);
425    let (acc_body, acc_tail) = acc.split_at_mut(body);
426    let (row_body, row_tail) = row.split_at(body);
427    for (a_chunk, r_chunk) in acc_body.chunks_exact_mut(8).zip(row_body.chunks_exact(8)) {
428        let a = f32x8::from(<[f32; 8]>::try_from(&*a_chunk).unwrap());
429        let r = f32x8::from(<[f32; 8]>::try_from(r_chunk).unwrap());
430        a_chunk.copy_from_slice((a + r).as_array());
431    }
432    for (a, &r) in acc_tail.iter_mut().zip(row_tail.iter()) {
433        *a += r;
434    }
435}
436
437/// `acc[i] += row[i] * scale` for `i in 0..acc.len()`, vectorized with FMA.
438fn accumulate_scaled_f32x8(acc: &mut [f32], row: &[f32], scale: f32) {
439    debug_assert_eq!(acc.len(), row.len(), "pool dim mismatch");
440    let n = acc.len();
441    let body = n - (n % 8);
442    let (acc_body, acc_tail) = acc.split_at_mut(body);
443    let (row_body, row_tail) = row.split_at(body);
444    let scale_v = f32x8::splat(scale);
445    for (a_chunk, r_chunk) in acc_body.chunks_exact_mut(8).zip(row_body.chunks_exact(8)) {
446        let a = f32x8::from(<[f32; 8]>::try_from(&*a_chunk).unwrap());
447        let r = f32x8::from(<[f32; 8]>::try_from(r_chunk).unwrap());
448        // mul_add: a + (r * scale_v); lowers to vfmaq_f32 on aarch64.
449        a_chunk.copy_from_slice(r.mul_add(scale_v, a).as_array());
450    }
451    for (a, &r) in acc_tail.iter_mut().zip(row_tail.iter()) {
452        *a += r * scale;
453    }
454}
455
456/// `v[i] *= factor`, vectorized.
457fn scale_in_place_f32x8(v: &mut [f32], factor: f32) {
458    let n = v.len();
459    let body = n - (n % 8);
460    let (body_slice, tail) = v.split_at_mut(body);
461    let factor_v = f32x8::splat(factor);
462    for chunk in body_slice.chunks_exact_mut(8) {
463        let x = f32x8::from(<[f32; 8]>::try_from(&*chunk).unwrap());
464        chunk.copy_from_slice((x * factor_v).as_array());
465    }
466    for x in tail.iter_mut() {
467        *x *= factor;
468    }
469}
470
471/// L2 norm of `v`, vectorized.
472fn l2_norm_f32x8(v: &[f32]) -> f32 {
473    let n = v.len();
474    let body = n - (n % 8);
475    let (body_slice, tail) = v.split_at(body);
476    let mut acc_v = f32x8::splat(0.0);
477    for chunk in body_slice.chunks_exact(8) {
478        let x = f32x8::from(<[f32; 8]>::try_from(chunk).unwrap());
479        acc_v = x.mul_add(x, acc_v);
480    }
481    let mut sum_sq: f32 = acc_v.as_array().iter().sum();
482    for &x in tail {
483        sum_sq += x * x;
484    }
485    sum_sq.sqrt()
486}
487
488/// Drop unk tokens (if any) and cap to `max_tokens`. Returns an owned
489/// `Vec<u32>` to avoid lifetime-juggling against the encoding object.
490fn filter_ids(ids: &[u32], unk_id: Option<usize>, max_tokens: usize) -> Vec<u32> {
491    let mut out: Vec<u32> = match unk_id {
492        Some(u) => ids.iter().copied().filter(|&i| i as usize != u).collect(),
493        None => ids.to_vec(),
494    };
495    if out.len() > max_tokens {
496        out.truncate(max_tokens);
497    }
498    out
499}
500
501/// Compute the tokenizer-derived metadata (median token length + unk id).
502fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
503    let mut lens: Vec<usize> = tokenizer
504        .get_vocab(false)
505        .keys()
506        .map(std::string::String::len)
507        .collect();
508    lens.sort_unstable();
509    let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
510
511    let spec: Value =
512        serde_json::to_value(tokenizer).context("tokenizer serialize for unk lookup")?;
513    let unk_token = spec
514        .get("model")
515        .and_then(|m| m.get("unk_token"))
516        .and_then(Value::as_str);
517    let unk_token_id = match unk_token {
518        Some(tok) => tokenizer.token_to_id(tok).map(|id| id as usize),
519        None => None,
520    };
521    Ok((median_token_length, unk_token_id))
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    /// `pool_ids` empty input produces a normalized zero-ish vector
529    /// (well, 0/0 is masked by `count.max(1)` → divide by 1 → zeros →
530    /// L2 norm 0 → `max(1e-12)` → still zeros).
531    #[test]
532    fn pool_ids_empty_input() {
533        // Build a tiny model in-memory to exercise pool_ids without
534        // loading a real tokenizer. We construct just enough state.
535        // For this test we skip the full file path and assert via a
536        // direct math check on a hand-rolled state.
537        // (A more complete test would require a real tokenizer asset.)
538        let _ = compute_metadata;
539        // Compile-time exercise: just ensure this file compiles cleanly.
540    }
541}