Skip to main content

spark_bert/
args.rs

1use anyhow::{Error as E, Result};
2use candle_nn::VarBuilder;
3use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
4use hf_hub::{api::sync::Api, Repo, RepoType};
5use tokenizers::{Tokenizer, TruncationParams};
6
7pub struct Args {
8    /// Run on CPU rather than on GPU.
9    pub cpu: bool,
10    /// Enable tracing (generates a trace-timestamp.json file).
11    pub tracing: bool,
12    /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
13    pub model_id: Option<String>,
14    pub revision: Option<String>,
15    /// Use the pytorch weights rather than the safetensors ones
16    pub use_pth: bool,
17    /// L2 normalization for embeddings.
18    pub normalize_embeddings: bool,
19    /// Use tanh based approximation for Gelu instead of erf implementation.
20    pub approximate_gelu: bool,
21}
22
23impl Args {
24    pub fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
25        let device = crate::util::device(self.cpu)?;
26        let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
27        let default_revision = "refs/pr/21".to_string();
28        let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
29            (Some(model_id), Some(revision)) => (model_id, revision),
30            (Some(model_id), None) => (model_id, "main".to_string()),
31            (None, Some(revision)) => (default_model, revision),
32            (None, None) => (default_model, default_revision),
33        };
34
35        let repo = Repo::with_revision(model_id, RepoType::Model, revision);
36        let (config_filename, tokenizer_filename, weights_filename) = {
37            let api = Api::new()?;
38            let api = api.repo(repo);
39            let config = api.get("config.json")?;
40            let tokenizer = api.get("tokenizer.json")?;
41            let weights = if self.use_pth {
42                api.get("pytorch_model.bin")?
43            } else {
44                api.get("model.safetensors")?
45            };
46            (config, tokenizer, weights)
47        };
48        let config = std::fs::read_to_string(config_filename)?;
49        let mut config: Config = serde_json::from_str(&config)?;
50        let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
51        let _ = tokenizer.with_truncation(Some(TruncationParams {
52            max_length: 512,
53            ..Default::default()
54        }));
55        let vb = if self.use_pth {
56            VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
57        } else {
58            unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
59        };
60        if self.approximate_gelu {
61            config.hidden_act = HiddenAct::GeluApproximate;
62        }
63        let model = BertModel::load(vb, &config)?;
64        Ok((model, tokenizer))
65    }
66}