use crate::autograd::Tensor;
use crate::format::v2::AprV2Reader;
use crate::models::bert::{BertConfig, BertEmbeddings, BertEncoder, BertLayer, CrossEncoder};
#[must_use]
pub fn expected_bert_tensor_names(
config: &BertConfig,
with_pooler: bool,
classifier_prefix: &str,
) -> Vec<String> {
let mut names = Vec::new();
names.push("bert.embeddings.word_embeddings.weight".to_string());
names.push("bert.embeddings.position_embeddings.weight".to_string());
names.push("bert.embeddings.token_type_embeddings.weight".to_string());
names.push("bert.embeddings.LayerNorm.weight".to_string());
names.push("bert.embeddings.LayerNorm.bias".to_string());
for idx in 0..config.num_layers {
let p = format!("bert.encoder.layer.{idx}");
for proj in ["query", "key", "value"] {
names.push(format!("{p}.attention.self.{proj}.weight"));
names.push(format!("{p}.attention.self.{proj}.bias"));
}
names.push(format!("{p}.attention.output.dense.weight"));
names.push(format!("{p}.attention.output.dense.bias"));
names.push(format!("{p}.attention.output.LayerNorm.weight"));
names.push(format!("{p}.attention.output.LayerNorm.bias"));
names.push(format!("{p}.intermediate.dense.weight"));
names.push(format!("{p}.intermediate.dense.bias"));
names.push(format!("{p}.output.dense.weight"));
names.push(format!("{p}.output.dense.bias"));
names.push(format!("{p}.output.LayerNorm.weight"));
names.push(format!("{p}.output.LayerNorm.bias"));
}
if with_pooler {
names.push("bert.pooler.dense.weight".to_string());
names.push("bert.pooler.dense.bias".to_string());
}
names.push(format!("{classifier_prefix}.weight"));
names.push(format!("{classifier_prefix}.bias"));
names
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BertLoadError {
pub tensor: String,
pub reason: String,
}
impl std::fmt::Display for BertLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BertLoadError({}: {})", self.tensor, self.reason)
}
}
impl std::error::Error for BertLoadError {}
fn read_tensor(
reader: &AprV2Reader,
name: &str,
expected_shape: &[usize],
) -> Result<Tensor, BertLoadError> {
let entry = reader.get_tensor(name).ok_or_else(|| BertLoadError {
tensor: name.to_string(),
reason: "tensor not present in APR file".to_string(),
})?;
let data = reader
.get_tensor_as_f32(name)
.ok_or_else(|| BertLoadError {
tensor: name.to_string(),
reason: format!("get_tensor_as_f32 failed for dtype {:?}", entry.dtype),
})?;
let expected_numel: usize = expected_shape.iter().product();
if data.len() != expected_numel {
return Err(BertLoadError {
tensor: name.to_string(),
reason: format!(
"element count mismatch: got {}, expected {} (shape {:?})",
data.len(),
expected_numel,
expected_shape
),
});
}
Ok(Tensor::from_vec(data, expected_shape))
}
pub(crate) fn load_embeddings_from_reader(
embeddings: &mut BertEmbeddings,
reader: &AprV2Reader,
config: &BertConfig,
) -> Result<(), BertLoadError> {
let h = config.hidden_dim;
embeddings.word_embeddings = read_tensor(
reader,
"bert.embeddings.word_embeddings.weight",
&[config.vocab_size, h],
)?;
embeddings.position_embeddings = read_tensor(
reader,
"bert.embeddings.position_embeddings.weight",
&[config.max_position_embeddings, h],
)?;
embeddings.token_type_embeddings = read_tensor(
reader,
"bert.embeddings.token_type_embeddings.weight",
&[config.type_vocab_size, h],
)?;
embeddings.layer_norm.set_weight(read_tensor(
reader,
"bert.embeddings.LayerNorm.weight",
&[h],
)?);
embeddings
.layer_norm
.set_bias(read_tensor(reader, "bert.embeddings.LayerNorm.bias", &[h])?);
Ok(())
}
pub(crate) fn load_layer_from_reader(
layer: &mut BertLayer,
reader: &AprV2Reader,
idx: usize,
config: &BertConfig,
) -> Result<(), BertLoadError> {
let h = config.hidden_dim;
let im = config.intermediate_dim;
let prefix = format!("bert.encoder.layer.{idx}");
for (proj, name) in [("query", "q"), ("key", "k"), ("value", "v")] {
let w_name = format!("{prefix}.attention.self.{proj}.weight");
let b_name = format!("{prefix}.attention.self.{proj}.bias");
let weight = read_tensor(reader, &w_name, &[h, h])?;
let bias = read_tensor(reader, &b_name, &[h])?;
let proj_linear = match name {
"q" => layer.attention_mut().q_proj_mut(),
"k" => layer.attention_mut().k_proj_mut(),
"v" => layer.attention_mut().v_proj_mut(),
_ => unreachable!("only q/k/v iterated"),
};
proj_linear.set_weight(weight);
proj_linear.set_bias(bias);
}
let attn_out_w = read_tensor(
reader,
&format!("{prefix}.attention.output.dense.weight"),
&[h, h],
)?;
let attn_out_b = read_tensor(
reader,
&format!("{prefix}.attention.output.dense.bias"),
&[h],
)?;
layer.attention_mut().out_proj_mut().set_weight(attn_out_w);
layer.attention_mut().out_proj_mut().set_bias(attn_out_b);
layer.attention_norm_mut().set_weight(read_tensor(
reader,
&format!("{prefix}.attention.output.LayerNorm.weight"),
&[h],
)?);
layer.attention_norm_mut().set_bias(read_tensor(
reader,
&format!("{prefix}.attention.output.LayerNorm.bias"),
&[h],
)?);
let intermediate_w = read_tensor(
reader,
&format!("{prefix}.intermediate.dense.weight"),
&[im, h],
)?;
let intermediate_b = read_tensor(reader, &format!("{prefix}.intermediate.dense.bias"), &[im])?;
layer.intermediate_mut().set_weight(intermediate_w);
layer.intermediate_mut().set_bias(intermediate_b);
let output_w = read_tensor(reader, &format!("{prefix}.output.dense.weight"), &[h, im])?;
let output_b = read_tensor(reader, &format!("{prefix}.output.dense.bias"), &[h])?;
layer.output_dense_mut().set_weight(output_w);
layer.output_dense_mut().set_bias(output_b);
layer.output_norm_mut().set_weight(read_tensor(
reader,
&format!("{prefix}.output.LayerNorm.weight"),
&[h],
)?);
layer.output_norm_mut().set_bias(read_tensor(
reader,
&format!("{prefix}.output.LayerNorm.bias"),
&[h],
)?);
Ok(())
}
pub(crate) fn load_encoder_from_reader(
encoder: &mut BertEncoder,
reader: &AprV2Reader,
config: &BertConfig,
) -> Result<(), BertLoadError> {
let num_layers = config.num_layers;
if encoder.num_layers() != num_layers {
return Err(BertLoadError {
tensor: "<encoder>".to_string(),
reason: format!(
"encoder has {} layers but config says {num_layers}",
encoder.num_layers()
),
});
}
for idx in 0..num_layers {
load_layer_from_reader(encoder.layer_mut(idx), reader, idx, config)?;
}
Ok(())
}
pub(crate) fn load_cross_encoder_from_reader(
model: &mut CrossEncoder,
reader: &AprV2Reader,
config: &BertConfig,
) -> Result<(), BertLoadError> {
let h = config.hidden_dim;
load_embeddings_from_reader(model.embeddings_mut(), reader, config)?;
load_encoder_from_reader(model.encoder_mut(), reader, config)?;
if let Some(pooler) = model.pooler_mut() {
pooler.set_weight(read_tensor(reader, "bert.pooler.dense.weight", &[h, h])?);
pooler.set_bias(read_tensor(reader, "bert.pooler.dense.bias", &[h])?);
}
let num_labels = model.num_labels();
let mut tried: Vec<String> = Vec::new();
for prefix in ["classifier", "score", "rank_head"] {
let w_name = format!("{prefix}.weight");
let b_name = format!("{prefix}.bias");
if reader.get_tensor(&w_name).is_some() {
let w = read_tensor(reader, &w_name, &[num_labels, h])?;
let b = read_tensor(reader, &b_name, &[num_labels])?;
model.classifier_mut().set_weight(w);
model.classifier_mut().set_bias(b);
return Ok(());
}
tried.push(prefix.to_string());
}
Err(BertLoadError {
tensor: "<classifier head>".to_string(),
reason: format!(
"no classifier tensor found; tried prefixes {tried:?} \
(expected one of `classifier.weight`, `score.weight`, `rank_head.weight`)"
),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::format::v2::{AprV2Metadata, AprV2Reader, AprV2Writer};
#[test]
fn bert_load_error_display_includes_tensor_and_reason() {
let err = BertLoadError {
tensor: "bert.embeddings.word_embeddings.weight".to_string(),
reason: "tensor not present in APR file".to_string(),
};
let display = format!("{err}");
assert!(display.contains("bert.embeddings.word_embeddings.weight"));
assert!(display.contains("tensor not present"));
}
fn tiny_config() -> BertConfig {
BertConfig {
vocab_size: 32,
hidden_dim: 8,
num_layers: 2,
num_heads: 2,
intermediate_dim: 16,
max_position_embeddings: 16,
type_vocab_size: 2,
layer_norm_eps: 1e-12,
pad_token_id: 0,
}
}
fn build_stub_bert_apr(config: &BertConfig, with_pooler: bool, num_labels: usize) -> Vec<u8> {
let h = config.hidden_dim;
let im = config.intermediate_dim;
let mut w = AprV2Writer::new(AprV2Metadata::default());
w.add_f32_tensor(
"bert.embeddings.word_embeddings.weight",
vec![config.vocab_size, h],
&vec![0.1f32; config.vocab_size * h],
);
w.add_f32_tensor(
"bert.embeddings.position_embeddings.weight",
vec![config.max_position_embeddings, h],
&vec![0.01f32; config.max_position_embeddings * h],
);
w.add_f32_tensor(
"bert.embeddings.token_type_embeddings.weight",
vec![config.type_vocab_size, h],
&vec![0.001f32; config.type_vocab_size * h],
);
w.add_f32_tensor(
"bert.embeddings.LayerNorm.weight",
vec![h],
&vec![1.0f32; h],
);
w.add_f32_tensor("bert.embeddings.LayerNorm.bias", vec![h], &vec![0.0f32; h]);
for idx in 0..config.num_layers {
let p = format!("bert.encoder.layer.{idx}");
for proj in ["query", "key", "value"] {
w.add_f32_tensor(
&format!("{p}.attention.self.{proj}.weight"),
vec![h, h],
&vec![0.0f32; h * h],
);
w.add_f32_tensor(
&format!("{p}.attention.self.{proj}.bias"),
vec![h],
&vec![0.0f32; h],
);
}
w.add_f32_tensor(
&format!("{p}.attention.output.dense.weight"),
vec![h, h],
&vec![0.0f32; h * h],
);
w.add_f32_tensor(
&format!("{p}.attention.output.dense.bias"),
vec![h],
&vec![0.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.attention.output.LayerNorm.weight"),
vec![h],
&vec![1.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.attention.output.LayerNorm.bias"),
vec![h],
&vec![0.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.intermediate.dense.weight"),
vec![im, h],
&vec![0.0f32; im * h],
);
w.add_f32_tensor(
&format!("{p}.intermediate.dense.bias"),
vec![im],
&vec![0.0f32; im],
);
w.add_f32_tensor(
&format!("{p}.output.dense.weight"),
vec![h, im],
&vec![0.0f32; h * im],
);
w.add_f32_tensor(&format!("{p}.output.dense.bias"), vec![h], &vec![0.0f32; h]);
w.add_f32_tensor(
&format!("{p}.output.LayerNorm.weight"),
vec![h],
&vec![1.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.output.LayerNorm.bias"),
vec![h],
&vec![0.0f32; h],
);
}
if with_pooler {
w.add_f32_tensor("bert.pooler.dense.weight", vec![h, h], &vec![0.0f32; h * h]);
w.add_f32_tensor("bert.pooler.dense.bias", vec![h], &vec![0.0f32; h]);
}
w.add_f32_tensor(
"classifier.weight",
vec![num_labels, h],
&vec![0.0f32; num_labels * h],
);
w.add_f32_tensor(
"classifier.bias",
vec![num_labels],
&vec![0.0f32; num_labels],
);
w.write().expect("AprV2Writer must produce bytes")
}
#[test]
fn falsify_bert_326_phase1_load_full_cross_encoder() {
let config = tiny_config();
let bytes = build_stub_bert_apr(&config, true, 1);
let reader = AprV2Reader::from_bytes(&bytes).expect("AprV2Reader parse");
let mut model = CrossEncoder::new(&config, 1, true);
model
.load_from_reader(&reader, &config)
.expect("CrossEncoder::load_from_reader must succeed for full BERT-named APR");
let input_ids = vec![1u32, 2, 3];
let token_type_ids = vec![0u32, 0, 0];
let out = model.forward(&input_ids, &token_type_ids);
assert_eq!(out.shape(), &[1, 1]);
}
#[test]
fn falsify_bert_326_phase1_missing_classifier_returns_structured_error() {
let config = tiny_config();
let h = config.hidden_dim;
let im = config.intermediate_dim;
let mut w = AprV2Writer::new(AprV2Metadata::default());
w.add_f32_tensor(
"bert.embeddings.word_embeddings.weight",
vec![config.vocab_size, h],
&vec![0.0f32; config.vocab_size * h],
);
w.add_f32_tensor(
"bert.embeddings.position_embeddings.weight",
vec![config.max_position_embeddings, h],
&vec![0.0f32; config.max_position_embeddings * h],
);
w.add_f32_tensor(
"bert.embeddings.token_type_embeddings.weight",
vec![config.type_vocab_size, h],
&vec![0.0f32; config.type_vocab_size * h],
);
w.add_f32_tensor(
"bert.embeddings.LayerNorm.weight",
vec![h],
&vec![1.0f32; h],
);
w.add_f32_tensor("bert.embeddings.LayerNorm.bias", vec![h], &vec![0.0f32; h]);
for idx in 0..config.num_layers {
let p = format!("bert.encoder.layer.{idx}");
for proj in ["query", "key", "value"] {
w.add_f32_tensor(
&format!("{p}.attention.self.{proj}.weight"),
vec![h, h],
&vec![0.0f32; h * h],
);
w.add_f32_tensor(
&format!("{p}.attention.self.{proj}.bias"),
vec![h],
&vec![0.0f32; h],
);
}
w.add_f32_tensor(
&format!("{p}.attention.output.dense.weight"),
vec![h, h],
&vec![0.0f32; h * h],
);
w.add_f32_tensor(
&format!("{p}.attention.output.dense.bias"),
vec![h],
&vec![0.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.attention.output.LayerNorm.weight"),
vec![h],
&vec![1.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.attention.output.LayerNorm.bias"),
vec![h],
&vec![0.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.intermediate.dense.weight"),
vec![im, h],
&vec![0.0f32; im * h],
);
w.add_f32_tensor(
&format!("{p}.intermediate.dense.bias"),
vec![im],
&vec![0.0f32; im],
);
w.add_f32_tensor(
&format!("{p}.output.dense.weight"),
vec![h, im],
&vec![0.0f32; h * im],
);
w.add_f32_tensor(&format!("{p}.output.dense.bias"), vec![h], &vec![0.0f32; h]);
w.add_f32_tensor(
&format!("{p}.output.LayerNorm.weight"),
vec![h],
&vec![1.0f32; h],
);
w.add_f32_tensor(
&format!("{p}.output.LayerNorm.bias"),
vec![h],
&vec![0.0f32; h],
);
}
let bytes = w.write().expect("AprV2Writer must produce bytes");
let reader = AprV2Reader::from_bytes(&bytes).expect("AprV2Reader parse");
let mut model = CrossEncoder::new(&config, 1, false);
let err = model
.load_from_reader(&reader, &config)
.expect_err("loader must report missing classifier");
assert!(
err.reason.contains("classifier"),
"error reason must reference classifier tried prefixes: {err:?}"
);
}
#[test]
fn falsify_bert_326_phase2_expected_names_count_matches_formula() {
let config = tiny_config();
let names_with_pooler = expected_bert_tensor_names(&config, true, "classifier");
let names_without_pooler = expected_bert_tensor_names(&config, false, "classifier");
let n = config.num_layers;
assert_eq!(names_with_pooler.len(), 5 + 16 * n + 2 + 2);
assert_eq!(names_without_pooler.len(), 5 + 16 * n + 2);
}
#[test]
fn falsify_bert_326_phase2_contract_matches_loader_reads() {
let config = tiny_config();
let bytes = build_stub_bert_apr(&config, true, 1);
let reader = AprV2Reader::from_bytes(&bytes).expect("AprV2Reader parse");
let expected = expected_bert_tensor_names(&config, true, "classifier");
for name in &expected {
assert!(
reader.get_tensor(name).is_some(),
"contract helper named {name:?} but stub APR doesn't contain it"
);
}
let stub_names: Vec<String> = reader
.tensor_names()
.iter()
.map(|s| s.to_string())
.collect();
for name in &stub_names {
assert!(
expected.contains(name),
"stub APR contains {name:?} but contract helper doesn't list it"
);
}
let mut model = CrossEncoder::new(&config, 1, true);
model
.load_from_reader(&reader, &config)
.expect("loader must succeed when APR contains exactly the contract names");
}
#[test]
fn falsify_bert_326_phase2_bert_map_name_is_identity() {
use crate::format::converter_types::Architecture;
let canonical_names = [
"bert.embeddings.word_embeddings.weight",
"bert.embeddings.LayerNorm.bias",
"bert.encoder.layer.0.attention.self.query.weight",
"bert.encoder.layer.0.attention.output.LayerNorm.weight",
"bert.encoder.layer.11.output.dense.bias",
"bert.pooler.dense.weight",
"classifier.weight",
"classifier.bias",
];
for name in canonical_names {
assert_eq!(
Architecture::Bert.map_name(name),
name,
"bert_map_name must preserve HF tensor names verbatim (identity passthrough)"
);
}
}
#[test]
fn falsify_bert_326_phase2_bert_base_tensor_count() {
let config = BertConfig::default(); let names = expected_bert_tensor_names(&config, true, "classifier");
assert_eq!(names.len(), 5 + 16 * 12 + 2 + 2);
assert_eq!(names.len(), 201);
}
#[test]
fn falsify_bert_326_phase1_shape_mismatch_returns_structured_error() {
let config = tiny_config();
let h = config.hidden_dim;
let mut w = AprV2Writer::new(AprV2Metadata::default());
w.add_f32_tensor(
"bert.embeddings.word_embeddings.weight",
vec![99, h],
&vec![0.0f32; 99 * h],
);
let bytes = w.write().expect("AprV2Writer must produce bytes");
let reader = AprV2Reader::from_bytes(&bytes).expect("AprV2Reader parse");
let mut emb = BertEmbeddings::new(&config);
let err = load_embeddings_from_reader(&mut emb, &reader, &config)
.expect_err("loader must reject shape mismatch");
assert!(err.reason.contains("element count mismatch"), "{err:?}");
assert!(err.tensor.contains("word_embeddings"), "{err:?}");
}
}