use crate::autograd::Tensor;
use crate::models::bert::config::BertConfig;
use crate::models::bert::layer::BertLayer;
pub struct BertEncoder {
layers: Vec<BertLayer>,
}
impl BertEncoder {
#[must_use]
pub fn new(config: &BertConfig) -> Self {
let layers = (0..config.num_layers)
.map(|_| BertLayer::new(config))
.collect();
Self { layers }
}
#[must_use]
pub fn forward(&self, hidden: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
let mut h = hidden.clone();
for layer in &self.layers {
h = layer.forward(&h, attn_mask);
}
h
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn layer_mut(&mut self, idx: usize) -> &mut BertLayer {
&mut self.layers[idx]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encoder_preserves_shape() {
let config = BertConfig::minilm_l6();
let encoder = BertEncoder::new(&config);
assert_eq!(encoder.num_layers(), 6);
let seq_len = 5;
let h = config.hidden_dim;
let input = Tensor::from_vec(vec![0.0; seq_len * h], &[1, seq_len, h]);
let out = encoder.forward(&input, None);
assert_eq!(out.shape(), &[1, seq_len, h]);
}
#[test]
fn encoder_handles_bert_base_dims() {
let config = BertConfig::default();
let encoder = BertEncoder::new(&config);
assert_eq!(encoder.num_layers(), 12);
let seq_len = 4;
let h = config.hidden_dim;
let input = Tensor::from_vec(vec![0.0; seq_len * h], &[1, seq_len, h]);
let out = encoder.forward(&input, None);
assert_eq!(out.shape(), &[1, seq_len, h]);
}
}