#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AttentionMask {
None,
Causal,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AttentionShape {
pub q_len: usize,
pub k_len: usize,
pub d_k: usize,
pub d_v: usize,
}
impl AttentionShape {
pub fn validate(&self) -> bool {
self.q_len > 0 && self.k_len > 0 && self.d_k > 0 && self.d_v > 0
}
pub fn score_len(&self) -> Option<usize> {
self.q_len.checked_mul(self.k_len)
}
pub fn output_len(&self) -> Option<usize> {
self.q_len.checked_mul(self.d_v)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AttentionError {
ShapeMismatch,
BufferTooSmall,
InvalidDim,
}