use candle_core::{Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::with_tracing::{linear, Linear};
use super::HiddenActLayer;
pub(crate) struct BertIntermediate {
dense: Linear,
intermediate_act: HiddenActLayer,
span: tracing::Span,
}
impl BertIntermediate {
pub(crate) fn load(vb: VarBuilder, config: &super::Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
Ok(Self {
dense,
intermediate_act: HiddenActLayer::new(config.hidden_act),
span: tracing::span!(tracing::Level::TRACE, "inter"),
})
}
}
impl Module for BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let ys = self.intermediate_act.forward(&hidden_states)?;
Ok(ys)
}
}