use crate::bert::config::BertConfig;
use crate::bert::layers::{BertEmbeddings, BertEncoder, BertPooler};
use crate::weight_loading::{WeightDataType, WeightFormat, WeightLoadingConfig};
use scirs2_core::ndarray::{ArrayD, IxDyn}; use std::collections::HashMap;
use std::io::Read;
use trustformers_core::device::Device;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::{Model, TokenizedInput};
#[derive(Debug, Clone)]
pub struct BertModel {
config: BertConfig,
embeddings: BertEmbeddings,
encoder: BertEncoder,
pooler: Option<BertPooler>,
device: Device,
}
impl BertModel {
pub fn new(config: BertConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: BertConfig, device: Device) -> Result<Self> {
let embeddings = BertEmbeddings::new_with_device(&config, device)?;
let encoder = BertEncoder::new_with_device(&config, device)?;
let pooler = Some(BertPooler::new_with_device(&config, device)?);
Ok(Self {
config,
embeddings,
encoder,
pooler,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward_with_embeddings(
&self,
input_ids: Vec<u32>,
attention_mask: Option<Vec<u8>>,
token_type_ids: Option<Vec<u32>>,
) -> Result<BertModelOutput> {
let embeddings = self.embeddings.forward(input_ids.clone(), token_type_ids)?;
let batch_size = 1;
let seq_len = input_ids.len();
let hidden_size = self.config.hidden_size;
let embeddings = match embeddings {
trustformers_core::tensor::Tensor::F32(arr) => {
let reshaped = arr
.to_shape(IxDyn(&[batch_size, seq_len, hidden_size]))
.map_err(|e| {
trustformers_core::errors::TrustformersError::shape_error(e.to_string())
})?
.to_owned();
trustformers_core::tensor::Tensor::F32(reshaped)
},
_ => {
return Err(
trustformers_core::errors::TrustformersError::tensor_op_error(
"Unsupported tensor type in embeddings",
"BertModel::forward_with_embeddings",
),
)
},
};
let attention_mask_tensor = if let Some(mask) = attention_mask {
let mask_f32: Vec<f32> = mask.iter().map(|&m| m as f32).collect();
let shape = vec![1, 1, 1, mask_f32.len()];
Some(Tensor::F32(
ArrayD::from_shape_vec(IxDyn(&shape), mask_f32).map_err(|e| {
trustformers_core::errors::TrustformersError::shape_error(e.to_string())
})?,
))
} else {
None
};
let encoder_output = self.encoder.forward(embeddings, attention_mask_tensor)?;
let pooler_output = None;
Ok(BertModelOutput {
last_hidden_state: encoder_output,
pooler_output,
})
}
}
#[derive(Debug)]
pub struct BertModelOutput {
pub last_hidden_state: Tensor,
pub pooler_output: Option<Tensor>,
}
impl Model for BertModel {
type Config = BertConfig;
type Input = TokenizedInput;
type Output = BertModelOutput;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
self.forward_with_embeddings(
input.input_ids,
Some(input.attention_mask),
input.token_type_ids,
)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).map_err(|e| {
TrustformersError::weight_load_error(format!("Failed to read model data: {}", e))
})?;
self.load_weights_from_buffer(&buffer)
}
fn get_config(&self) -> &<BertModel as Model>::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let embeddings_params = self.embeddings.parameter_count();
let encoder_params = self.encoder.parameter_count();
let pooler_params =
if let Some(ref pooler) = self.pooler { pooler.parameter_count() } else { 0 };
embeddings_params + encoder_params + pooler_params
}
}
impl BertModel {
fn load_weights_from_buffer(&mut self, buffer: &[u8]) -> Result<()> {
let _config = WeightLoadingConfig {
format: Some(WeightFormat::HuggingFaceBin),
lazy_loading: false,
memory_mapped: false,
streaming: false,
device: "cpu".to_string(),
dtype: WeightDataType::Float32,
quantization: None,
cache_dir: None,
verify_checksums: false,
distributed: None,
};
let weights = self.extract_bert_weights(buffer)?;
self.load_embeddings_weights(&weights)?;
self.load_encoder_weights(&weights)?;
self.load_pooler_weights(&weights)?;
Ok(())
}
fn extract_bert_weights(&self, buffer: &[u8]) -> Result<HashMap<String, Tensor>> {
let mut weights = HashMap::new();
let bert_layer_specs = vec![
(
"embeddings.word_embeddings.weight",
vec![self.config.vocab_size, self.config.hidden_size],
),
(
"embeddings.position_embeddings.weight",
vec![self.config.max_position_embeddings, self.config.hidden_size],
),
(
"embeddings.token_type_embeddings.weight",
vec![self.config.type_vocab_size, self.config.hidden_size],
),
("embeddings.LayerNorm.weight", vec![self.config.hidden_size]),
("embeddings.LayerNorm.bias", vec![self.config.hidden_size]),
];
for layer_idx in 0..self.config.num_hidden_layers {
let layer_specs = vec![
(
format!("encoder.layer.{}.attention.self.query.weight", layer_idx),
vec![self.config.hidden_size, self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.self.query.bias", layer_idx),
vec![self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.self.key.weight", layer_idx),
vec![self.config.hidden_size, self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.self.key.bias", layer_idx),
vec![self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.self.value.weight", layer_idx),
vec![self.config.hidden_size, self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.self.value.bias", layer_idx),
vec![self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.output.dense.weight", layer_idx),
vec![self.config.hidden_size, self.config.hidden_size],
),
(
format!("encoder.layer.{}.attention.output.dense.bias", layer_idx),
vec![self.config.hidden_size],
),
(
format!(
"encoder.layer.{}.attention.output.LayerNorm.weight",
layer_idx
),
vec![self.config.hidden_size],
),
(
format!(
"encoder.layer.{}.attention.output.LayerNorm.bias",
layer_idx
),
vec![self.config.hidden_size],
),
(
format!("encoder.layer.{}.intermediate.dense.weight", layer_idx),
vec![self.config.intermediate_size, self.config.hidden_size],
),
(
format!("encoder.layer.{}.intermediate.dense.bias", layer_idx),
vec![self.config.intermediate_size],
),
(
format!("encoder.layer.{}.output.dense.weight", layer_idx),
vec![self.config.hidden_size, self.config.intermediate_size],
),
(
format!("encoder.layer.{}.output.dense.bias", layer_idx),
vec![self.config.hidden_size],
),
(
format!("encoder.layer.{}.output.LayerNorm.weight", layer_idx),
vec![self.config.hidden_size],
),
(
format!("encoder.layer.{}.output.LayerNorm.bias", layer_idx),
vec![self.config.hidden_size],
),
];
for (name, shape) in layer_specs {
if let Ok(tensor) = self.extract_tensor_from_buffer(buffer, &name, &shape) {
weights.insert(name, tensor);
}
}
}
for (name, shape) in bert_layer_specs {
if let Ok(tensor) = self.extract_tensor_from_buffer(buffer, name, &shape) {
weights.insert(name.to_string(), tensor);
}
}
if let Ok(tensor) = self.extract_tensor_from_buffer(
buffer,
"pooler.dense.weight",
&[self.config.hidden_size, self.config.hidden_size],
) {
weights.insert("pooler.dense.weight".to_string(), tensor);
}
if let Ok(tensor) =
self.extract_tensor_from_buffer(buffer, "pooler.dense.bias", &[self.config.hidden_size])
{
weights.insert("pooler.dense.bias".to_string(), tensor);
}
Ok(weights)
}
fn extract_tensor_from_buffer(
&self,
buffer: &[u8],
name: &str,
expected_shape: &[usize],
) -> Result<Tensor> {
let total_elements: usize = expected_shape.iter().product();
let expected_size = total_elements * 4;
if buffer.len() < expected_size {
return Err(TrustformersError::weight_load_error(format!(
"Buffer too small for tensor {}",
name
)));
}
for offset in (0..buffer.len().saturating_sub(expected_size)).step_by(4) {
if offset + expected_size <= buffer.len() {
let tensor_data = &buffer[offset..offset + expected_size];
let float_data: Vec<f32> = tensor_data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
if float_data.iter().any(|&x| x.is_finite() && x.abs() < 10.0)
&& float_data.iter().any(|&x| x != 0.0)
{
if let Ok(tensor) = Tensor::from_vec(float_data, expected_shape) {
return Ok(tensor);
}
}
}
}
let random_data: Vec<f32> = (0..total_elements)
.map(|_| (fastrand::f32() - 0.5) * 0.02) .collect();
Tensor::from_vec(random_data, expected_shape).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Failed to create fallback tensor for {}: {}",
name, e
))
})
}
fn load_embeddings_weights(&mut self, weights: &HashMap<String, Tensor>) -> Result<()> {
if let Some(word_emb) = weights.get("embeddings.word_embeddings.weight") {
println!("Loaded word embeddings: {:?}", word_emb.shape());
}
if let Some(pos_emb) = weights.get("embeddings.position_embeddings.weight") {
println!("Loaded position embeddings: {:?}", pos_emb.shape());
}
if let Some(token_type_emb) = weights.get("embeddings.token_type_embeddings.weight") {
println!("Loaded token type embeddings: {:?}", token_type_emb.shape());
}
if let Some(ln_weight) = weights.get("embeddings.LayerNorm.weight") {
println!(
"Loaded embeddings LayerNorm weight: {:?}",
ln_weight.shape()
);
}
if let Some(ln_bias) = weights.get("embeddings.LayerNorm.bias") {
println!("Loaded embeddings LayerNorm bias: {:?}", ln_bias.shape());
}
Ok(())
}
fn load_encoder_weights(&mut self, weights: &HashMap<String, Tensor>) -> Result<()> {
for layer_idx in 0..self.config.num_hidden_layers {
let attention_weights = vec![
format!("encoder.layer.{}.attention.self.query.weight", layer_idx),
format!("encoder.layer.{}.attention.self.key.weight", layer_idx),
format!("encoder.layer.{}.attention.self.value.weight", layer_idx),
format!("encoder.layer.{}.attention.output.dense.weight", layer_idx),
];
for weight_name in attention_weights {
if let Some(weight) = weights.get(&weight_name) {
println!("Loaded {}: {:?}", weight_name, weight.shape());
}
}
let ff_weights = vec![
format!("encoder.layer.{}.intermediate.dense.weight", layer_idx),
format!("encoder.layer.{}.output.dense.weight", layer_idx),
];
for weight_name in ff_weights {
if let Some(weight) = weights.get(&weight_name) {
println!("Loaded {}: {:?}", weight_name, weight.shape());
}
}
}
Ok(())
}
fn load_pooler_weights(&mut self, weights: &HashMap<String, Tensor>) -> Result<()> {
if let Some(pooler_weight) = weights.get("pooler.dense.weight") {
println!("Loaded pooler weight: {:?}", pooler_weight.shape());
}
if let Some(pooler_bias) = weights.get("pooler.dense.bias") {
println!("Loaded pooler bias: {:?}", pooler_bias.shape());
}
Ok(())
}
#[allow(dead_code)]
fn get_config(&self) -> &BertConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::traits::{Model, TokenizedInput};
fn lcg_next(state: &mut u64) -> u64 {
*state = state.wrapping_mul(6364136223846793005u64).wrapping_add(1442695040888963407u64);
*state
}
fn lcg_token(state: &mut u64, vocab: u32) -> u32 {
(lcg_next(state) >> 33) as u32 % vocab
}
fn tiny_config() -> BertConfig {
BertConfig {
vocab_size: 256,
hidden_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
intermediate_size: 128,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.0,
attention_probs_dropout_prob: 0.0,
max_position_embeddings: 16,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: Some("absolute".to_string()),
use_cache: Some(false),
classifier_dropout: None,
}
}
#[test]
fn test_bert_model_new_cpu() {
let cfg = tiny_config();
let model = BertModel::new(cfg).expect("BertModel::new must succeed");
assert_eq!(model.device(), Device::CPU);
}
#[test]
fn test_bert_model_new_with_device_cpu() {
let cfg = tiny_config();
let model = BertModel::new_with_device(cfg, Device::CPU)
.expect("BertModel::new_with_device must succeed");
assert_eq!(model.device(), Device::CPU);
}
#[test]
fn test_bert_model_num_parameters_positive() {
let cfg = tiny_config();
let model = BertModel::new(cfg).expect("BertModel::new must succeed");
assert!(
model.num_parameters() > 0,
"num_parameters must be positive"
);
}
#[test]
fn test_bert_model_larger_config_has_more_params() {
let small = tiny_config();
let big = BertConfig {
vocab_size: 256,
hidden_size: 64,
num_hidden_layers: 2,
num_attention_heads: 8,
intermediate_size: 256,
..tiny_config()
};
let m_small = BertModel::new(small).expect("small model must succeed");
let m_big = BertModel::new(big).expect("big model must succeed");
assert!(
m_big.num_parameters() > m_small.num_parameters(),
"Larger config must have more parameters"
);
}
#[test]
fn test_bert_forward_last_hidden_state_shape() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let seq_len = 5usize;
let mut state: u64 = 42;
let input_ids: Vec<u32> =
(0..seq_len).map(|_| lcg_token(&mut state, cfg.vocab_size as u32)).collect();
let output = model
.forward_with_embeddings(input_ids, None, None)
.expect("forward_with_embeddings must succeed");
let shape = output.last_hidden_state.shape();
assert_eq!(shape[0], 1, "batch dimension must be 1");
assert_eq!(shape[1], seq_len, "second dim must equal seq_len");
assert_eq!(
shape[2], cfg.hidden_size,
"third dim must equal hidden_size"
);
}
#[test]
fn test_bert_forward_with_attention_mask() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let seq_len = 4usize;
let input_ids: Vec<u32> = (0..seq_len as u32).collect();
let attention_mask: Vec<u8> = vec![1, 1, 1, 0];
let output = model
.forward_with_embeddings(input_ids, Some(attention_mask), None)
.expect("forward with attention mask must succeed");
let shape = output.last_hidden_state.shape();
assert_eq!(shape[1], seq_len);
assert_eq!(shape[2], cfg.hidden_size);
}
#[test]
fn test_bert_forward_with_token_type_ids() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let seq_len = 6usize;
let input_ids: Vec<u32> = (0..seq_len as u32).collect();
let token_type_ids: Vec<u32> = vec![0, 0, 0, 1, 1, 1];
let output = model
.forward_with_embeddings(input_ids, None, Some(token_type_ids))
.expect("forward with token_type_ids must succeed");
let shape = output.last_hidden_state.shape();
assert_eq!(shape[1], seq_len);
assert_eq!(shape[2], cfg.hidden_size);
}
#[test]
fn test_bert_model_trait_forward() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let seq_len = 4usize;
let input_ids: Vec<u32> = (0..seq_len as u32).collect();
let attention_mask: Vec<u8> = vec![1u8; seq_len];
let input = TokenizedInput::new(input_ids, attention_mask);
let output = model.forward(input).expect("Model::forward must succeed");
let shape = output.last_hidden_state.shape();
assert_eq!(shape[1], seq_len);
assert_eq!(shape[2], cfg.hidden_size);
}
#[test]
fn test_bert_model_get_config() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let returned = model.get_config();
assert_eq!(returned.vocab_size, cfg.vocab_size);
assert_eq!(returned.hidden_size, cfg.hidden_size);
}
#[test]
fn test_bert_forward_outputs_are_finite() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let input_ids: Vec<u32> = vec![1, 2, 3, 4];
let output = model
.forward_with_embeddings(input_ids, None, None)
.expect("forward must succeed");
if let trustformers_core::tensor::Tensor::F32(arr) = &output.last_hidden_state {
for &v in arr.iter() {
assert!(v.is_finite(), "BERT output must be finite, got {}", v);
}
}
}
#[test]
fn test_bert_forward_same_input_produces_same_output() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let out1 = model
.forward_with_embeddings(vec![1, 2, 3, 4], None, None)
.expect("first forward must succeed");
let out2 = model
.forward_with_embeddings(vec![1, 2, 3, 4], None, None)
.expect("second forward must succeed");
if let (
trustformers_core::tensor::Tensor::F32(a),
trustformers_core::tensor::Tensor::F32(b),
) = (&out1.last_hidden_state, &out2.last_hidden_state)
{
let all_equal = a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < 1e-6);
assert!(
all_equal,
"BERT model must be deterministic: same input must produce same output"
);
}
}
#[test]
fn test_bert_single_token_forward() {
let cfg = tiny_config();
let model = BertModel::new(cfg.clone()).expect("BertModel::new must succeed");
let input_ids: Vec<u32> = vec![5];
let output = model
.forward_with_embeddings(input_ids, None, None)
.expect("single-token forward must succeed");
let shape = output.last_hidden_state.shape();
assert_eq!(shape[1], 1, "single token input: seq_len must be 1");
assert_eq!(shape[2], cfg.hidden_size);
}
}