1use candle::{DType, Device, Result, Tensor};
4use candle_core as candle;
5use candle_core::IndexOp;
6use candle_nn::{Linear, Module, VarBuilder};
7use candle_transformers::models::bert::{BertModel, Config};
8use std::path::Path;
9
10struct BertForSequenceClassificationImpl {
11 bert: BertModel,
12 classifier: Linear,
13}
14
15impl BertForSequenceClassificationImpl {
16 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
17 let bert = if vb.contains_tensor("bert.embeddings.word_embeddings.weight") {
18 BertModel::load(vb.pp("bert"), config)?
19 } else if vb.contains_tensor("roberta.embeddings.word_embeddings.weight") {
20 BertModel::load(vb.pp("roberta"), config)?
21 } else {
22 BertModel::load(vb.clone(), config)?
23 };
24
25 let hidden_size = config.hidden_size;
26 let classifier = candle_nn::linear(hidden_size, 1, vb.pp("classifier"))?;
27
28 Ok(Self { bert, classifier })
29 }
30
31 fn forward(
32 &self,
33 input_ids: &Tensor,
34 token_type_ids: &Tensor,
35 attention_mask: Option<&Tensor>,
36 ) -> Result<Tensor> {
37 let output = self
38 .bert
39 .forward(input_ids, token_type_ids, attention_mask)?;
40 let cls_token = output.i((.., 0, ..))?;
41 self.classifier.forward(&cls_token)
42 }
43}
44
45#[derive(Clone)]
46pub struct BertClassifier(std::sync::Arc<BertForSequenceClassificationImpl>);
48
49impl BertClassifier {
50 pub fn load<P: AsRef<Path>>(model_dir: P, device: &Device) -> Result<Self> {
52 let model_dir = model_dir.as_ref();
53 let config_path = model_dir.join("config.json");
54 let weights_path = model_dir.join("model.safetensors");
55
56 let config_content = std::fs::read_to_string(config_path)?;
57 let config: Config = serde_json::from_str(&config_content)
58 .map_err(|e| candle::Error::Msg(format!("Failed to parse config: {}", e)))?;
59
60 let vb =
61 unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, device)? };
62
63 let model = BertForSequenceClassificationImpl::load(vb, &config)?;
64
65 Ok(Self(std::sync::Arc::new(model)))
66 }
67
68 pub fn forward(
70 &self,
71 input_ids: &Tensor,
72 token_type_ids: &Tensor,
73 attention_mask: Option<&Tensor>,
74 ) -> Result<Tensor> {
75 self.0.forward(input_ids, token_type_ids, attention_mask)
76 }
77}