reflex/embedding/
bert.rs

1//! Minimal BERT classifier wrapper used by the reranker.
2
3use candle::{DType, Device, Result, Tensor};
4use candle_core as candle;
5use candle_core::IndexOp;
6use candle_nn::{Linear, Module, VarBuilder};
7use candle_transformers::models::bert::{BertModel, Config};
8use std::path::Path;
9
10struct BertForSequenceClassificationImpl {
11    bert: BertModel,
12    classifier: Linear,
13}
14
15impl BertForSequenceClassificationImpl {
16    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
17        let bert = if vb.contains_tensor("bert.embeddings.word_embeddings.weight") {
18            BertModel::load(vb.pp("bert"), config)?
19        } else if vb.contains_tensor("roberta.embeddings.word_embeddings.weight") {
20            BertModel::load(vb.pp("roberta"), config)?
21        } else {
22            BertModel::load(vb.clone(), config)?
23        };
24
25        let hidden_size = config.hidden_size;
26        let classifier = candle_nn::linear(hidden_size, 1, vb.pp("classifier"))?;
27
28        Ok(Self { bert, classifier })
29    }
30
31    fn forward(
32        &self,
33        input_ids: &Tensor,
34        token_type_ids: &Tensor,
35        attention_mask: Option<&Tensor>,
36    ) -> Result<Tensor> {
37        let output = self
38            .bert
39            .forward(input_ids, token_type_ids, attention_mask)?;
40        let cls_token = output.i((.., 0, ..))?;
41        self.classifier.forward(&cls_token)
42    }
43}
44
45#[derive(Clone)]
46/// BERT sequence-classification model that returns a single logit score.
47pub struct BertClassifier(std::sync::Arc<BertForSequenceClassificationImpl>);
48
49impl BertClassifier {
50    /// Loads a model from a directory containing `config.json` and `model.safetensors`.
51    pub fn load<P: AsRef<Path>>(model_dir: P, device: &Device) -> Result<Self> {
52        let model_dir = model_dir.as_ref();
53        let config_path = model_dir.join("config.json");
54        let weights_path = model_dir.join("model.safetensors");
55
56        let config_content = std::fs::read_to_string(config_path)?;
57        let config: Config = serde_json::from_str(&config_content)
58            .map_err(|e| candle::Error::Msg(format!("Failed to parse config: {}", e)))?;
59
60        let vb =
61            unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, device)? };
62
63        let model = BertForSequenceClassificationImpl::load(vb, &config)?;
64
65        Ok(Self(std::sync::Arc::new(model)))
66    }
67
68    /// Runs a forward pass and returns logits.
69    pub fn forward(
70        &self,
71        input_ids: &Tensor,
72        token_type_ids: &Tensor,
73        attention_mask: Option<&Tensor>,
74    ) -> Result<Tensor> {
75        self.0.forward(input_ids, token_type_ids, attention_mask)
76    }
77}