Skip to main content

spark_bert/
embs.rs

1use anyhow::{Error as E, Result};
2use candle_core::Tensor;
3use candle_transformers::models::bert::BertModel;
4use tokenizers::{PaddingParams, Tokenizer};
5use tracing_chrome::ChromeLayerBuilder;
6use tracing_subscriber::prelude::*;
7
8use crate::args::Args;
9
10pub struct Bert {
11    model: BertModel,
12    tokenizer: Tokenizer,
13    args: Args,
14}
15
16impl Bert {
17    pub fn build(args: Args) -> Result<Self> {
18        let (model, tokenizer) = args.build_model_and_tokenizer()?;
19        Ok(Self {
20            model,
21            tokenizer,
22            args,
23        })
24    }
25
26    pub fn calc_embs(&mut self, sentences: Vec<&str>, apply_pooling: bool) -> Result<Tensor> {
27        let _guard = if self.args.tracing {
28            println!("tracing...");
29            let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
30            tracing_subscriber::registry().with(chrome_layer).init();
31            Some(guard)
32        } else {
33            None
34        };
35        let start = std::time::Instant::now();
36
37        let device = &self.model.device;
38
39        if let Some(pp) = self.tokenizer.get_padding_mut() {
40            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
41        } else {
42            let pp = PaddingParams {
43                strategy: tokenizers::PaddingStrategy::BatchLongest,
44                ..Default::default()
45            };
46            self.tokenizer.with_padding(Some(pp));
47        }
48        let tokens = self
49            .tokenizer
50            .encode_batch(sentences.to_vec(), true)
51            .map_err(E::msg)?;
52        let token_ids = tokens
53            .iter()
54            .map(|tokens| {
55                let tokens = tokens.get_ids().to_vec();
56                Ok(Tensor::new(tokens.as_slice(), device)?)
57            })
58            .collect::<Result<Vec<_>>>()?;
59        let attention_mask = tokens
60            .iter()
61            .map(|tokens| {
62                let tokens = tokens.get_attention_mask().to_vec();
63                Ok(Tensor::new(tokens.as_slice(), device)?)
64            })
65            .collect::<Result<Vec<_>>>()?;
66
67        let token_ids = Tensor::stack(&token_ids, 0)?;
68        let attention_mask = Tensor::stack(&attention_mask, 0)?;
69        let token_type_ids = token_ids.zeros_like()?;
70        //println!("running inference on batch {:?}", token_ids.shape());
71        let embeddings = self
72            .model
73            .forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
74        //println!("generated embeddings {:?}", embeddings.shape());
75        let embeddings = if apply_pooling {
76            // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
77            let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
78            (embeddings.sum(1)? / (n_tokens as f64))?
79        } else {
80            embeddings
81        };
82        let embeddings = if apply_pooling && self.args.normalize_embeddings {
83            normalize_l2(&embeddings)?
84        } else {
85            embeddings
86        };
87        //println!("Loaded and encoded {:?}", start.elapsed());
88        //println!("pooled embeddings {:?}", embeddings.shape());
89
90        Ok(embeddings)
91    }
92}
93
94// TODO: adapt to 3D vector
95pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
96    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
97}
98
99pub fn convert_to_flatten_vec(embs: &Tensor) -> Result<Vec<f32>> {
100    Ok(embs.flatten_all()?.to_vec1::<f32>()?)
101}