use std::collections::HashMap;
use flodl::nn::{Dropout, Embedding, LayerNorm, Linear, Module, NamedInputModule, Parameter};
use flodl::{DType, Device, FlowBuilder, Graph, Result, Tensor, TensorError, Variable};
use crate::models::transformer_layer::{LayerNaming, TransformerLayer, TransformerLayerConfig};
use crate::path::{prefix_params, HfPath};
pub fn build_extended_attention_mask(mask: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
assert_eq!(shape.len(), 2, "expected [batch, seq_len], got {shape:?}");
let mask_f = mask.to_dtype(DType::Float32)?;
let additive = mask_f.mul_scalar(-1.0)?.add_scalar(1.0)?.mul_scalar(-1e4)?;
additive.reshape(&[shape[0], 1, 1, shape[1]])
}
#[derive(Debug, Clone)]
pub struct BertConfig {
pub vocab_size: i64,
pub hidden_size: i64,
pub num_hidden_layers: i64,
pub num_attention_heads: i64,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub type_vocab_size: i64,
pub pad_token_id: Option<i64>,
pub layer_norm_eps: f64,
pub hidden_dropout_prob: f64,
pub attention_probs_dropout_prob: f64,
pub num_labels: Option<i64>,
pub id2label: Option<Vec<String>>,
}
impl BertConfig {
pub fn bert_base_uncased() -> Self {
BertConfig {
vocab_size: 30522,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
max_position_embeddings: 512,
type_vocab_size: 2,
pad_token_id: Some(0),
layer_norm_eps: 1e-12,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
num_labels: None,
id2label: None,
}
}
pub fn from_json_str(s: &str) -> Result<Self> {
use crate::config_json::{
optional_f64, optional_i64_or_none, parse_id2label, parse_num_labels, required_i64,
};
let v: serde_json::Value = serde_json::from_str(s)
.map_err(|e| TensorError::new(&format!("config.json parse error: {e}")))?;
let id2label = parse_id2label(&v)?;
let num_labels = parse_num_labels(&v, id2label.as_deref());
Ok(BertConfig {
vocab_size: required_i64(&v, "vocab_size")?,
hidden_size: required_i64(&v, "hidden_size")?,
num_hidden_layers: required_i64(&v, "num_hidden_layers")?,
num_attention_heads: required_i64(&v, "num_attention_heads")?,
intermediate_size: required_i64(&v, "intermediate_size")?,
max_position_embeddings: required_i64(&v, "max_position_embeddings")?,
type_vocab_size: required_i64(&v, "type_vocab_size")?,
pad_token_id: optional_i64_or_none(&v, "pad_token_id"),
layer_norm_eps: optional_f64(&v, "layer_norm_eps", 1e-12),
hidden_dropout_prob: optional_f64(&v, "hidden_dropout_prob", 0.1),
attention_probs_dropout_prob: optional_f64(&v, "attention_probs_dropout_prob", 0.1),
num_labels,
id2label,
})
}
}
pub struct BertEmbeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
dropout: Dropout,
}
impl BertEmbeddings {
pub fn on_device(config: &BertConfig, device: Device) -> Result<Self> {
Ok(BertEmbeddings {
word_embeddings: Embedding::on_device_with_padding_idx(
config.vocab_size,
config.hidden_size,
config.pad_token_id,
device,
)?,
position_embeddings: Embedding::on_device(
config.max_position_embeddings,
config.hidden_size,
device,
)?,
token_type_embeddings: Embedding::on_device(
config.type_vocab_size,
config.hidden_size,
device,
)?,
layer_norm: LayerNorm::on_device_with_eps(
config.hidden_size,
config.layer_norm_eps,
device,
)?,
dropout: Dropout::new(config.hidden_dropout_prob),
})
}
}
impl Module for BertEmbeddings {
fn name(&self) -> &str { "bert_embeddings" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let word = self.word_embeddings.forward(input)?;
let ln = self.layer_norm.forward(&word)?;
self.dropout.forward(&ln)
}
fn parameters(&self) -> Vec<Parameter> {
let mut out = Vec::new();
out.extend(prefix_params("word_embeddings", self.word_embeddings.parameters()));
out.extend(prefix_params("position_embeddings", self.position_embeddings.parameters()));
out.extend(prefix_params("token_type_embeddings", self.token_type_embeddings.parameters()));
out.extend(prefix_params("LayerNorm", self.layer_norm.parameters()));
out
}
fn as_named_input(&self) -> Option<&dyn NamedInputModule> { Some(self) }
fn set_training(&self, training: bool) {
self.dropout.set_training(training);
}
}
impl NamedInputModule for BertEmbeddings {
fn forward_named(
&self,
input: &Variable,
refs: &HashMap<String, Variable>,
) -> Result<Variable> {
let mut summed = self.word_embeddings.forward(input)?;
if let Some(pos) = refs.get("position_ids") {
let pe = self.position_embeddings.forward(pos)?;
summed = summed.add(&pe)?;
}
if let Some(tt) = refs.get("token_type_ids") {
let te = self.token_type_embeddings.forward(tt)?;
summed = summed.add(&te)?;
}
let ln = self.layer_norm.forward(&summed)?;
self.dropout.forward(&ln)
}
}
pub struct BertPooler {
dense: Linear,
}
impl BertPooler {
pub fn on_device(config: &BertConfig, device: Device) -> Result<Self> {
Ok(BertPooler {
dense: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
})
}
}
impl Module for BertPooler {
fn name(&self) -> &str { "bert_pooler" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let cls = input.select(1, 0)?; let pooled = self.dense.forward(&cls)?;
pooled.tanh()
}
fn parameters(&self) -> Vec<Parameter> {
prefix_params("dense", self.dense.parameters())
}
}
fn bert_layer_config(config: &BertConfig) -> TransformerLayerConfig {
TransformerLayerConfig {
hidden_size: config.hidden_size,
num_attention_heads: config.num_attention_heads,
intermediate_size: config.intermediate_size,
hidden_dropout_prob: config.hidden_dropout_prob,
attention_probs_dropout_prob: config.attention_probs_dropout_prob,
layer_norm_eps: config.layer_norm_eps,
}
}
fn bert_backbone_flow(
config: &BertConfig,
device: Device,
with_pooler: bool,
) -> Result<FlowBuilder> {
let mut fb = FlowBuilder::new()
.input(&["position_ids", "token_type_ids", "attention_mask"])
.through(BertEmbeddings::on_device(config, device)?)
.tag("bert.embeddings")
.using(&["position_ids", "token_type_ids"]);
let layer_root = HfPath::new("bert").sub("encoder").sub("layer");
let layer_cfg = bert_layer_config(config);
for i in 0..config.num_hidden_layers {
let tag = layer_root.sub(i).to_string();
fb = fb
.through(TransformerLayer::on_device(&layer_cfg, LayerNaming::BERT, device)?)
.tag(&tag)
.using(&["attention_mask"]);
}
if with_pooler {
fb = fb
.through(BertPooler::on_device(config, device)?)
.tag("bert.pooler");
}
Ok(fb)
}
pub struct BertModel;
impl BertModel {
pub fn build(config: &BertConfig) -> Result<Graph> {
Self::on_device(config, Device::CPU)
}
pub fn on_device(config: &BertConfig, device: Device) -> Result<Graph> {
bert_backbone_flow(config, device, true)?.build()
}
pub fn on_device_without_pooler(config: &BertConfig, device: Device) -> Result<Graph> {
bert_backbone_flow(config, device, false)?.build()
}
}
use crate::task_heads::{check_num_labels, default_labels, extract_best_span, logits_to_sorted_labels};
pub use crate::task_heads::{Answer, TokenPrediction};
pub struct BertForSequenceClassification {
graph: Graph,
id2label: Vec<String>,
#[cfg(feature = "tokenizer")]
tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl BertForSequenceClassification {
pub fn on_device(
config: &BertConfig,
num_labels: i64,
device: Device,
) -> Result<Self> {
let num_labels = check_num_labels(num_labels)?;
let graph = bert_backbone_flow(config, device, true)?
.through(Dropout::new(config.hidden_dropout_prob))
.through(Linear::on_device(config.hidden_size, num_labels, device)?)
.tag("classifier")
.build()?;
let id2label = config
.id2label
.clone()
.unwrap_or_else(|| default_labels(num_labels));
Ok(Self {
graph,
id2label,
#[cfg(feature = "tokenizer")]
tokenizer: None,
})
}
pub(crate) fn num_labels_from_config(config: &BertConfig) -> Result<i64> {
config.num_labels.ok_or_else(|| {
TensorError::new(
"BertForSequenceClassification: config.json has no `num_labels` \
(nor `id2label`); cannot infer head size",
)
})
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn labels(&self) -> &[String] { &self.id2label }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
#[cfg(feature = "tokenizer")]
pub fn classify(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Vec<(String, f32)>>> {
let logits = self.forward_from_encoded(enc)?;
logits_to_sorted_labels(&logits, &self.id2label)
}
#[cfg(feature = "tokenizer")]
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<(String, f32)>>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"BertForSequenceClassification::predict requires a tokenizer; \
use from_pretrained or .with_tokenizer(...) first",
)
})?;
let enc = tok.encode(texts)?;
self.classify(&enc)
}
#[cfg(feature = "tokenizer")]
fn forward_from_encoded(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Variable> {
self.graph.eval();
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
self.graph.forward_multi(&[
enc.input_ids.clone(),
enc.position_ids.clone(),
enc.token_type_ids.clone(),
mask,
])
}
}
pub struct BertForTokenClassification {
graph: Graph,
id2label: Vec<String>,
#[cfg(feature = "tokenizer")]
tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl BertForTokenClassification {
pub fn on_device(
config: &BertConfig,
num_labels: i64,
device: Device,
) -> Result<Self> {
let num_labels = check_num_labels(num_labels)?;
let graph = bert_backbone_flow(config, device, false)?
.through(Dropout::new(config.hidden_dropout_prob))
.through(Linear::on_device(config.hidden_size, num_labels, device)?)
.tag("classifier")
.build()?;
let id2label = config
.id2label
.clone()
.unwrap_or_else(|| default_labels(num_labels));
Ok(Self {
graph,
id2label,
#[cfg(feature = "tokenizer")]
tokenizer: None,
})
}
pub(crate) fn num_labels_from_config(config: &BertConfig) -> Result<i64> {
config.num_labels.ok_or_else(|| {
TensorError::new(
"BertForTokenClassification: config.json has no `num_labels` \
(nor `id2label`); cannot infer head size",
)
})
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn labels(&self) -> &[String] { &self.id2label }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
#[cfg(feature = "tokenizer")]
pub fn tag(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Vec<TokenPrediction>>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"BertForTokenClassification::tag requires a tokenizer; \
attach one via .with_tokenizer(...) or from_pretrained",
)
})?;
self.graph.eval();
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
let logits = self.graph.forward_multi(&[
enc.input_ids.clone(),
enc.position_ids.clone(),
enc.token_type_ids.clone(),
mask,
])?;
let probs = logits.softmax(-1)?;
let shape = probs.shape();
assert_eq!(shape.len(), 3, "expected [B, S, num_labels], got {shape:?}");
let batch = shape[0] as usize;
let seq = shape[1] as usize;
let n = shape[2] as usize;
let flat = probs.data().to_f32_vec()?;
let input_ids: Vec<i64> = enc.input_ids.data().to_i64_vec()?;
let attn_ids: Vec<i64> = enc.attention_mask.data().to_i64_vec()?;
let mut out = Vec::with_capacity(batch);
for b in 0..batch {
let mut row = Vec::with_capacity(seq);
for s in 0..seq {
let base = (b * seq + s) * n;
let (best_k, &best_p) = flat[base..base + n]
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.expect("n > 0 checked by check_num_labels");
let id = input_ids[b * seq + s] as u32;
let token = tok
.inner()
.id_to_token(id)
.unwrap_or_else(|| format!("<unk_id={id}>"));
row.push(TokenPrediction {
token,
label: self.id2label[best_k].clone(),
score: best_p,
attends: attn_ids[b * seq + s] != 0,
});
}
out.push(row);
}
Ok(out)
}
#[cfg(feature = "tokenizer")]
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<TokenPrediction>>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"BertForTokenClassification::predict requires a tokenizer; \
use from_pretrained or .with_tokenizer(...) first",
)
})?;
let enc = tok.encode(texts)?;
self.tag(&enc)
}
}
pub struct BertForQuestionAnswering {
graph: Graph,
#[cfg(feature = "tokenizer")]
tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl BertForQuestionAnswering {
pub fn on_device(config: &BertConfig, device: Device) -> Result<Self> {
let graph = bert_backbone_flow(config, device, false)?
.through(Linear::on_device(config.hidden_size, 2, device)?)
.tag("qa_outputs")
.build()?;
Ok(Self {
graph,
#[cfg(feature = "tokenizer")]
tokenizer: None,
})
}
pub fn graph(&self) -> &Graph { &self.graph }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
#[cfg(feature = "tokenizer")]
pub fn answer(&self, question: &str, context: &str) -> Result<Answer> {
let mut out = self.answer_batch(&[(question, context)])?;
Ok(out.pop().expect("answer_batch returns one per input"))
}
#[cfg(feature = "tokenizer")]
pub fn answer_batch(&self, pairs: &[(&str, &str)]) -> Result<Vec<Answer>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"BertForQuestionAnswering::answer requires a tokenizer; \
use from_pretrained or .with_tokenizer(...) first",
)
})?;
let enc = tok.encode_pairs(pairs)?;
self.extract(&enc)
}
#[cfg(feature = "tokenizer")]
pub fn extract(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Answer>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"BertForQuestionAnswering::extract requires a tokenizer; \
attach one via .with_tokenizer(...) or from_pretrained",
)
})?;
self.graph.eval();
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
let logits = self.graph.forward_multi(&[
enc.input_ids.clone(),
enc.position_ids.clone(),
enc.token_type_ids.clone(),
mask,
])?;
extract_best_span(&logits, enc, tok)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::safetensors_io::expected_from_graph;
use flodl::TensorOptions;
fn expected_layer_keys(i: i64) -> Vec<String> {
let suffixes = [
"attention.output.LayerNorm.bias",
"attention.output.LayerNorm.weight",
"attention.output.dense.bias",
"attention.output.dense.weight",
"attention.self.key.bias",
"attention.self.key.weight",
"attention.self.query.bias",
"attention.self.query.weight",
"attention.self.value.bias",
"attention.self.value.weight",
"intermediate.dense.bias",
"intermediate.dense.weight",
"output.LayerNorm.bias",
"output.LayerNorm.weight",
"output.dense.bias",
"output.dense.weight",
];
suffixes.iter().map(|s| format!("bert.encoder.layer.{i}.{s}")).collect()
}
#[test]
fn bert_parameter_keys_match_hf_dotted_form() {
let config = BertConfig::bert_base_uncased();
let graph = BertModel::build(&config).unwrap();
let expected = expected_from_graph(&graph);
let mut keys: Vec<String> = expected.iter().map(|p| p.key.clone()).collect();
keys.sort();
let mut want: Vec<String> = vec![
"bert.embeddings.LayerNorm.bias".into(),
"bert.embeddings.LayerNorm.weight".into(),
"bert.embeddings.position_embeddings.weight".into(),
"bert.embeddings.token_type_embeddings.weight".into(),
"bert.embeddings.word_embeddings.weight".into(),
];
for i in 0..config.num_hidden_layers {
want.extend(expected_layer_keys(i));
}
want.extend([
"bert.pooler.dense.bias".into(),
"bert.pooler.dense.weight".into(),
]);
want.sort();
assert_eq!(want.len(), 199, "expected-key list size drift");
assert_eq!(keys, want, "BERT parameter keys must match HF exactly");
}
#[test]
fn bert_parameter_shapes_match_bert_base_uncased() {
let config = BertConfig::bert_base_uncased();
let graph = BertModel::build(&config).unwrap();
let expected = expected_from_graph(&graph);
let by_key: std::collections::HashMap<&str, &[i64]> = expected
.iter()
.map(|p| (p.key.as_str(), p.shape.as_slice()))
.collect();
assert_eq!(by_key["bert.embeddings.word_embeddings.weight"], &[30522, 768]);
assert_eq!(by_key["bert.embeddings.position_embeddings.weight"], &[512, 768]);
assert_eq!(by_key["bert.embeddings.token_type_embeddings.weight"], &[2, 768]);
assert_eq!(by_key["bert.embeddings.LayerNorm.weight"], &[768]);
assert_eq!(by_key["bert.embeddings.LayerNorm.bias"], &[768]);
for i in 0..config.num_hidden_layers {
let p = format!("bert.encoder.layer.{i}");
assert_eq!(by_key[&*format!("{p}.attention.self.query.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.self.query.bias")], &[768]);
assert_eq!(by_key[&*format!("{p}.attention.self.key.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.self.value.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.output.dense.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.output.LayerNorm.weight")], &[768]);
assert_eq!(by_key[&*format!("{p}.intermediate.dense.weight")], &[3072, 768]);
assert_eq!(by_key[&*format!("{p}.intermediate.dense.bias")], &[3072]);
assert_eq!(by_key[&*format!("{p}.output.dense.weight")], &[768, 3072]);
assert_eq!(by_key[&*format!("{p}.output.dense.bias")], &[768]);
assert_eq!(by_key[&*format!("{p}.output.LayerNorm.weight")], &[768]);
}
assert_eq!(by_key["bert.pooler.dense.weight"], &[768, 768]);
assert_eq!(by_key["bert.pooler.dense.bias"], &[768]);
}
#[test]
fn bert_layer_count_scales_with_config() {
for n in [1_i64, 3, 6] {
let config = BertConfig {
num_hidden_layers: n,
..BertConfig::bert_base_uncased()
};
let graph = BertModel::build(&config).unwrap();
let expected = expected_from_graph(&graph);
let total = expected.len();
let want_total = 5 + 16 * n as usize + 2;
assert_eq!(
total, want_total,
"num_hidden_layers={n}: got {total} keys, expected {want_total}",
);
let last_layer_key = format!(
"bert.encoder.layer.{}.attention.self.query.weight", n - 1,
);
assert!(
expected.iter().any(|p| p.key == last_layer_key),
"last layer key {last_layer_key:?} missing from graph keys",
);
}
}
fn tiny_bert_config() -> BertConfig {
BertConfig {
vocab_size: 32,
hidden_size: 16,
num_hidden_layers: 1,
num_attention_heads: 4,
intermediate_size: 32,
max_position_embeddings: 8,
type_vocab_size: 2,
pad_token_id: Some(0),
layer_norm_eps: 1e-12,
hidden_dropout_prob: 0.0,
attention_probs_dropout_prob: 0.0,
num_labels: None,
id2label: None,
}
}
#[test]
fn bert_config_from_json_str_matches_base_preset() {
let json = r#"{
"architectures": ["BertForMaskedLM"],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}"#;
let got = BertConfig::from_json_str(json).unwrap();
let want = BertConfig::bert_base_uncased();
assert_eq!(got.vocab_size, want.vocab_size);
assert_eq!(got.hidden_size, want.hidden_size);
assert_eq!(got.num_hidden_layers, want.num_hidden_layers);
assert_eq!(got.num_attention_heads, want.num_attention_heads);
assert_eq!(got.intermediate_size, want.intermediate_size);
assert_eq!(got.max_position_embeddings, want.max_position_embeddings);
assert_eq!(got.type_vocab_size, want.type_vocab_size);
assert_eq!(got.pad_token_id, want.pad_token_id);
assert!((got.layer_norm_eps - want.layer_norm_eps).abs() < 1e-18);
assert!((got.hidden_dropout_prob - want.hidden_dropout_prob).abs() < 1e-9);
assert!((got.attention_probs_dropout_prob - want.attention_probs_dropout_prob).abs() < 1e-9);
}
#[test]
fn bert_config_from_json_str_rejects_missing_field() {
let json = r#"{
"vocab_size": 30522,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2
}"#;
let err = BertConfig::from_json_str(json).unwrap_err().to_string();
assert!(err.contains("hidden_size"),
"error must name the missing field: {err}");
assert!(err.contains("missing required integer field"),
"error must explain the failure mode: {err}");
}
#[test]
fn bert_config_from_json_str_pad_token_id_nullable() {
let required_fields = r#"
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2
"#;
let explicit_null = format!(r#"{{ {required_fields}, "pad_token_id": null }}"#);
let absent = format!(r#"{{ {required_fields} }}"#);
let a = BertConfig::from_json_str(&explicit_null).unwrap();
let b = BertConfig::from_json_str(&absent).unwrap();
assert_eq!(a.pad_token_id, None);
assert_eq!(b.pad_token_id, None);
}
#[test]
fn bert_config_from_json_str_parses_task_head_metadata() {
let json = r#"{
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"num_labels": 3,
"id2label": { "2": "JOY", "0": "ANGER", "1": "SADNESS" },
"label2label": { "IGNORED": 1 }
}"#;
let c = BertConfig::from_json_str(json).unwrap();
assert_eq!(c.num_labels, Some(3));
assert_eq!(
c.id2label,
Some(vec!["ANGER".to_string(), "SADNESS".to_string(), "JOY".to_string()]),
);
}
#[test]
fn bert_config_num_labels_derived_from_id2label() {
let json = r#"{
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"id2label": { "0": "NEGATIVE", "1": "POSITIVE" }
}"#;
let c = BertConfig::from_json_str(json).unwrap();
assert_eq!(c.num_labels, Some(2));
assert_eq!(c.id2label.unwrap(), vec!["NEGATIVE", "POSITIVE"]);
}
#[test]
fn bert_config_without_task_metadata_is_none() {
let c = BertConfig::bert_base_uncased();
assert_eq!(c.num_labels, None);
assert_eq!(c.id2label, None);
}
#[test]
fn bert_config_rejects_non_contiguous_id2label() {
let json = r#"{
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"id2label": { "0": "A", "2": "C" }
}"#;
let err = BertConfig::from_json_str(json).unwrap_err().to_string();
assert!(err.contains("contiguous"), "error must call out contiguity: {err}");
}
#[test]
fn bert_config_from_json_str_uses_defaults_for_missing_optional_fields() {
let json = r#"{
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2
}"#;
let c = BertConfig::from_json_str(json).unwrap();
assert!((c.layer_norm_eps - 1e-12).abs() < 1e-18);
assert!((c.hidden_dropout_prob - 0.1).abs() < 1e-9);
assert!((c.attention_probs_dropout_prob - 0.1).abs() < 1e-9);
}
#[test]
fn bert_forward_shape_smoke() {
let config = tiny_bert_config();
let dev = Device::CPU;
let graph = BertModel::on_device(&config, dev).unwrap();
graph.eval();
let batch = 2;
let seq = 4;
let word_ids = Variable::new(
Tensor::from_i64(&[1, 2, 3, 4, 5, 6, 7, 0], &[batch, seq], dev).unwrap(),
false,
);
let position_ids = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3, 0, 1, 2, 3], &[batch, seq], dev).unwrap(),
false,
);
let token_type_ids = Variable::new(
Tensor::from_i64(&[0, 0, 0, 0, 1, 1, 1, 1], &[batch, seq], dev).unwrap(),
false,
);
let mask_flat = Tensor::ones(&[batch, seq], TensorOptions { dtype: DType::Float32, device: dev }).unwrap();
let attention_mask = Variable::new(
build_extended_attention_mask(&mask_flat).unwrap(),
false,
);
let out = graph
.forward_multi(&[word_ids, position_ids, token_type_ids, attention_mask])
.unwrap();
assert_eq!(out.shape(), vec![batch, config.hidden_size]);
}
#[test]
fn extended_attention_mask_shape_and_values() {
let dev = Device::CPU;
let raw = Tensor::from_f32(&[1.0, 1.0, 1.0, 1.0, 1.0, 0.0], &[2, 3], dev).unwrap();
let additive = build_extended_attention_mask(&raw).unwrap();
assert_eq!(additive.shape(), vec![2, 1, 1, 3]);
let values: Vec<f32> = additive.reshape(&[6]).unwrap().to_f32_vec().unwrap();
assert_eq!(values[0], 0.0);
assert_eq!(values[1], 0.0);
assert_eq!(values[2], 0.0);
assert_eq!(values[3], 0.0);
assert_eq!(values[4], 0.0);
assert!((values[5] - -1e4).abs() < 1e-3, "masked position should be ~-1e4, got {}", values[5]);
}
#[test]
fn bert_without_pooler_drops_two_keys() {
let config = BertConfig::bert_base_uncased();
let graph = BertModel::on_device_without_pooler(&config, Device::CPU).unwrap();
let expected = expected_from_graph(&graph);
let keys: Vec<&str> = expected.iter().map(|p| p.key.as_str()).collect();
assert_eq!(expected.len(), 197, "197 backbone keys expected");
assert!(!keys.iter().any(|k| k.starts_with("bert.pooler.")));
}
#[test]
fn sequence_classification_parameter_keys_match_hf() {
let config = BertConfig::bert_base_uncased();
let head = BertForSequenceClassification::on_device(&config, 3, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let mut head_keys: Vec<&str> = expected
.iter()
.map(|p| p.key.as_str())
.filter(|k| !k.starts_with("bert."))
.collect();
head_keys.sort();
assert_eq!(head_keys, vec!["classifier.bias", "classifier.weight"]);
let by_key: std::collections::HashMap<&str, &[i64]> = expected
.iter().map(|p| (p.key.as_str(), p.shape.as_slice())).collect();
assert_eq!(by_key["classifier.weight"], &[3, 768]);
assert_eq!(by_key["classifier.bias"], &[3]);
}
#[test]
fn sequence_classification_labels_from_config_or_fallback() {
let mut cfg = BertConfig::bert_base_uncased();
cfg.num_labels = Some(3);
cfg.id2label = Some(vec!["A".into(), "B".into(), "C".into()]);
let head = BertForSequenceClassification::on_device(&cfg, 3, Device::CPU).unwrap();
assert_eq!(head.labels(), &["A".to_string(), "B".to_string(), "C".to_string()]);
let bare = BertConfig::bert_base_uncased();
let fallback = BertForSequenceClassification::on_device(&bare, 2, Device::CPU).unwrap();
assert_eq!(fallback.labels(), &["LABEL_0".to_string(), "LABEL_1".to_string()]);
}
#[test]
fn sequence_classification_forward_shape_smoke() {
let config = tiny_bert_config();
let dev = Device::CPU;
let head = BertForSequenceClassification::on_device(&config, 5, dev).unwrap();
head.graph().eval();
let batch = 2;
let seq = 4;
let ids = Variable::new(
Tensor::from_i64(&[1, 2, 3, 4, 5, 6, 7, 0], &[batch, seq], dev).unwrap(),
false,
);
let pos = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3, 0, 1, 2, 3], &[batch, seq], dev).unwrap(),
false,
);
let tt = Variable::new(
Tensor::from_i64(&[0; 8], &[batch, seq], dev).unwrap(),
false,
);
let mask_flat = Tensor::ones(&[batch, seq], TensorOptions {
dtype: DType::Float32, device: dev,
}).unwrap();
let mask = Variable::new(build_extended_attention_mask(&mask_flat).unwrap(), false);
let out = head.graph().forward_multi(&[ids, pos, tt, mask]).unwrap();
assert_eq!(out.shape(), vec![batch, 5]);
}
#[test]
fn token_classification_parameter_keys_match_hf() {
let config = BertConfig::bert_base_uncased();
let head = BertForTokenClassification::on_device(&config, 9, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let keys: Vec<&str> = expected.iter().map(|p| p.key.as_str()).collect();
assert!(!keys.iter().any(|k| k.starts_with("bert.pooler.")),
"token classification must not carry pooler params");
assert!(keys.contains(&"classifier.weight"));
assert!(keys.contains(&"classifier.bias"));
let by_key: std::collections::HashMap<&str, &[i64]> = expected
.iter().map(|p| (p.key.as_str(), p.shape.as_slice())).collect();
assert_eq!(by_key["classifier.weight"], &[9, 768]);
assert_eq!(by_key["classifier.bias"], &[9]);
}
#[test]
fn token_classification_forward_shape_smoke() {
let config = tiny_bert_config();
let dev = Device::CPU;
let head = BertForTokenClassification::on_device(&config, 7, dev).unwrap();
head.graph().eval();
let batch = 2;
let seq = 4;
let ids = Variable::new(
Tensor::from_i64(&[1, 2, 3, 4, 5, 6, 7, 0], &[batch, seq], dev).unwrap(),
false,
);
let pos = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3, 0, 1, 2, 3], &[batch, seq], dev).unwrap(),
false,
);
let tt = Variable::new(Tensor::from_i64(&[0; 8], &[batch, seq], dev).unwrap(), false);
let mask_flat = Tensor::ones(&[batch, seq], TensorOptions {
dtype: DType::Float32, device: dev,
}).unwrap();
let mask = Variable::new(build_extended_attention_mask(&mask_flat).unwrap(), false);
let out = head.graph().forward_multi(&[ids, pos, tt, mask]).unwrap();
assert_eq!(out.shape(), vec![batch, seq, 7]);
}
#[test]
fn question_answering_parameter_keys_match_hf() {
let config = BertConfig::bert_base_uncased();
let head = BertForQuestionAnswering::on_device(&config, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let mut head_keys: Vec<&str> = expected
.iter().map(|p| p.key.as_str()).filter(|k| !k.starts_with("bert.")).collect();
head_keys.sort();
assert_eq!(head_keys, vec!["qa_outputs.bias", "qa_outputs.weight"]);
let by_key: std::collections::HashMap<&str, &[i64]> = expected
.iter().map(|p| (p.key.as_str(), p.shape.as_slice())).collect();
assert_eq!(by_key["qa_outputs.weight"], &[2, 768]);
assert_eq!(by_key["qa_outputs.bias"], &[2]);
}
#[test]
fn question_answering_forward_shape_smoke() {
let config = tiny_bert_config();
let dev = Device::CPU;
let head = BertForQuestionAnswering::on_device(&config, dev).unwrap();
head.graph().eval();
let batch = 1;
let seq = 4;
let ids = Variable::new(
Tensor::from_i64(&[1, 2, 3, 4], &[batch, seq], dev).unwrap(),
false,
);
let pos = Variable::new(
Tensor::from_i64(&[0, 1, 2, 3], &[batch, seq], dev).unwrap(),
false,
);
let tt = Variable::new(Tensor::from_i64(&[0; 4], &[batch, seq], dev).unwrap(), false);
let mask_flat = Tensor::ones(&[batch, seq], TensorOptions {
dtype: DType::Float32, device: dev,
}).unwrap();
let mask = Variable::new(build_extended_attention_mask(&mask_flat).unwrap(), false);
let out = head.graph().forward_multi(&[ids, pos, tt, mask]).unwrap();
assert_eq!(out.shape(), vec![batch, seq, 2]);
}
#[test]
fn task_heads_reject_zero_labels() {
let config = BertConfig::bert_base_uncased();
let dev = Device::CPU;
assert!(BertForSequenceClassification::on_device(&config, 0, dev).is_err());
assert!(BertForTokenClassification::on_device(&config, 0, dev).is_err());
}
}