1use anyhow::{Error as E, Result};
2use candle_core::Tensor;
3use candle_transformers::models::bert::BertModel;
4use tokenizers::{PaddingParams, Tokenizer};
5use tracing_chrome::ChromeLayerBuilder;
6use tracing_subscriber::prelude::*;
7
8use crate::args::Args;
9
10pub struct Bert {
11 model: BertModel,
12 tokenizer: Tokenizer,
13 args: Args,
14}
15
16impl Bert {
17 pub fn build(args: Args) -> Result<Self> {
18 let (model, tokenizer) = args.build_model_and_tokenizer()?;
19 Ok(Self {
20 model,
21 tokenizer,
22 args,
23 })
24 }
25
26 pub fn calc_embs(&mut self, sentences: Vec<&str>, apply_pooling: bool) -> Result<Tensor> {
27 let _guard = if self.args.tracing {
28 println!("tracing...");
29 let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
30 tracing_subscriber::registry().with(chrome_layer).init();
31 Some(guard)
32 } else {
33 None
34 };
35 let start = std::time::Instant::now();
36
37 let device = &self.model.device;
38
39 if let Some(pp) = self.tokenizer.get_padding_mut() {
40 pp.strategy = tokenizers::PaddingStrategy::BatchLongest
41 } else {
42 let pp = PaddingParams {
43 strategy: tokenizers::PaddingStrategy::BatchLongest,
44 ..Default::default()
45 };
46 self.tokenizer.with_padding(Some(pp));
47 }
48 let tokens = self
49 .tokenizer
50 .encode_batch(sentences.to_vec(), true)
51 .map_err(E::msg)?;
52 let token_ids = tokens
53 .iter()
54 .map(|tokens| {
55 let tokens = tokens.get_ids().to_vec();
56 Ok(Tensor::new(tokens.as_slice(), device)?)
57 })
58 .collect::<Result<Vec<_>>>()?;
59 let attention_mask = tokens
60 .iter()
61 .map(|tokens| {
62 let tokens = tokens.get_attention_mask().to_vec();
63 Ok(Tensor::new(tokens.as_slice(), device)?)
64 })
65 .collect::<Result<Vec<_>>>()?;
66
67 let token_ids = Tensor::stack(&token_ids, 0)?;
68 let attention_mask = Tensor::stack(&attention_mask, 0)?;
69 let token_type_ids = token_ids.zeros_like()?;
70 let embeddings = self
72 .model
73 .forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
74 let embeddings = if apply_pooling {
76 let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
78 (embeddings.sum(1)? / (n_tokens as f64))?
79 } else {
80 embeddings
81 };
82 let embeddings = if apply_pooling && self.args.normalize_embeddings {
83 normalize_l2(&embeddings)?
84 } else {
85 embeddings
86 };
87 Ok(embeddings)
91 }
92}
93
94pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
96 Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
97}
98
99pub fn convert_to_flatten_vec(embs: &Tensor) -> Result<Vec<f32>> {
100 Ok(embs.flatten_all()?.to_vec1::<f32>()?)
101}