use std::fmt::Debug;
use candle_core::Tensor;
use candle_nn::VarBuilder;
mod mask;
pub use mask::{AttentionMask, AttentionMaskError};
mod sdpa;
pub use sdpa::{
ScaledDotProductAttention, ScaledDotProductAttentionConfig, ScaledDotProductAttentionError,
};
mod alibi;
pub use alibi::{AttentionLinearBiases, AttentionLinearBiasesConfig, AttentionLinearBiasesError};
mod self_attention;
pub use self_attention::{
AttentionHeads, CausalMask, CausalMaskError, QkvMode, QkvSplit, SelfAttention,
SelfAttentionConfig, SelfAttentionMask, SelfAttentionMaskError,
};
use crate::error::BoxedError;
use crate::kv_cache::LayerKeyValueCache;
pub trait Attention {
fn forward_t(
&self,
input: &Tensor,
attention_mask: &AttentionMask,
cache: &mut LayerKeyValueCache,
positions: Option<&Tensor>,
train: bool,
use_causal_mask: bool,
) -> Result<Tensor, BoxedError>;
}
pub trait BuildAttention {
fn build(&self, vb: VarBuilder) -> Result<Box<dyn Attention>, BoxedError>;
}
pub trait AttentionScorer {
fn forward(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: &SelfAttentionMask,
train: bool,
) -> Result<Tensor, BoxedError>;
}
pub trait BuildAttentionScorer: Debug {
fn build(&self, vb: VarBuilder) -> Result<Box<dyn AttentionScorer>, BoxedError>;
}