rbert 0.4.0

A simple interface for Bert embeddings
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
//! # rbert
//!
//! A Rust wrapper for [bert sentence transformers](https://arxiv.org/abs/1908.10084) implemented in [Candle](https://github.com/huggingface/candle)
//!
//! ## Usage
//!
//! ```rust, no_run
//! use kalosm_language_model::Embedder;
//! use rbert::*;
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//!     let mut bert = Bert::new().await?;
//!     let sentences = [
//!         "Cats are cool",
//!         "The geopolitical situation is dire",
//!         "Pets are great",
//!         "Napoleon was a tyrant",
//!         "Napoleon was a great general",
//!     ];
//!     let embeddings = bert.embed_batch(sentences).await?;
//!     println!("embeddings {:?}", embeddings);
//!
//!     // Find the cosine similarity between the first two sentences
//!     let mut similarities = vec![];
//!     let n_sentences = sentences.len();
//!     for (i, e_i) in embeddings.iter().enumerate() {
//!         for j in (i + 1)..n_sentences {
//!             let e_j = embeddings.get(j).unwrap();
//!             let cosine_similarity = e_j.cosine_similarity(e_i);
//!             similarities.push((cosine_similarity, i, j))
//!         }
//!     }
//!     similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
//!     for &(score, i, j) in similarities.iter() {
//!         println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
//!     }
//!
//!     Ok(())
//! }
//! ```

#![warn(missing_docs)]

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle_core::{IndexOp, Tensor};
use candle_nn::VarBuilder;
use kalosm_common::*;
use kalosm_model_types::ModelLoadingProgress;
use std::sync::{Arc, RwLock};
use tokenizers::{Encoding, PaddingParams, Tokenizer};

mod language_model;
mod raw;
mod source;

pub use crate::language_model::*;
use crate::raw::DTYPE;
pub use crate::raw::{BertModel, Config};
pub use crate::source::*;

/// A builder for a [`Bert`] model
#[derive(Default)]
pub struct BertBuilder {
    source: BertSource,
    cache: kalosm_common::Cache,
}

impl BertBuilder {
    /// Set the source of the model
    pub fn with_source(mut self, source: BertSource) -> Self {
        self.source = source;
        self
    }

    /// Build the model
    pub async fn build(self) -> Result<Bert, BertLoadingError> {
        self.build_with_loading_handler(ModelLoadingProgress::multi_bar_loading_indicator())
            .await
    }

    /// Set the cache location to use for the model (defaults DATA_DIR/kalosm/cache)
    pub fn with_cache(mut self, cache: kalosm_common::Cache) -> Self {
        self.cache = cache;

        self
    }

    /// Build the model with a loading handler
    ///
    /// ```rust, no_run
    /// use kalosm::language::*;
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), anyhow::Error> {
    /// // Create a new bert model with a loading handler
    /// let model = Bert::builder()
    ///     .build_with_loading_handler(|progress| match progress {
    ///         ModelLoadingProgress::Downloading { source, progress } => {
    ///             let progress_percent = (progress.progress * 100) as u32;
    ///             let elapsed = progress.start_time.elapsed().as_secs_f32();
    ///             println!("Downloading file {source} {progress_percent}% ({elapsed}s)");
    ///         }
    ///         ModelLoadingProgress::Loading { progress } => {
    ///             let progress = (progress * 100.0) as u32;
    ///             println!("Loading model {progress}%");
    ///         }
    ///     })
    ///     .await?;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn build_with_loading_handler(
        self,
        loading_handler: impl FnMut(ModelLoadingProgress) + Send + 'static,
    ) -> Result<Bert, BertLoadingError> {
        Bert::from_builder(self, loading_handler).await
    }
}

/// An error that can occur when loading a Bert model.
#[derive(Debug, thiserror::Error)]
pub enum BertLoadingError {
    /// An error that can occur when trying to load a Bert model from huggingface or a local file.
    #[error("Failed to load model from huggingface or local file: {0}")]
    DownloadingError(#[from] CacheError),
    /// An error that can occur when trying to load a Bert model.
    #[error("Failed to load model into device: {0}")]
    LoadModel(#[from] candle_core::Error),
    /// An error that can occur when trying to load the bert tokenizer.
    #[error("Failed to load tokenizer: {0}")]
    LoadTokenizer(tokenizers::Error),
    /// An error that can occur when trying to load the bert config.
    #[error("Failed to load config: {0}")]
    LoadConfig(serde_json::Error),
    /// A config was not found
    #[error("Config not found")]
    ConfigNotFound,
}

/// An error that can occur when running a Bert model.
#[derive(Debug, thiserror::Error)]
pub enum BertError {
    /// An error that can occur when trying to run a Bert model.
    #[error("Failed to run model: {0}")]
    Candle(#[from] candle_core::Error),
    /// An error that can occur when tokenizing or detokenizing text.
    #[error("Failed to tokenize: {0}")]
    TokenizerError(tokenizers::Error),
    /// Failed to join the thread that is running the model
    #[error("Failed to join thread: {0}")]
    Join(#[from] tokio::task::JoinError),
}

/// The pooling strategy to use when embedding text.
#[derive(Debug, Clone, Copy)]
pub enum Pooling {
    /// Take the mean embedding value for all tokens (except padding)
    Mean,
    /// Take the embedding of the CLS token for each sequence
    CLS,
}

/// A bert embedding model. The main interface for this model is [`EmbedderExt`].
///
/// # Example
/// ```rust, no_run
/// use kalosm_language_model::Embedder;
/// use rbert::*;
///
/// #[tokio::main]
/// async fn main() -> anyhow::Result<()> {
///     let mut bert = Bert::new().await?;
///     let sentences = [
///         "Cats are cool",
///         "The geopolitical situation is dire",
///         "Pets are great",
///         "Napoleon was a tyrant",
///         "Napoleon was a great general",
///     ];
///     let embeddings = bert.embed_batch(sentences).await?;
///     println!("embeddings {:?}", embeddings);
///
///     // Find the cosine similarity between the first two sentences
///     let mut similarities = vec![];
///     let n_sentences = sentences.len();
///     for (i, e_i) in embeddings.iter().enumerate() {
///         for j in (i + 1)..n_sentences {
///             let e_j = embeddings.get(j).unwrap();
///             let cosine_similarity = e_j.cosine_similarity(e_i);
///             similarities.push((cosine_similarity, i, j))
///         }
///     }
///     similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
///     for &(score, i, j) in similarities.iter() {
///         println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
///     }
///
///     Ok(())
/// }
/// ```
#[derive(Clone)]
pub struct Bert {
    embedding_search_prefix: Arc<Option<String>>,
    model: Arc<BertModel>,
    tokenizer: Arc<RwLock<Tokenizer>>,
}

impl Bert {
    /// Create a new [`BertBuilder`]
    pub fn builder() -> BertBuilder {
        BertBuilder::default()
    }

    /// Create a new default bert model
    pub async fn new() -> Result<Self, BertLoadingError> {
        Self::builder().build().await
    }

    /// Create a new default bert model for search
    pub async fn new_for_search() -> Result<Self, BertLoadingError> {
        Self::builder()
            .with_source(BertSource::new_for_search())
            .build()
            .await
    }

    async fn from_builder(
        builder: BertBuilder,
        mut progress_handler: impl FnMut(ModelLoadingProgress) + Send + 'static,
    ) -> Result<Self, BertLoadingError> {
        let BertBuilder { source, cache } = builder;
        let BertSource {
            config,
            tokenizer,
            model,
            search_embedding_prefix,
        } = source;

        let source = format!("Config ({})", config);
        let mut create_progress = ModelLoadingProgress::downloading_progress(source);
        let config_filename = cache
            .get(&config, |progress| {
                progress_handler(create_progress(progress))
            })
            .await?;
        let tokenizer_source = format!("Tokenizer ({})", tokenizer);
        let mut create_progress = ModelLoadingProgress::downloading_progress(tokenizer_source);
        let tokenizer_filename = cache
            .get(&tokenizer, |progress| {
                progress_handler(create_progress(progress))
            })
            .await?;
        let model_source = format!("Model ({})", model);
        let mut create_progress = ModelLoadingProgress::downloading_progress(model_source);
        let weights_filename = cache
            .get(&model, |progress| {
                progress_handler(create_progress(progress))
            })
            .await?;

        let config = std::fs::read_to_string(config_filename)
            .map_err(|_| BertLoadingError::ConfigNotFound)?;
        let config: Config = serde_json::from_str(&config).map_err(BertLoadingError::LoadConfig)?;

        let device = accelerated_device_if_available()?;
        let vb =
            unsafe { VarBuilder::from_mmaped_safetensors(&[&weights_filename], DTYPE, &device)? };
        let model = BertModel::load(vb, &config)?;
        let mut tokenizer =
            Tokenizer::from_file(&tokenizer_filename).map_err(BertLoadingError::LoadTokenizer)?;
        tokenizer.with_padding(None);

        Ok(Bert {
            tokenizer: Arc::new(RwLock::new(tokenizer)),
            model: Arc::new(model),
            embedding_search_prefix: Arc::new(search_embedding_prefix),
        })
    }

    /// Embed a batch of sentences
    pub(crate) fn embed_batch_raw(
        &self,
        sentences: Vec<&str>,
        pooling: Pooling,
    ) -> Result<Vec<Tensor>, BertError> {
        let embedding_dim = self.model.embedding_dim();
        // The batch size limit (input length * memory per token)
        let limit = embedding_dim * 512usize.pow(2) * 2;

        // The sentences we are embedding may have a very different length. First we sort them so that similar length sentences are grouped together in the same batch to reduce the overhead of padding.
        let encodings = {
            let tokenizer_read = self.tokenizer.read().unwrap();
            tokenizer_read.encode_batch(sentences, true)
        }
        .map_err(BertError::TokenizerError)?;
        let mut encodings_with_indices = encodings.into_iter().enumerate().collect::<Vec<_>>();

        encodings_with_indices.sort_unstable_by_key(|(_, encoding)| encoding.len());

        let mut combined: Vec<Option<Tensor>> = vec![None; encodings_with_indices.len()];
        let mut chunks = Vec::new();
        let mut current_chunk_len = 0;
        let mut current_chunk_max_token_len = 0;
        let mut current_chunk_indices = Vec::new();
        let mut current_chunk_text: Vec<Encoding> = Vec::new();
        for (index, encoding) in encodings_with_indices {
            let len = encoding.get_ids().len();
            current_chunk_max_token_len = current_chunk_max_token_len.max(len);
            current_chunk_len += 1;
            let score = current_chunk_len
                * (embedding_dim * 8 + embedding_dim * current_chunk_max_token_len.pow(2));
            if score > limit {
                chunks.push((
                    std::mem::take(&mut current_chunk_indices),
                    std::mem::take(&mut current_chunk_text),
                ));
                current_chunk_max_token_len = len;
                current_chunk_len = 1;
            }
            current_chunk_indices.push(index);
            current_chunk_text.push(encoding);
        }
        // Add the last chunk even if the score isn't maxed out
        chunks.push((
            std::mem::take(&mut current_chunk_indices),
            std::mem::take(&mut current_chunk_text),
        ));

        for (indices, encodings) in chunks {
            let embeddings =
                maybe_autoreleasepool(|| self.embed_batch_raw_inner(encodings, pooling))?;
            for (i, embedding) in indices.iter().zip(embeddings) {
                combined[*i] = Some(embedding);
            }
        }
        Ok(combined.into_iter().map(|x| x.unwrap()).collect())
    }

    fn embed_batch_raw_inner(
        &self,
        mut tokens: Vec<Encoding>,
        pooling: Pooling,
    ) -> Result<Vec<Tensor>, BertError> {
        if tokens.is_empty() {
            return Ok(Vec::new());
        }
        let device = &self.model.device;
        let pp = PaddingParams {
            strategy: tokenizers::PaddingStrategy::BatchLongest,
            ..Default::default()
        };
        tokenizers::pad_encodings(&mut tokens, &pp).map_err(BertError::TokenizerError)?;

        let n_sentences = tokens.len();
        let max_seq_len = self.model.max_seq_len();
        let token_ids = tokens
            .iter()
            .map(|tokens| {
                let tokens = tokens.get_ids().to_vec();
                Tensor::new(
                    &tokens.as_slice()[..max_seq_len.min(tokens.as_slice().len())],
                    device,
                )
            })
            .collect::<candle_core::Result<Vec<_>>>()?;
        let token_ids = Tensor::stack(&token_ids, 0)?;

        let attention_masks = tokens
            .iter()
            .map(|tokens| {
                let attention_mask = tokens.get_attention_mask();
                let attention_mask = Tensor::new(
                    &attention_mask[..max_seq_len.min(attention_mask.len())],
                    device,
                )?;
                Ok(attention_mask)
            })
            .collect::<candle_core::Result<Vec<_>>>()?;
        let attention_mask = Tensor::stack(&attention_masks, 0)?;

        // The token type ids are only used for next sentence prediction. We can just set them to zero for embedding tasks.
        let token_type_ids = token_ids.zeros_like()?;
        let embeddings =
            self.model
                .forward(&token_ids, &token_type_ids, Some(&attention_mask), false)?;

        let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;

        match pooling {
            Pooling::Mean => {
                // Take the mean embedding value for all tokens (except padding)
                let embeddings = embeddings.mul(
                    &attention_mask
                        .to_dtype(DTYPE)?
                        .unsqueeze(2)?
                        .broadcast_as(embeddings.shape())?,
                )?;
                let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
                let embeddings = normalize_l2(&embeddings)?;
                Ok(embeddings.chunk(n_sentences, 0)?)
            }
            Pooling::CLS => {
                // Index into the first token of each sentence which is the CLS token that contains the sentence embedding
                let indexed_embeddings = embeddings.i((.., 0, ..))?;
                Ok(indexed_embeddings.chunk(n_sentences, 0)?)
            }
        }
    }
}

fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
    v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}