use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use super::BertLayer;
pub(crate) struct BertEncoder {
layers: Vec<BertLayer>,
span: tracing::Span,
}
impl BertEncoder {
pub(crate) fn load(vb: VarBuilder, config: &super::Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
pub fn forward(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, attention_mask, train)?
}
Ok(hidden_states)
}
}