use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use super::{BertSelfAttention, BertSelfOutput};
pub(crate) struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
span: tracing::Span,
}
impl BertAttention {
pub(crate) fn load(vb: VarBuilder, config: &super::Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
Ok(Self {
self_attention,
self_output,
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
pub(crate) fn forward(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self
.self_attention
.forward(hidden_states, attention_mask, train)?;
let attention_output = self
.self_output
.forward(&self_outputs, hidden_states, train)?;
Ok(attention_output)
}
}