use std::cell::Cell;
use flodl::nn::{Dropout, Embedding, LayerNorm, Linear, Module, Parameter};
use flodl::{
DType, Device, FlowBuilder, Graph, Result, Tensor, TensorError, TensorOptions, Variable,
};
use crate::models::bert::build_extended_attention_mask;
use crate::models::transformer_layer::{LayerNaming, TransformerLayer, TransformerLayerConfig};
use crate::path::{prefix_params, HfPath};
use crate::task_heads::{check_num_labels, default_labels, extract_best_span, logits_to_sorted_labels};
pub use crate::task_heads::{Answer, TokenPrediction};
#[derive(Debug, Clone)]
pub struct DistilBertConfig {
pub vocab_size: i64,
pub dim: i64,
pub n_layers: i64,
pub n_heads: i64,
pub hidden_dim: i64,
pub max_position_embeddings: i64,
pub pad_token_id: i64,
pub dropout: f64,
pub attention_dropout: f64,
pub qa_dropout: f64,
pub seq_classif_dropout: f64,
pub sinusoidal_pos_embds: bool,
pub layer_norm_eps: f64,
pub num_labels: Option<i64>,
pub id2label: Option<Vec<String>>,
}
impl DistilBertConfig {
pub fn distilbert_base_uncased() -> Self {
DistilBertConfig {
vocab_size: 30522,
dim: 768,
n_layers: 6,
n_heads: 12,
hidden_dim: 3072,
max_position_embeddings: 512,
pad_token_id: 0,
dropout: 0.1,
attention_dropout: 0.1,
qa_dropout: 0.1,
seq_classif_dropout: 0.2,
sinusoidal_pos_embds: false,
layer_norm_eps: 1e-12,
num_labels: None,
id2label: None,
}
}
pub fn from_json_str(s: &str) -> Result<Self> {
use crate::config_json::{
optional_bool, optional_f64, optional_i64, parse_id2label, parse_num_labels,
required_i64,
};
let v: serde_json::Value = serde_json::from_str(s)
.map_err(|e| TensorError::new(&format!("config.json parse error: {e}")))?;
let id2label = parse_id2label(&v)?;
let num_labels = parse_num_labels(&v, id2label.as_deref());
Ok(DistilBertConfig {
vocab_size: required_i64(&v, "vocab_size")?,
dim: required_i64(&v, "dim")?,
n_layers: required_i64(&v, "n_layers")?,
n_heads: required_i64(&v, "n_heads")?,
hidden_dim: required_i64(&v, "hidden_dim")?,
max_position_embeddings: required_i64(&v, "max_position_embeddings")?,
pad_token_id: optional_i64(&v, "pad_token_id", 0),
dropout: optional_f64(&v, "dropout", 0.1),
attention_dropout: optional_f64(&v, "attention_dropout", 0.1),
qa_dropout: optional_f64(&v, "qa_dropout", 0.1),
seq_classif_dropout: optional_f64(&v, "seq_classif_dropout", 0.2),
sinusoidal_pos_embds: optional_bool(&v, "sinusoidal_pos_embds", false),
layer_norm_eps: optional_f64(&v, "layer_norm_eps", 1e-12),
num_labels,
id2label,
})
}
}
pub struct DistilBertEmbeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
layer_norm: LayerNorm,
dropout: Dropout,
}
impl DistilBertEmbeddings {
pub fn on_device(config: &DistilBertConfig, device: Device) -> Result<Self> {
Ok(DistilBertEmbeddings {
word_embeddings: Embedding::on_device_with_padding_idx(
config.vocab_size,
config.dim,
Some(config.pad_token_id),
device,
)?,
position_embeddings: Embedding::on_device(
config.max_position_embeddings,
config.dim,
device,
)?,
layer_norm: LayerNorm::on_device_with_eps(
config.dim,
config.layer_norm_eps,
device,
)?,
dropout: Dropout::new(config.dropout),
})
}
fn position_ids_from_input_ids(input_ids: &Tensor) -> Result<Tensor> {
let shape = input_ids.shape();
assert_eq!(shape.len(), 2, "input_ids must be [B, S], got {shape:?}");
let batch = shape[0];
let seq = shape[1];
let pos = Tensor::arange(
0.0,
seq as f64,
1.0,
TensorOptions { dtype: DType::Int64, device: input_ids.device() },
)?;
pos.reshape(&[1, seq])?.expand(&[batch, seq])
}
}
impl Module for DistilBertEmbeddings {
fn name(&self) -> &str { "distilbert_embeddings" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let pos_ids = Self::position_ids_from_input_ids(&input.data())?;
let pos_var = Variable::new(pos_ids, false);
let word = self.word_embeddings.forward(input)?;
let pe = self.position_embeddings.forward(&pos_var)?;
let summed = word.add(&pe)?;
let ln = self.layer_norm.forward(&summed)?;
self.dropout.forward(&ln)
}
fn parameters(&self) -> Vec<Parameter> {
let mut out = Vec::new();
out.extend(prefix_params("word_embeddings", self.word_embeddings.parameters()));
out.extend(prefix_params("position_embeddings", self.position_embeddings.parameters()));
out.extend(prefix_params("LayerNorm", self.layer_norm.parameters()));
out
}
fn set_training(&self, training: bool) {
self.dropout.set_training(training);
}
}
fn distilbert_layer_config(config: &DistilBertConfig) -> TransformerLayerConfig {
TransformerLayerConfig {
hidden_size: config.dim,
num_attention_heads: config.n_heads,
intermediate_size: config.hidden_dim,
hidden_dropout_prob: config.dropout,
attention_probs_dropout_prob: config.attention_dropout,
layer_norm_eps: config.layer_norm_eps,
}
}
fn distilbert_backbone_flow(
config: &DistilBertConfig,
device: Device,
) -> Result<FlowBuilder> {
let mut fb = FlowBuilder::new()
.input(&["attention_mask"])
.through(DistilBertEmbeddings::on_device(config, device)?)
.tag("distilbert.embeddings");
let layer_root = HfPath::new("distilbert").sub("transformer").sub("layer");
let layer_cfg = distilbert_layer_config(config);
for i in 0..config.n_layers {
let tag = layer_root.sub(i).to_string();
fb = fb
.through(TransformerLayer::on_device(&layer_cfg, LayerNaming::DISTILBERT, device)?)
.tag(&tag)
.using(&["attention_mask"]);
}
Ok(fb)
}
pub struct DistilBertModel;
impl DistilBertModel {
pub fn build(config: &DistilBertConfig) -> Result<Graph> {
Self::on_device(config, Device::CPU)
}
pub fn on_device(config: &DistilBertConfig, device: Device) -> Result<Graph> {
distilbert_backbone_flow(config, device)?.build()
}
}
struct SelectClsLinear {
linear: Linear,
}
impl Module for SelectClsLinear {
fn name(&self) -> &str { "select_cls_linear" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let cls = input.select(1, 0)?;
self.linear.forward(&cls)
}
fn parameters(&self) -> Vec<Parameter> {
self.linear.parameters()
}
}
struct ActivationDropoutLinear {
dropout: Dropout,
linear: Linear,
training: Cell<bool>,
}
impl Module for ActivationDropoutLinear {
fn name(&self) -> &str { "activation_dropout_linear" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let acted = input.relu()?;
let dropped = if self.training.get() {
self.dropout.forward(&acted)?
} else {
acted
};
self.linear.forward(&dropped)
}
fn parameters(&self) -> Vec<Parameter> {
self.linear.parameters()
}
fn set_training(&self, training: bool) {
self.training.set(training);
self.dropout.set_training(training);
}
}
pub struct DistilBertForSequenceClassification {
graph: Graph,
id2label: Vec<String>,
#[cfg(feature = "tokenizer")]
tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl DistilBertForSequenceClassification {
pub fn on_device(
config: &DistilBertConfig,
num_labels: i64,
device: Device,
) -> Result<Self> {
let num_labels = check_num_labels(num_labels)?;
let graph = distilbert_backbone_flow(config, device)?
.through(SelectClsLinear {
linear: Linear::on_device(config.dim, config.dim, device)?,
})
.tag("pre_classifier")
.through(ActivationDropoutLinear {
dropout: Dropout::new(config.seq_classif_dropout),
linear: Linear::on_device(config.dim, num_labels, device)?,
training: Cell::new(true),
})
.tag("classifier")
.build()?;
let id2label = config
.id2label
.clone()
.unwrap_or_else(|| default_labels(num_labels));
Ok(Self {
graph,
id2label,
#[cfg(feature = "tokenizer")]
tokenizer: None,
})
}
pub(crate) fn num_labels_from_config(config: &DistilBertConfig) -> Result<i64> {
config.num_labels.ok_or_else(|| {
TensorError::new(
"DistilBertForSequenceClassification: config.json has no \
`num_labels` (nor `id2label`); cannot infer head size",
)
})
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn labels(&self) -> &[String] { &self.id2label }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
#[cfg(feature = "tokenizer")]
pub fn classify(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Vec<(String, f32)>>> {
let logits = self.forward_from_encoded(enc)?;
logits_to_sorted_labels(&logits, &self.id2label)
}
#[cfg(feature = "tokenizer")]
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<(String, f32)>>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"DistilBertForSequenceClassification::predict requires a \
tokenizer; use from_pretrained or .with_tokenizer(...) first",
)
})?;
let enc = tok.encode(texts)?;
self.classify(&enc)
}
#[cfg(feature = "tokenizer")]
fn forward_from_encoded(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Variable> {
self.graph.eval();
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
self.graph.forward_multi(&[enc.input_ids.clone(), mask])
}
}
pub struct DistilBertForTokenClassification {
graph: Graph,
id2label: Vec<String>,
#[cfg(feature = "tokenizer")]
tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl DistilBertForTokenClassification {
pub fn on_device(
config: &DistilBertConfig,
num_labels: i64,
device: Device,
) -> Result<Self> {
let num_labels = check_num_labels(num_labels)?;
let graph = distilbert_backbone_flow(config, device)?
.through(Dropout::new(config.dropout))
.through(Linear::on_device(config.dim, num_labels, device)?)
.tag("classifier")
.build()?;
let id2label = config
.id2label
.clone()
.unwrap_or_else(|| default_labels(num_labels));
Ok(Self {
graph,
id2label,
#[cfg(feature = "tokenizer")]
tokenizer: None,
})
}
pub(crate) fn num_labels_from_config(config: &DistilBertConfig) -> Result<i64> {
config.num_labels.ok_or_else(|| {
TensorError::new(
"DistilBertForTokenClassification: config.json has no \
`num_labels` (nor `id2label`); cannot infer head size",
)
})
}
pub fn graph(&self) -> &Graph { &self.graph }
pub fn labels(&self) -> &[String] { &self.id2label }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
#[cfg(feature = "tokenizer")]
pub fn tag(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Vec<TokenPrediction>>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"DistilBertForTokenClassification::tag requires a tokenizer; \
attach one via .with_tokenizer(...) or from_pretrained",
)
})?;
self.graph.eval();
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
let logits = self.graph.forward_multi(&[enc.input_ids.clone(), mask])?;
let probs = logits.softmax(-1)?;
let shape = probs.shape();
assert_eq!(shape.len(), 3, "expected [B, S, num_labels], got {shape:?}");
let batch = shape[0] as usize;
let seq = shape[1] as usize;
let n = shape[2] as usize;
let flat = probs.data().to_f32_vec()?;
let input_ids: Vec<i64> = enc.input_ids.data().to_i64_vec()?;
let attn_ids: Vec<i64> = enc.attention_mask.data().to_i64_vec()?;
let mut out = Vec::with_capacity(batch);
for b in 0..batch {
let mut row = Vec::with_capacity(seq);
for s in 0..seq {
let offset = (b * seq + s) * n;
let slice = &flat[offset..offset + n];
let (argmax, score) = slice
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bs), (i, &v)| {
if v > bs { (i, v) } else { (bi, bs) }
});
let token_id = input_ids[b * seq + s] as u32;
let token = tok
.inner()
.id_to_token(token_id)
.unwrap_or_else(|| format!("<{token_id}>"));
let attends = attn_ids[b * seq + s] != 0;
row.push(TokenPrediction {
token,
label: self.id2label[argmax].clone(),
score,
attends,
});
}
out.push(row);
}
Ok(out)
}
#[cfg(feature = "tokenizer")]
pub fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<TokenPrediction>>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"DistilBertForTokenClassification::predict requires a \
tokenizer; use from_pretrained or .with_tokenizer(...) first",
)
})?;
let enc = tok.encode(texts)?;
self.tag(&enc)
}
}
pub struct DistilBertForQuestionAnswering {
graph: Graph,
#[cfg(feature = "tokenizer")]
tokenizer: Option<crate::tokenizer::HfTokenizer>,
}
impl DistilBertForQuestionAnswering {
pub fn on_device(config: &DistilBertConfig, device: Device) -> Result<Self> {
let graph = distilbert_backbone_flow(config, device)?
.through(Dropout::new(config.qa_dropout))
.through(Linear::on_device(config.dim, 2, device)?)
.tag("qa_outputs")
.build()?;
Ok(Self {
graph,
#[cfg(feature = "tokenizer")]
tokenizer: None,
})
}
pub fn graph(&self) -> &Graph { &self.graph }
#[cfg(feature = "tokenizer")]
pub fn with_tokenizer(mut self, tok: crate::tokenizer::HfTokenizer) -> Self {
self.tokenizer = Some(tok);
self
}
#[cfg(feature = "tokenizer")]
pub fn answer(&self, question: &str, context: &str) -> Result<Answer> {
let mut out = self.answer_batch(&[(question, context)])?;
Ok(out.pop().expect("answer_batch returns one per input"))
}
#[cfg(feature = "tokenizer")]
pub fn answer_batch(&self, pairs: &[(&str, &str)]) -> Result<Vec<Answer>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"DistilBertForQuestionAnswering::answer requires a tokenizer; \
use from_pretrained or .with_tokenizer(...) first",
)
})?;
let enc = tok.encode_pairs(pairs)?;
self.extract(&enc)
}
#[cfg(feature = "tokenizer")]
pub fn extract(
&self,
enc: &crate::tokenizer::EncodedBatch,
) -> Result<Vec<Answer>> {
let tok = self.tokenizer.as_ref().ok_or_else(|| {
TensorError::new(
"DistilBertForQuestionAnswering::extract requires a tokenizer; \
attach one via .with_tokenizer(...) or from_pretrained",
)
})?;
self.graph.eval();
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
let logits = self.graph.forward_multi(&[enc.input_ids.clone(), mask])?;
extract_best_span(&logits, enc, tok)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::safetensors_io::expected_from_graph;
fn expected_layer_keys(i: i64) -> Vec<String> {
let suffixes = [
"attention.k_lin.bias",
"attention.k_lin.weight",
"attention.out_lin.bias",
"attention.out_lin.weight",
"attention.q_lin.bias",
"attention.q_lin.weight",
"attention.v_lin.bias",
"attention.v_lin.weight",
"ffn.lin1.bias",
"ffn.lin1.weight",
"ffn.lin2.bias",
"ffn.lin2.weight",
"output_layer_norm.bias",
"output_layer_norm.weight",
"sa_layer_norm.bias",
"sa_layer_norm.weight",
];
suffixes.iter()
.map(|s| format!("distilbert.transformer.layer.{i}.{s}"))
.collect()
}
#[test]
fn distilbert_parameter_keys_match_hf_dotted_form() {
let config = DistilBertConfig::distilbert_base_uncased();
let graph = DistilBertModel::build(&config).unwrap();
let expected = expected_from_graph(&graph);
let mut keys: Vec<String> = expected.iter().map(|p| p.key.clone()).collect();
keys.sort();
let mut want: Vec<String> = vec![
"distilbert.embeddings.LayerNorm.bias".into(),
"distilbert.embeddings.LayerNorm.weight".into(),
"distilbert.embeddings.position_embeddings.weight".into(),
"distilbert.embeddings.word_embeddings.weight".into(),
];
for i in 0..config.n_layers {
want.extend(expected_layer_keys(i));
}
want.sort();
assert_eq!(want.len(), 100, "expected-key list size drift");
assert_eq!(keys, want, "DistilBERT parameter keys must match HF exactly");
}
#[test]
fn distilbert_parameter_shapes_match_base_uncased() {
let config = DistilBertConfig::distilbert_base_uncased();
let graph = DistilBertModel::build(&config).unwrap();
let expected = expected_from_graph(&graph);
let by_key: std::collections::HashMap<&str, &[i64]> = expected
.iter()
.map(|p| (p.key.as_str(), p.shape.as_slice()))
.collect();
assert_eq!(by_key["distilbert.embeddings.word_embeddings.weight"], &[30522, 768]);
assert_eq!(by_key["distilbert.embeddings.position_embeddings.weight"], &[512, 768]);
assert_eq!(by_key["distilbert.embeddings.LayerNorm.weight"], &[768]);
assert_eq!(by_key["distilbert.embeddings.LayerNorm.bias"], &[768]);
for i in 0..config.n_layers {
let p = format!("distilbert.transformer.layer.{i}");
assert_eq!(by_key[&*format!("{p}.attention.q_lin.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.q_lin.bias")], &[768]);
assert_eq!(by_key[&*format!("{p}.attention.k_lin.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.v_lin.weight")], &[768, 768]);
assert_eq!(by_key[&*format!("{p}.attention.out_lin.weight")],&[768, 768]);
assert_eq!(by_key[&*format!("{p}.sa_layer_norm.weight")], &[768]);
assert_eq!(by_key[&*format!("{p}.ffn.lin1.weight")], &[3072, 768]);
assert_eq!(by_key[&*format!("{p}.ffn.lin1.bias")], &[3072]);
assert_eq!(by_key[&*format!("{p}.ffn.lin2.weight")], &[768, 3072]);
assert_eq!(by_key[&*format!("{p}.ffn.lin2.bias")], &[768]);
assert_eq!(by_key[&*format!("{p}.output_layer_norm.weight")],&[768]);
}
}
#[test]
fn distilbert_layer_count_scales_with_config() {
for n in [1_i64, 3, 6] {
let config = DistilBertConfig {
n_layers: n,
..DistilBertConfig::distilbert_base_uncased()
};
let graph = DistilBertModel::build(&config).unwrap();
let expected = expected_from_graph(&graph);
let total = expected.len();
let want_total = 4 + 16 * n as usize;
assert_eq!(
total, want_total,
"n_layers={n}: got {total} keys, expected {want_total}",
);
}
}
#[test]
fn seqcls_head_adds_four_keys() {
let config = DistilBertConfig {
num_labels: Some(3),
..DistilBertConfig::distilbert_base_uncased()
};
let head = DistilBertForSequenceClassification::on_device(&config, 3, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let keys: Vec<String> = expected.iter().map(|p| p.key.clone()).collect();
assert_eq!(expected.len(), 100 + 4, "backbone + SeqCls head key count");
assert!(keys.iter().any(|k| k == "pre_classifier.weight"));
assert!(keys.iter().any(|k| k == "pre_classifier.bias"));
assert!(keys.iter().any(|k| k == "classifier.weight"));
assert!(keys.iter().any(|k| k == "classifier.bias"));
}
#[test]
fn tokencls_head_adds_two_keys() {
let config = DistilBertConfig {
num_labels: Some(9),
..DistilBertConfig::distilbert_base_uncased()
};
let head = DistilBertForTokenClassification::on_device(&config, 9, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let keys: Vec<String> = expected.iter().map(|p| p.key.clone()).collect();
assert_eq!(expected.len(), 100 + 2, "backbone + TokenCls head key count");
assert!(keys.iter().any(|k| k == "classifier.weight"));
assert!(keys.iter().any(|k| k == "classifier.bias"));
}
#[test]
fn qa_head_adds_two_keys_shape_2_dim() {
let config = DistilBertConfig::distilbert_base_uncased();
let head = DistilBertForQuestionAnswering::on_device(&config, Device::CPU).unwrap();
let expected = expected_from_graph(head.graph());
let by_key: std::collections::HashMap<&str, &[i64]> = expected
.iter()
.map(|p| (p.key.as_str(), p.shape.as_slice()))
.collect();
assert_eq!(expected.len(), 100 + 2, "backbone + QA head key count");
assert_eq!(by_key["qa_outputs.weight"], &[2, 768]);
assert_eq!(by_key["qa_outputs.bias"], &[2]);
}
#[test]
fn seqcls_num_labels_required() {
let config = DistilBertConfig::distilbert_base_uncased();
let err = DistilBertForSequenceClassification::num_labels_from_config(&config).unwrap_err();
assert!(format!("{err}").contains("num_labels"), "got: {err}");
}
#[test]
fn parses_distilbert_base_uncased_config() {
let json = r#"{
"activation": "gelu",
"architectures": ["DistilBertForMaskedLM"],
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"model_type": "distilbert",
"n_heads": 12,
"n_layers": 6,
"pad_token_id": 0,
"qa_dropout": 0.1,
"seq_classif_dropout": 0.2,
"sinusoidal_pos_embds": false,
"tie_weights_": true,
"vocab_size": 30522
}"#;
let cfg = DistilBertConfig::from_json_str(json).unwrap();
assert_eq!(cfg.vocab_size, 30522);
assert_eq!(cfg.dim, 768);
assert_eq!(cfg.n_layers, 6);
assert_eq!(cfg.n_heads, 12);
assert_eq!(cfg.hidden_dim, 3072);
assert_eq!(cfg.max_position_embeddings, 512);
assert_eq!(cfg.pad_token_id, 0);
assert!((cfg.dropout - 0.1).abs() < 1e-12);
assert!((cfg.attention_dropout - 0.1).abs() < 1e-12);
assert!((cfg.qa_dropout - 0.1).abs() < 1e-12);
assert!((cfg.seq_classif_dropout - 0.2).abs() < 1e-12);
assert!(!cfg.sinusoidal_pos_embds);
assert!((cfg.layer_norm_eps - 1e-12).abs() < 1e-18);
assert!(cfg.num_labels.is_none());
assert!(cfg.id2label.is_none());
}
#[test]
fn parses_cased_distilled_squad_config() {
let json = r#"{
"activation": "gelu",
"architectures": ["DistilBertForQuestionAnswering"],
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"max_position_embeddings": 512,
"model_type": "distilbert",
"n_heads": 12,
"n_layers": 6,
"pad_token_id": 0,
"qa_dropout": 0.1,
"seq_classif_dropout": 0.2,
"sinusoidal_pos_embds": true,
"vocab_size": 28996
}"#;
let cfg = DistilBertConfig::from_json_str(json).unwrap();
assert_eq!(cfg.vocab_size, 28996);
assert!(cfg.sinusoidal_pos_embds);
}
#[test]
fn parses_finetuned_seqcls_config() {
let json = r#"{
"activation": "gelu",
"architectures": ["DistilBertForSequenceClassification"],
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"id2label": {"0": "positive", "1": "neutral", "2": "negative"},
"label2id": {"positive": 0, "neutral": 1, "negative": 2},
"max_position_embeddings": 512,
"model_type": "distilbert",
"n_heads": 12,
"n_layers": 6,
"pad_token_id": 0,
"qa_dropout": 0.1,
"seq_classif_dropout": 0.2,
"sinusoidal_pos_embds": false,
"vocab_size": 119547
}"#;
let cfg = DistilBertConfig::from_json_str(json).unwrap();
assert_eq!(cfg.vocab_size, 119547);
assert_eq!(cfg.num_labels, Some(3));
let labels = cfg.id2label.unwrap();
assert_eq!(labels, vec!["positive", "neutral", "negative"]);
}
#[test]
fn missing_required_field_errors() {
let json = r#"{
"vocab_size": 30522, "dim": 768, "n_heads": 12,
"hidden_dim": 3072, "max_position_embeddings": 512
}"#;
let err = DistilBertConfig::from_json_str(json).unwrap_err();
assert!(format!("{err}").contains("n_layers"), "got: {err}");
}
#[test]
fn preset_roundtrips_through_parser() {
let preset = DistilBertConfig::distilbert_base_uncased();
let json = r#"{
"vocab_size": 30522, "dim": 768, "n_layers": 6, "n_heads": 12,
"hidden_dim": 3072, "max_position_embeddings": 512, "pad_token_id": 0
}"#;
let parsed = DistilBertConfig::from_json_str(json).unwrap();
assert_eq!(preset.vocab_size, parsed.vocab_size);
assert_eq!(preset.dim, parsed.dim);
assert_eq!(preset.n_layers, parsed.n_layers);
assert_eq!(preset.n_heads, parsed.n_heads);
assert_eq!(preset.hidden_dim, parsed.hidden_dim);
assert_eq!(preset.pad_token_id, parsed.pad_token_id);
}
}