post_cortex_embeddings/embeddings/backends/model2vec.rs
1// Copyright (c) 2025 Julius ML
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12
13//! Model2Vec backend — static, multilingual embeddings via `model2vec-rs`.
14//!
15//! Loads `minishlab/potion-*` checkpoints from the HuggingFace Hub (or a
16//! local path), encodes batches of text through the precomputed token-vector
17//! table, and returns L2-normalised vectors. All computation is CPU-only and
18//! orders of magnitude faster than the BERT path — typically ms/text on a
19//! laptop CPU with no GPU.
20
21use anyhow::Result;
22use async_trait::async_trait;
23use model2vec_rs::model::StaticModel;
24use std::sync::Arc;
25use tracing::{debug, info};
26
27use crate::embeddings::backend::EmbeddingBackend;
28use crate::embeddings::config::EmbeddingModelType;
29
30/// Default token-length budget passed to `StaticModel::encode_with_args`.
31/// Mirrors `model2vec-rs`' built-in default but kept explicit so callers
32/// can audit the trade-off.
33const DEFAULT_MAX_TOKENS: usize = 512;
34
35/// Default batch size for the underlying `model2vec-rs` encode loop. The
36/// surrounding `LocalEmbeddingEngine` already chunks input batches via its
37/// own adaptive batching, so this just controls model2vec's inner loop.
38const DEFAULT_INNER_BATCH: usize = 1024;
39
40/// Static-embedding backend powered by `model2vec-rs`.
41///
42/// Construct via [`Self::load`] — the call is async and offloads the
43/// (blocking) `StaticModel::from_pretrained` download/load onto
44/// [`tokio::task::spawn_blocking`].
45pub struct Model2VecBackend {
46 model: Arc<StaticModel>,
47 dimension: usize,
48 model_type: EmbeddingModelType,
49}
50
51impl Model2VecBackend {
52 /// Download (if needed) and load a Model2Vec model from the HuggingFace
53 /// Hub. Verifies the runtime embedding dimension against the
54 /// `EmbeddingModelType` constant and returns an error on mismatch —
55 /// silent drift would corrupt every downstream HNSW index.
56 pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
57 let model_id = model_type.model_id().to_string();
58 info!("Loading Model2Vec backend for model: {}", model_id);
59
60 // `StaticModel::from_pretrained` does blocking I/O (HF Hub fetch
61 // + safetensors mmap) — keep it off the reactor.
62 let model = tokio::task::spawn_blocking({
63 let model_id = model_id.clone();
64 move || -> Result<StaticModel> {
65 StaticModel::from_pretrained(model_id, None, None, None)
66 }
67 })
68 .await
69 .map_err(|e| anyhow::anyhow!("Model2Vec load join error: {e}"))??;
70
71 // Probe the actual dimension by encoding a sentinel — this is the
72 // ground-truth source. If it disagrees with the enum constant the
73 // HNSW index would be sized wrong, so fail loudly.
74 let probe = model.encode_with_args(
75 &["dimension probe".to_string()],
76 Some(DEFAULT_MAX_TOKENS),
77 DEFAULT_INNER_BATCH,
78 );
79 let runtime_dim = probe.first().map(|v| v.len()).unwrap_or(0);
80 let declared_dim = model_type.embedding_dimension();
81 if runtime_dim == 0 {
82 return Err(anyhow::anyhow!(
83 "Model2Vec ({model_id}) produced an empty vector on the dimension probe"
84 ));
85 }
86 if runtime_dim != declared_dim {
87 return Err(anyhow::anyhow!(
88 "Model2Vec ({model_id}) reports dimension {runtime_dim} at runtime but \
89 EmbeddingModelType::{model_type:?}.embedding_dimension() = {declared_dim}; \
90 update the enum constant before shipping or HNSW indices will be sized wrong"
91 ));
92 }
93
94 info!(
95 "Model2Vec model loaded — id: {}, dim: {}",
96 model_id, runtime_dim
97 );
98
99 Ok(Self {
100 model: Arc::new(model),
101 dimension: runtime_dim,
102 model_type,
103 })
104 }
105}
106
107#[async_trait]
108impl EmbeddingBackend for Model2VecBackend {
109 fn embedding_dimension(&self) -> usize {
110 self.dimension
111 }
112
113 /// Model2Vec is **not** BERT — the engine's BERT branch (concurrency
114 /// controller, adaptive batching, timeout) is bypassed. The static
115 /// path dispatches directly because each call is ms-cheap.
116 fn is_bert_based(&self) -> bool {
117 false
118 }
119
120 async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
121 if texts.is_empty() {
122 return Ok(Vec::new());
123 }
124
125 let model = self.model.clone();
126 let dim = self.dimension;
127 let model_type = self.model_type;
128
129 let embeddings = tokio::task::spawn_blocking(move || {
130 model.encode_with_args(&texts, Some(DEFAULT_MAX_TOKENS), DEFAULT_INNER_BATCH)
131 })
132 .await
133 .map_err(|e| anyhow::anyhow!("Model2Vec encode join error: {e}"))?;
134
135 // Sanity check — `encode_with_args` should produce one vector per
136 // input. A mismatch means the underlying API changed under us.
137 for (i, v) in embeddings.iter().enumerate() {
138 if v.len() != dim {
139 return Err(anyhow::anyhow!(
140 "Model2Vec ({model_type:?}) produced vector at index {i} with dim {} \
141 (expected {})",
142 v.len(),
143 dim
144 ));
145 }
146 }
147
148 debug!(
149 "Model2Vec encoded {} texts (dim={}) for {:?}",
150 embeddings.len(),
151 dim,
152 model_type
153 );
154
155 Ok(embeddings)
156 }
157}