use candle_core::{ModuleT, Tensor, D};
use candle_nn::ops::softmax;
use candle_nn::VarBuilder;
use snafu::{ResultExt, Snafu};
use crate::error::BoxedError;
use crate::layers::attention::self_attention::{SelfAttentionMask, SelfAttentionMaskError};
use crate::layers::attention::{
AttentionLinearBiases, AttentionLinearBiasesConfig, AttentionLinearBiasesError,
AttentionScorer, BuildAttentionScorer,
};
use crate::layers::build_module::BuildModule;
use crate::layers::identity::Identity;
#[derive(Debug)]
pub struct ScaledDotProductAttentionConfig {
dropout: Box<dyn BuildModule>,
linear_biases: Option<AttentionLinearBiasesConfig>,
}
impl ScaledDotProductAttentionConfig {
pub fn dropout(mut self, dropout: Box<dyn BuildModule>) -> Self {
self.dropout = dropout;
self
}
pub fn linear_biases(mut self, linear_biases: Option<AttentionLinearBiasesConfig>) -> Self {
self.linear_biases = linear_biases;
self
}
}
impl Default for ScaledDotProductAttentionConfig {
fn default() -> Self {
Self {
dropout: Box::new(Identity),
linear_biases: None,
}
}
}
impl BuildAttentionScorer for ScaledDotProductAttentionConfig {
fn build(&self, vb: VarBuilder) -> Result<Box<dyn AttentionScorer>, BoxedError> {
Ok(Box::new(ScaledDotProductAttention {
dropout: self.dropout.build(vb.clone()).context(BuildDropoutSnafu)?,
linear_biases: self
.linear_biases
.as_ref()
.map(|linear_biases| linear_biases.build())
.transpose()
.context(BuildAttentionLinearBiasesSnafu)?,
}))
}
}
#[derive(Debug, Snafu)]
pub enum ScaledDotProductAttentionError {
#[snafu(display("Cannot calculate attention linear biases"))]
AttentionLinearBiases { source: AttentionLinearBiasesError },
#[snafu(display("Cannot apply attention mask"))]
AttentionMask { source: SelfAttentionMaskError },
#[snafu(display("Cannot calculate attention scores"))]
AttentionScores { source: candle_core::Error },
#[snafu(display("Cannot weigh representations using attention mask"))]
AttentionWeight { source: candle_core::Error },
#[snafu(display("Cannot build attention linear biases"))]
BuildAttentionLinearBiases { source: AttentionLinearBiasesError },
#[snafu(display("Cannot build dropout module"))]
BuildDropout { source: BoxedError },
#[snafu(display("Cannot apply softmax temperature"))]
Temperature { source: candle_core::Error },
}
pub struct ScaledDotProductAttention {
dropout: Box<dyn ModuleT>,
linear_biases: Option<AttentionLinearBiases>,
}
impl AttentionScorer for ScaledDotProductAttention {
fn forward(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: &SelfAttentionMask,
train: bool,
) -> Result<Tensor, BoxedError> {
let query = query.contiguous().context(AttentionScoresSnafu)?;
let mut attn_scores = key
.contiguous()
.and_then(|key| key.transpose(3, 2))
.and_then(|key| query.broadcast_matmul(&key))
.context(AttentionScoresSnafu)?;
let model_width = key.dim(3).context(TemperatureSnafu)?;
let temperature = (model_width as f64).sqrt();
attn_scores = (attn_scores / temperature).context(TemperatureSnafu)?;
if let Some(linear_biases) = &self.linear_biases {
attn_scores = linear_biases
.forward(&attn_scores)
.context(AttentionLinearBiasesSnafu)?;
}
attn_scores = attention_mask
.apply_logit_mask(&attn_scores)
.context(AttentionMaskSnafu)?;
let attn_weights = softmax(&attn_scores, D::Minus1).context(AttentionWeightSnafu)?;
value
.contiguous()
.and_then(|value| attn_weights.broadcast_matmul(&value))
.and_then(|xs| self.dropout.forward_t(&xs, train))
.context(AttentionWeightSnafu)
.boxed()
}
}