#[derive(Debug, Clone, PartialEq)]
pub struct BertConfig {
pub hidden_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub intermediate_dim: usize,
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub layer_norm_eps: f32,
pub pad_token_id: u32,
}
impl Default for BertConfig {
fn default() -> Self {
Self {
hidden_dim: 768,
num_layers: 12,
num_heads: 12,
intermediate_dim: 3072,
vocab_size: 30522,
max_position_embeddings: 512,
type_vocab_size: 2,
layer_norm_eps: 1e-12,
pad_token_id: 0,
}
}
}
impl BertConfig {
#[must_use]
pub const fn head_dim(&self) -> usize {
self.hidden_dim / self.num_heads
}
#[must_use]
pub fn minilm_l6() -> Self {
Self {
hidden_dim: 384,
num_layers: 6,
num_heads: 12,
intermediate_dim: 1536,
vocab_size: 30522,
max_position_embeddings: 512,
type_vocab_size: 2,
layer_norm_eps: 1e-12,
pad_token_id: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bert_base_default_hyperparams() {
let c = BertConfig::default();
assert_eq!(c.hidden_dim, 768);
assert_eq!(c.num_layers, 12);
assert_eq!(c.num_heads, 12);
assert_eq!(c.intermediate_dim, 3072);
assert_eq!(c.head_dim(), 64);
}
#[test]
fn minilm_l6_preset() {
let c = BertConfig::minilm_l6();
assert_eq!(c.hidden_dim, 384);
assert_eq!(c.num_layers, 6);
assert_eq!(c.head_dim(), 32);
}
#[test]
fn head_dim_divides_evenly() {
let c = BertConfig::default();
assert_eq!(c.head_dim() * c.num_heads, c.hidden_dim);
}
}