1use candle_core::Tensor;
2use candle_nn::VarBuilder;
3use candle_transformers::models::bert::{BertModel, Config, DTYPE};
4use hf_hub::api::sync::Api;
6use hf_hub::{Repo, RepoType};
7use tokenizers::{PaddingParams, Tokenizer};
8
9pub use super::error::{EmbedError, Error, NewEmbedderError};
10use super::{DistributionShift, Embedding, EmbeddingCache};
11
12#[derive(
13 Debug,
14 Clone,
15 Copy,
16 Default,
17 Hash,
18 PartialEq,
19 Eq,
20 serde::Deserialize,
21 serde::Serialize,
22 deserr::Deserr,
23)]
24#[serde(deny_unknown_fields, rename_all = "camelCase")]
25#[deserr(rename_all = camelCase, deny_unknown_fields)]
26enum WeightSource {
27 #[default]
28 Safetensors,
29 Pytorch,
30}
31
32#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
33pub struct EmbedderOptions {
34 pub model: String,
35 pub revision: Option<String>,
36 pub distribution: Option<DistributionShift>,
37 #[serde(default)]
38 pub pooling: OverridePooling,
39}
40
41#[derive(
42 Debug,
43 Clone,
44 Copy,
45 Default,
46 Hash,
47 PartialEq,
48 Eq,
49 serde::Deserialize,
50 serde::Serialize,
51 utoipa::ToSchema,
52 deserr::Deserr,
53)]
54#[deserr(rename_all = camelCase, deny_unknown_fields)]
55#[serde(rename_all = "camelCase")]
56pub enum OverridePooling {
57 UseModel,
58 ForceCls,
59 #[default]
60 ForceMean,
61}
62
63impl EmbedderOptions {
64 pub fn new() -> Self {
65 Self {
66 model: "BAAI/bge-base-en-v1.5".to_string(),
67 revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
68 distribution: None,
69 pooling: OverridePooling::UseModel,
70 }
71 }
72}
73
74impl Default for EmbedderOptions {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80pub struct Embedder {
82 model: BertModel,
83 tokenizer: Tokenizer,
84 options: EmbedderOptions,
85 dimensions: usize,
86 pooling: Pooling,
87 cache: EmbeddingCache,
88}
89
90impl std::fmt::Debug for Embedder {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("Embedder")
93 .field("model", &self.options.model)
94 .field("tokenizer", &self.tokenizer)
95 .field("options", &self.options)
96 .field("pooling", &self.pooling)
97 .finish()
98 }
99}
100
101#[derive(Clone, Copy, serde::Deserialize)]
102struct PoolingConfig {
103 #[serde(default)]
104 pub pooling_mode_cls_token: bool,
105 #[serde(default)]
106 pub pooling_mode_mean_tokens: bool,
107 #[serde(default)]
108 pub pooling_mode_max_tokens: bool,
109 #[serde(default)]
110 pub pooling_mode_mean_sqrt_len_tokens: bool,
111 #[serde(default)]
112 pub pooling_mode_lasttoken: bool,
113}
114
115#[derive(Debug, Clone, Copy, Default)]
116pub enum Pooling {
117 #[default]
118 Mean,
119 Cls,
120 Max,
121 MeanSqrtLen,
122 LastToken,
123}
124impl Pooling {
125 fn override_with(&mut self, pooling: OverridePooling) {
126 match pooling {
127 OverridePooling::UseModel => {}
128 OverridePooling::ForceCls => *self = Pooling::Cls,
129 OverridePooling::ForceMean => *self = Pooling::Mean,
130 }
131 }
132}
133
134impl From<PoolingConfig> for Pooling {
135 fn from(value: PoolingConfig) -> Self {
136 if value.pooling_mode_cls_token {
137 Self::Cls
138 } else if value.pooling_mode_mean_tokens {
139 Self::Mean
140 } else if value.pooling_mode_lasttoken {
141 Self::LastToken
142 } else if value.pooling_mode_mean_sqrt_len_tokens {
143 Self::MeanSqrtLen
144 } else if value.pooling_mode_max_tokens {
145 Self::Max
146 } else {
147 Self::default()
148 }
149 }
150}
151
152impl Embedder {
153 pub fn new(
154 options: EmbedderOptions,
155 cache_cap: usize,
156 ) -> std::result::Result<Self, NewEmbedderError> {
157 let device = match candle_core::Device::cuda_if_available(0) {
158 Ok(device) => device,
159 Err(error) => {
160 tracing::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error);
161 candle_core::Device::Cpu
162 }
163 };
164 let repo = match options.revision.clone() {
165 Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
166 None => Repo::model(options.model.clone()),
167 };
168 let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = {
169 let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
170 let api = api.repo(repo);
171 let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
172 let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
173 let (weights, source) = {
174 api.get("model.safetensors")
175 .map(|filename| (filename, WeightSource::Safetensors))
176 .or_else(|_| {
177 api.get("pytorch_model.bin")
178 .map(|filename| (filename, WeightSource::Pytorch))
179 })
180 .map_err(NewEmbedderError::api_get)?
181 };
182 let pooling = match api.get("1_Pooling/config.json") {
183 Ok(pooling) => Some(pooling),
184 Err(hf_hub::api::sync::ApiError::RequestError(error))
185 if matches!(*error, ureq::Error::Status(404, _,)) =>
186 {
187 None
189 }
190 Err(error) => return Err(NewEmbedderError::api_get(error)),
191 };
192 let mut pooling: Pooling = match pooling {
193 Some(pooling_filename) => {
194 let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
195 NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
196 })?;
197
198 let pooling: PoolingConfig =
199 serde_json::from_str(&pooling).map_err(|inner| {
200 NewEmbedderError::deserialize_pooling_config(
201 options.model.clone(),
202 pooling_filename,
203 inner,
204 )
205 })?;
206 pooling.into()
207 }
208 None => Pooling::default(),
209 };
210
211 pooling.override_with(options.pooling);
212
213 (config, tokenizer, weights, source, pooling)
214 };
215
216 let config = std::fs::read_to_string(&config_filename)
217 .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
218 let config: Config = serde_json::from_str(&config).map_err(|inner| {
219 NewEmbedderError::deserialize_config(
220 options.model.clone(),
221 config,
222 config_filename,
223 inner,
224 )
225 })?;
226 let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
227 .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
228
229 let vb = match weight_source {
230 WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
231 .map_err(NewEmbedderError::pytorch_weight)?,
232 WeightSource::Safetensors => unsafe {
233 VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
234 .map_err(NewEmbedderError::safetensor_weight)?
235 },
236 };
237
238 tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
239
240 let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
241
242 if let Some(pp) = tokenizer.get_padding_mut() {
243 pp.strategy = tokenizers::PaddingStrategy::BatchLongest
244 } else {
245 let pp = PaddingParams {
246 strategy: tokenizers::PaddingStrategy::BatchLongest,
247 ..Default::default()
248 };
249 tokenizer.with_padding(Some(pp));
250 }
251
252 let mut this = Self {
253 model,
254 tokenizer,
255 options,
256 dimensions: 0,
257 pooling,
258 cache: EmbeddingCache::new(cache_cap),
259 };
260
261 let embeddings = this
262 .embed(vec!["test".into()])
263 .map_err(NewEmbedderError::could_not_determine_dimension)?;
264 this.dimensions = embeddings.first().unwrap().len();
265
266 Ok(this)
267 }
268
269 pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
270 texts.into_iter().map(|text| self.embed_one(&text)).collect()
271 }
272
273 fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
274 match pooling {
275 Pooling::Mean => Self::mean_pooling(embeddings),
276 Pooling::Cls => Self::cls_pooling(embeddings),
277 Pooling::Max => Self::max_pooling(embeddings),
278 Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings),
279 Pooling::LastToken => Self::last_token_pooling(embeddings),
280 }
281 }
282
283 fn cls_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
284 embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value)
285 }
286
287 fn mean_sqrt_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
288 let (_n_sentence, n_tokens, _hidden_size) =
289 embeddings.dims3().map_err(EmbedError::tensor_shape)?;
290
291 (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt())
292 .map_err(EmbedError::tensor_shape)
293 }
294
295 fn mean_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
296 let (_n_sentence, n_tokens, _hidden_size) =
297 embeddings.dims3().map_err(EmbedError::tensor_shape)?;
298
299 (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
300 .map_err(EmbedError::tensor_shape)
301 }
302
303 fn max_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
304 embeddings.max(1).map_err(EmbedError::tensor_shape)
305 }
306
307 fn last_token_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
308 let (_n_sentence, n_tokens, _hidden_size) =
309 embeddings.dims3().map_err(EmbedError::tensor_shape)?;
310
311 embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value)
312 }
313
314 pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
315 let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
316 let token_ids = tokens.get_ids();
317 let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
318 let token_ids =
319 Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
320 let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
321 let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
322 let embeddings = self
323 .model
324 .forward(&token_ids, &token_type_ids, None)
325 .map_err(EmbedError::model_forward)?;
326
327 let embedding = Self::pooling(embeddings, self.pooling)?;
328
329 let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
330 let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
331 Ok(embedding)
332 }
333
334 pub fn embed_index(
335 &self,
336 text_chunks: Vec<Vec<String>>,
337 ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
338 text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
339 }
340
341 pub fn chunk_count_hint(&self) -> usize {
342 1
343 }
344
345 pub fn prompt_count_in_chunk_hint(&self) -> usize {
346 std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
347 }
348
349 pub fn dimensions(&self) -> usize {
350 self.dimensions
351 }
352
353 pub fn distribution(&self) -> Option<DistributionShift> {
354 self.options.distribution.or_else(|| {
355 if self.options.model == "BAAI/bge-base-en-v1.5" {
356 Some(DistributionShift {
357 current_mean: ordered_float::OrderedFloat(0.85),
358 current_sigma: ordered_float::OrderedFloat(0.1),
359 })
360 } else {
361 None
362 }
363 })
364 }
365
366 pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
367 texts.iter().map(|text| self.embed_one(text)).collect()
368 }
369
370 pub(super) fn cache(&self) -> &EmbeddingCache {
371 &self.cache
372 }
373}