Skip to main content

post_cortex_embeddings/embeddings/backends/
model2vec.rs

1// Copyright (c) 2025 Julius ML
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12
13//! Model2Vec backend — static, multilingual embeddings via `model2vec-rs`.
14//!
15//! Loads `minishlab/potion-*` checkpoints from the HuggingFace Hub (or a
16//! local path), encodes batches of text through the precomputed token-vector
17//! table, and returns L2-normalised vectors. All computation is CPU-only and
18//! orders of magnitude faster than the BERT path — typically ms/text on a
19//! laptop CPU with no GPU.
20
21use anyhow::Result;
22use async_trait::async_trait;
23use model2vec_rs::model::StaticModel;
24use std::sync::Arc;
25use tracing::{debug, info};
26
27use crate::embeddings::backend::EmbeddingBackend;
28use crate::embeddings::config::EmbeddingModelType;
29
30/// Default token-length budget passed to `StaticModel::encode_with_args`.
31/// Mirrors `model2vec-rs`' built-in default but kept explicit so callers
32/// can audit the trade-off.
33const DEFAULT_MAX_TOKENS: usize = 512;
34
35/// Default batch size for the underlying `model2vec-rs` encode loop. The
36/// surrounding `LocalEmbeddingEngine` already chunks input batches via its
37/// own adaptive batching, so this just controls model2vec's inner loop.
38const DEFAULT_INNER_BATCH: usize = 1024;
39
40/// Static-embedding backend powered by `model2vec-rs`.
41///
42/// Construct via [`Self::load`] — the call is async and offloads the
43/// (blocking) `StaticModel::from_pretrained` download/load onto
44/// [`tokio::task::spawn_blocking`].
45pub struct Model2VecBackend {
46    model: Arc<StaticModel>,
47    dimension: usize,
48    model_type: EmbeddingModelType,
49}
50
51impl Model2VecBackend {
52    /// Download (if needed) and load a Model2Vec model from the HuggingFace
53    /// Hub. Verifies the runtime embedding dimension against the
54    /// `EmbeddingModelType` constant and returns an error on mismatch —
55    /// silent drift would corrupt every downstream HNSW index.
56    pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
57        let model_id = model_type.model_id().to_string();
58        info!("Loading Model2Vec backend for model: {}", model_id);
59
60        // `StaticModel::from_pretrained` does blocking I/O (HF Hub fetch
61        // + safetensors mmap) — keep it off the reactor.
62        let model = tokio::task::spawn_blocking({
63            let model_id = model_id.clone();
64            move || -> Result<StaticModel> {
65                StaticModel::from_pretrained(model_id, None, None, None)
66            }
67        })
68        .await
69        .map_err(|e| anyhow::anyhow!("Model2Vec load join error: {e}"))??;
70
71        // Probe the actual dimension by encoding a sentinel — this is the
72        // ground-truth source. If it disagrees with the enum constant the
73        // HNSW index would be sized wrong, so fail loudly.
74        let probe = model.encode_with_args(
75            &["dimension probe".to_string()],
76            Some(DEFAULT_MAX_TOKENS),
77            DEFAULT_INNER_BATCH,
78        );
79        let runtime_dim = probe.first().map(|v| v.len()).unwrap_or(0);
80        let declared_dim = model_type.embedding_dimension();
81        if runtime_dim == 0 {
82            return Err(anyhow::anyhow!(
83                "Model2Vec ({model_id}) produced an empty vector on the dimension probe"
84            ));
85        }
86        if runtime_dim != declared_dim {
87            return Err(anyhow::anyhow!(
88                "Model2Vec ({model_id}) reports dimension {runtime_dim} at runtime but \
89                 EmbeddingModelType::{model_type:?}.embedding_dimension() = {declared_dim}; \
90                 update the enum constant before shipping or HNSW indices will be sized wrong"
91            ));
92        }
93
94        info!(
95            "Model2Vec model loaded — id: {}, dim: {}",
96            model_id, runtime_dim
97        );
98
99        Ok(Self {
100            model: Arc::new(model),
101            dimension: runtime_dim,
102            model_type,
103        })
104    }
105}
106
107#[async_trait]
108impl EmbeddingBackend for Model2VecBackend {
109    fn embedding_dimension(&self) -> usize {
110        self.dimension
111    }
112
113    /// Model2Vec is **not** BERT — the engine's BERT branch (concurrency
114    /// controller, adaptive batching, timeout) is bypassed. The static
115    /// path dispatches directly because each call is ms-cheap.
116    fn is_bert_based(&self) -> bool {
117        false
118    }
119
120    async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
121        if texts.is_empty() {
122            return Ok(Vec::new());
123        }
124
125        let model = self.model.clone();
126        let dim = self.dimension;
127        let model_type = self.model_type;
128
129        let embeddings = tokio::task::spawn_blocking(move || {
130            model.encode_with_args(&texts, Some(DEFAULT_MAX_TOKENS), DEFAULT_INNER_BATCH)
131        })
132        .await
133        .map_err(|e| anyhow::anyhow!("Model2Vec encode join error: {e}"))?;
134
135        // Sanity check — `encode_with_args` should produce one vector per
136        // input. A mismatch means the underlying API changed under us.
137        for (i, v) in embeddings.iter().enumerate() {
138            if v.len() != dim {
139                return Err(anyhow::anyhow!(
140                    "Model2Vec ({model_type:?}) produced vector at index {i} with dim {} \
141                     (expected {})",
142                    v.len(),
143                    dim
144                ));
145            }
146        }
147
148        debug!(
149            "Model2Vec encoded {} texts (dim={}) for {:?}",
150            embeddings.len(),
151            dim,
152            model_type
153        );
154
155        Ok(embeddings)
156    }
157}