milli_core/vector/
hf.rs

1use candle_core::Tensor;
2use candle_nn::VarBuilder;
3use candle_transformers::models::bert::{BertModel, Config, DTYPE};
4// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
5use hf_hub::api::sync::Api;
6use hf_hub::{Repo, RepoType};
7use tokenizers::{PaddingParams, Tokenizer};
8
9pub use super::error::{EmbedError, Error, NewEmbedderError};
10use super::{DistributionShift, Embedding, EmbeddingCache};
11
12#[derive(
13    Debug,
14    Clone,
15    Copy,
16    Default,
17    Hash,
18    PartialEq,
19    Eq,
20    serde::Deserialize,
21    serde::Serialize,
22    deserr::Deserr,
23)]
24#[serde(deny_unknown_fields, rename_all = "camelCase")]
25#[deserr(rename_all = camelCase, deny_unknown_fields)]
26enum WeightSource {
27    #[default]
28    Safetensors,
29    Pytorch,
30}
31
32#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
33pub struct EmbedderOptions {
34    pub model: String,
35    pub revision: Option<String>,
36    pub distribution: Option<DistributionShift>,
37    #[serde(default)]
38    pub pooling: OverridePooling,
39}
40
41#[derive(
42    Debug,
43    Clone,
44    Copy,
45    Default,
46    Hash,
47    PartialEq,
48    Eq,
49    serde::Deserialize,
50    serde::Serialize,
51    utoipa::ToSchema,
52    deserr::Deserr,
53)]
54#[deserr(rename_all = camelCase, deny_unknown_fields)]
55#[serde(rename_all = "camelCase")]
56pub enum OverridePooling {
57    UseModel,
58    ForceCls,
59    #[default]
60    ForceMean,
61}
62
63impl EmbedderOptions {
64    pub fn new() -> Self {
65        Self {
66            model: "BAAI/bge-base-en-v1.5".to_string(),
67            revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
68            distribution: None,
69            pooling: OverridePooling::UseModel,
70        }
71    }
72}
73
74impl Default for EmbedderOptions {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80/// Perform embedding of documents and queries
81pub struct Embedder {
82    model: BertModel,
83    tokenizer: Tokenizer,
84    options: EmbedderOptions,
85    dimensions: usize,
86    pooling: Pooling,
87    cache: EmbeddingCache,
88}
89
90impl std::fmt::Debug for Embedder {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("Embedder")
93            .field("model", &self.options.model)
94            .field("tokenizer", &self.tokenizer)
95            .field("options", &self.options)
96            .field("pooling", &self.pooling)
97            .finish()
98    }
99}
100
101#[derive(Clone, Copy, serde::Deserialize)]
102struct PoolingConfig {
103    #[serde(default)]
104    pub pooling_mode_cls_token: bool,
105    #[serde(default)]
106    pub pooling_mode_mean_tokens: bool,
107    #[serde(default)]
108    pub pooling_mode_max_tokens: bool,
109    #[serde(default)]
110    pub pooling_mode_mean_sqrt_len_tokens: bool,
111    #[serde(default)]
112    pub pooling_mode_lasttoken: bool,
113}
114
115#[derive(Debug, Clone, Copy, Default)]
116pub enum Pooling {
117    #[default]
118    Mean,
119    Cls,
120    Max,
121    MeanSqrtLen,
122    LastToken,
123}
124impl Pooling {
125    fn override_with(&mut self, pooling: OverridePooling) {
126        match pooling {
127            OverridePooling::UseModel => {}
128            OverridePooling::ForceCls => *self = Pooling::Cls,
129            OverridePooling::ForceMean => *self = Pooling::Mean,
130        }
131    }
132}
133
134impl From<PoolingConfig> for Pooling {
135    fn from(value: PoolingConfig) -> Self {
136        if value.pooling_mode_cls_token {
137            Self::Cls
138        } else if value.pooling_mode_mean_tokens {
139            Self::Mean
140        } else if value.pooling_mode_lasttoken {
141            Self::LastToken
142        } else if value.pooling_mode_mean_sqrt_len_tokens {
143            Self::MeanSqrtLen
144        } else if value.pooling_mode_max_tokens {
145            Self::Max
146        } else {
147            Self::default()
148        }
149    }
150}
151
152impl Embedder {
153    pub fn new(
154        options: EmbedderOptions,
155        cache_cap: usize,
156    ) -> std::result::Result<Self, NewEmbedderError> {
157        let device = match candle_core::Device::cuda_if_available(0) {
158            Ok(device) => device,
159            Err(error) => {
160                tracing::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error);
161                candle_core::Device::Cpu
162            }
163        };
164        let repo = match options.revision.clone() {
165            Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
166            None => Repo::model(options.model.clone()),
167        };
168        let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = {
169            let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
170            let api = api.repo(repo);
171            let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
172            let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
173            let (weights, source) = {
174                api.get("model.safetensors")
175                    .map(|filename| (filename, WeightSource::Safetensors))
176                    .or_else(|_| {
177                        api.get("pytorch_model.bin")
178                            .map(|filename| (filename, WeightSource::Pytorch))
179                    })
180                    .map_err(NewEmbedderError::api_get)?
181            };
182            let pooling = match api.get("1_Pooling/config.json") {
183                Ok(pooling) => Some(pooling),
184                Err(hf_hub::api::sync::ApiError::RequestError(error))
185                    if matches!(*error, ureq::Error::Status(404, _,)) =>
186                {
187                    // ignore the error if the file simply doesn't exist
188                    None
189                }
190                Err(error) => return Err(NewEmbedderError::api_get(error)),
191            };
192            let mut pooling: Pooling = match pooling {
193                Some(pooling_filename) => {
194                    let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
195                        NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
196                    })?;
197
198                    let pooling: PoolingConfig =
199                        serde_json::from_str(&pooling).map_err(|inner| {
200                            NewEmbedderError::deserialize_pooling_config(
201                                options.model.clone(),
202                                pooling_filename,
203                                inner,
204                            )
205                        })?;
206                    pooling.into()
207                }
208                None => Pooling::default(),
209            };
210
211            pooling.override_with(options.pooling);
212
213            (config, tokenizer, weights, source, pooling)
214        };
215
216        let config = std::fs::read_to_string(&config_filename)
217            .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
218        let config: Config = serde_json::from_str(&config).map_err(|inner| {
219            NewEmbedderError::deserialize_config(
220                options.model.clone(),
221                config,
222                config_filename,
223                inner,
224            )
225        })?;
226        let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
227            .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
228
229        let vb = match weight_source {
230            WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
231                .map_err(NewEmbedderError::pytorch_weight)?,
232            WeightSource::Safetensors => unsafe {
233                VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
234                    .map_err(NewEmbedderError::safetensor_weight)?
235            },
236        };
237
238        tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
239
240        let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
241
242        if let Some(pp) = tokenizer.get_padding_mut() {
243            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
244        } else {
245            let pp = PaddingParams {
246                strategy: tokenizers::PaddingStrategy::BatchLongest,
247                ..Default::default()
248            };
249            tokenizer.with_padding(Some(pp));
250        }
251
252        let mut this = Self {
253            model,
254            tokenizer,
255            options,
256            dimensions: 0,
257            pooling,
258            cache: EmbeddingCache::new(cache_cap),
259        };
260
261        let embeddings = this
262            .embed(vec!["test".into()])
263            .map_err(NewEmbedderError::could_not_determine_dimension)?;
264        this.dimensions = embeddings.first().unwrap().len();
265
266        Ok(this)
267    }
268
269    pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
270        texts.into_iter().map(|text| self.embed_one(&text)).collect()
271    }
272
273    fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
274        match pooling {
275            Pooling::Mean => Self::mean_pooling(embeddings),
276            Pooling::Cls => Self::cls_pooling(embeddings),
277            Pooling::Max => Self::max_pooling(embeddings),
278            Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings),
279            Pooling::LastToken => Self::last_token_pooling(embeddings),
280        }
281    }
282
283    fn cls_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
284        embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value)
285    }
286
287    fn mean_sqrt_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
288        let (_n_sentence, n_tokens, _hidden_size) =
289            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
290
291        (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt())
292            .map_err(EmbedError::tensor_shape)
293    }
294
295    fn mean_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
296        let (_n_sentence, n_tokens, _hidden_size) =
297            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
298
299        (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
300            .map_err(EmbedError::tensor_shape)
301    }
302
303    fn max_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
304        embeddings.max(1).map_err(EmbedError::tensor_shape)
305    }
306
307    fn last_token_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
308        let (_n_sentence, n_tokens, _hidden_size) =
309            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
310
311        embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value)
312    }
313
314    pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
315        let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
316        let token_ids = tokens.get_ids();
317        let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
318        let token_ids =
319            Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
320        let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
321        let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
322        let embeddings = self
323            .model
324            .forward(&token_ids, &token_type_ids, None)
325            .map_err(EmbedError::model_forward)?;
326
327        let embedding = Self::pooling(embeddings, self.pooling)?;
328
329        let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
330        let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
331        Ok(embedding)
332    }
333
334    pub fn embed_index(
335        &self,
336        text_chunks: Vec<Vec<String>>,
337    ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
338        text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
339    }
340
341    pub fn chunk_count_hint(&self) -> usize {
342        1
343    }
344
345    pub fn prompt_count_in_chunk_hint(&self) -> usize {
346        std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
347    }
348
349    pub fn dimensions(&self) -> usize {
350        self.dimensions
351    }
352
353    pub fn distribution(&self) -> Option<DistributionShift> {
354        self.options.distribution.or_else(|| {
355            if self.options.model == "BAAI/bge-base-en-v1.5" {
356                Some(DistributionShift {
357                    current_mean: ordered_float::OrderedFloat(0.85),
358                    current_sigma: ordered_float::OrderedFloat(0.1),
359                })
360            } else {
361                None
362            }
363        })
364    }
365
366    pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
367        texts.iter().map(|text| self.embed_one(text)).collect()
368    }
369
370    pub(super) fn cache(&self) -> &EmbeddingCache {
371        &self.cache
372    }
373}