use crate::autograd::Tensor;
use crate::models::bert::config::BertConfig;
use crate::nn::functional::gelu;
use crate::nn::transformer::MultiHeadAttention;
use crate::nn::{LayerNorm, Linear, Module};
pub struct BertLayer {
attention: MultiHeadAttention,
attention_norm: LayerNorm,
intermediate: Linear,
output_dense: Linear,
output_norm: LayerNorm,
}
impl BertLayer {
#[must_use]
pub fn new(config: &BertConfig) -> Self {
let h = config.hidden_dim;
let intermediate = config.intermediate_dim;
Self {
attention: MultiHeadAttention::new(h, config.num_heads),
attention_norm: LayerNorm::with_eps(&[h], config.layer_norm_eps),
intermediate: Linear::new(h, intermediate),
output_dense: Linear::new(intermediate, h),
output_norm: LayerNorm::with_eps(&[h], config.layer_norm_eps),
}
}
pub fn attention_mut(&mut self) -> &mut MultiHeadAttention {
&mut self.attention
}
pub fn attention_norm_mut(&mut self) -> &mut LayerNorm {
&mut self.attention_norm
}
pub fn intermediate_mut(&mut self) -> &mut Linear {
&mut self.intermediate
}
pub fn output_dense_mut(&mut self) -> &mut Linear {
&mut self.output_dense
}
pub fn output_norm_mut(&mut self) -> &mut LayerNorm {
&mut self.output_norm
}
#[must_use]
pub fn forward(&self, hidden: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
let (attn_out, _) = self.attention.forward_self(hidden, attn_mask);
let attn_residual = hidden.add(&attn_out);
let attn_normalized = self.attention_norm.forward(&attn_residual);
let intermediate = self.intermediate.forward(&attn_normalized);
let intermediate_act = gelu(&intermediate);
let ffn_out = self.output_dense.forward(&intermediate_act);
let ffn_residual = attn_normalized.add(&ffn_out);
self.output_norm.forward(&ffn_residual)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bert_layer_preserves_shape() {
let config = BertConfig::minilm_l6();
let layer = BertLayer::new(&config);
let seq_len = 5;
let h = config.hidden_dim;
let input = Tensor::from_vec(vec![0.0; seq_len * h], &[1, seq_len, h]);
let out = layer.forward(&input, None);
assert_eq!(out.shape(), &[1, seq_len, h]);
}
#[test]
fn bert_layer_handles_long_seq() {
let config = BertConfig::minilm_l6();
let layer = BertLayer::new(&config);
let seq_len = 128;
let h = config.hidden_dim;
let input = Tensor::from_vec(vec![0.0; seq_len * h], &[1, seq_len, h]);
let out = layer.forward(&input, None);
assert_eq!(out.shape(), &[1, seq_len, h]);
}
}