use crate::autograd::Tensor;
use crate::models::bert::config::BertConfig;
use crate::models::bert::embeddings::BertEmbeddings;
use crate::models::bert::encoder::BertEncoder;
use crate::nn::{Linear, Module};
pub struct CrossEncoder {
embeddings: BertEmbeddings,
encoder: BertEncoder,
pooler: Option<Linear>,
classifier: Linear,
hidden_dim: usize,
num_labels: usize,
}
impl CrossEncoder {
#[must_use]
pub fn new(config: &BertConfig, num_labels: usize, with_pooler: bool) -> Self {
let h = config.hidden_dim;
Self {
embeddings: BertEmbeddings::new(config),
encoder: BertEncoder::new(config),
pooler: if with_pooler {
Some(Linear::new(h, h))
} else {
None
},
classifier: Linear::new(h, num_labels),
hidden_dim: h,
num_labels,
}
}
#[must_use]
pub fn num_labels(&self) -> usize {
self.num_labels
}
pub fn embeddings_mut(&mut self) -> &mut BertEmbeddings {
&mut self.embeddings
}
pub fn encoder_mut(&mut self) -> &mut BertEncoder {
&mut self.encoder
}
pub fn pooler_mut(&mut self) -> Option<&mut Linear> {
self.pooler.as_mut()
}
pub fn classifier_mut(&mut self) -> &mut Linear {
&mut self.classifier
}
pub fn load_from_reader(
&mut self,
reader: &crate::format::v2::AprV2Reader,
config: &BertConfig,
) -> Result<(), crate::models::bert::load::BertLoadError> {
crate::models::bert::load::load_cross_encoder_from_reader(self, reader, config)
}
#[must_use]
pub fn forward(&self, input_ids: &[u32], token_type_ids: &[u32]) -> Tensor {
let embeddings = self.embeddings.forward(input_ids, token_type_ids);
let hidden = self.encoder.forward(&embeddings, None);
let hidden_data = hidden.data();
let cls = hidden_data[..self.hidden_dim].to_vec();
let cls_tensor = Tensor::from_vec(cls, &[1, self.hidden_dim]);
let pooled = if let Some(pooler) = &self.pooler {
let dense = pooler.forward(&cls_tensor);
tanh(&dense)
} else {
cls_tensor
};
self.classifier.forward(&pooled)
}
#[must_use]
pub fn score(&self, input_ids: &[u32], token_type_ids: &[u32]) -> f32 {
let logit_tensor = self.forward(input_ids, token_type_ids);
let logit = logit_tensor.data()[0];
sigmoid(logit)
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn tanh(x: &Tensor) -> Tensor {
let data: Vec<f32> = x.data().iter().map(|v| v.tanh()).collect();
Tensor::from_vec(data, x.shape())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cross_encoder_returns_scalar_logit() {
let config = BertConfig::minilm_l6();
let model = CrossEncoder::new(&config, 1, true);
let input_ids = vec![101u32, 2024, 102, 3456, 102];
let token_type_ids = vec![0u32, 0, 0, 1, 1];
let out = model.forward(&input_ids, &token_type_ids);
assert_eq!(out.shape(), &[1, 1]);
}
#[test]
fn cross_encoder_score_returns_finite() {
let config = BertConfig::minilm_l6();
let model = CrossEncoder::new(&config, 1, true);
let input_ids = vec![101u32, 2024, 102, 3456, 102];
let token_type_ids = vec![0u32, 0, 0, 1, 1];
let score = model.score(&input_ids, &token_type_ids);
assert!(score.is_finite());
assert!((0.0..=1.0).contains(&score));
}
#[test]
fn cross_encoder_without_pooler() {
let config = BertConfig::minilm_l6();
let model = CrossEncoder::new(&config, 1, false);
let input_ids = vec![101u32, 2024, 102];
let token_type_ids = vec![0u32, 0, 0];
let out = model.forward(&input_ids, &token_type_ids);
assert_eq!(out.shape(), &[1, 1]);
}
#[test]
fn cross_encoder_num_labels_dimension() {
let config = BertConfig::minilm_l6();
let model = CrossEncoder::new(&config, 2, false);
let input_ids = vec![101u32, 2024, 102];
let token_type_ids = vec![0u32, 0, 0];
let out = model.forward(&input_ids, &token_type_ids);
assert_eq!(out.shape(), &[1, 2]);
}
}