use tch::{Kind, Tensor};
use crate::{error::LoftrError, numeric::i64_to_f64};
#[derive(Debug, Clone, Copy)]
pub struct LinearAttention {
eps: f64,
}
impl Default for LinearAttention {
fn default() -> Self {
Self { eps: 1e-6 }
}
}
impl LinearAttention {
pub fn forward(
self,
queries: &Tensor,
keys: &Tensor,
values: &Tensor,
q_mask: Option<&Tensor>,
kv_mask: Option<&Tensor>,
) -> Result<Tensor, LoftrError> {
validate_attention_tensors(queries, keys, values)?;
let mut q = queries.elu() + 1.0;
let mut k = keys.elu() + 1.0;
let mut v = values.shallow_clone();
if let Some(mask) = q_mask {
let mask = expand_attention_mask(mask, queries)?;
q *= mask;
}
if let Some(mask) = kv_mask {
let mask = expand_attention_mask(mask, keys)?;
k *= &mask;
v *= mask;
}
let value_length = i64_to_f64(values.size()[1], "attention value length")?;
let v = &v / value_length;
let kv = Tensor::einsum("nshd,nshv->nhdv", &[&k, &v], None::<&[i64]>);
let z = (Tensor::einsum(
"nlhd,nhd->nlh",
&[&q, &k.sum_dim_intlist([1].as_slice(), false, Kind::Float)],
None::<&[i64]>,
) + self.eps)
.reciprocal();
Ok(
(Tensor::einsum("nlhd,nhdv,nlh->nlhv", &[&q, &kv, &z], None::<&[i64]>) * value_length)
.contiguous(),
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct FullAttention {
use_dropout: bool,
attention_dropout: f64,
}
impl Default for FullAttention {
fn default() -> Self {
Self {
use_dropout: false,
attention_dropout: 0.1,
}
}
}
impl FullAttention {
pub fn forward(
self,
queries: &Tensor,
keys: &Tensor,
values: &Tensor,
q_mask: Option<&Tensor>,
kv_mask: Option<&Tensor>,
) -> Result<Tensor, LoftrError> {
validate_attention_tensors(queries, keys, values)?;
let mut qk = Tensor::einsum("nlhd,nshd->nlsh", &[queries, keys], None::<&[i64]>);
if let (Some(q_mask), Some(kv_mask)) = (q_mask, kv_mask) {
let q_mask = expand_full_mask(q_mask, queries, true)?;
let kv_mask = expand_full_mask(kv_mask, keys, false)?;
let valid = q_mask.logical_and(&kv_mask);
qk = qk.f_masked_fill(&valid.logical_not(), f64::NEG_INFINITY)?;
}
let softmax_temp = 1.0 / i64_to_f64(queries.size()[3], "attention head dim")?.sqrt();
let mut attention = (qk * softmax_temp).softmax(2, Kind::Float);
if self.use_dropout && self.attention_dropout > 0.0 {
attention = attention.dropout(self.attention_dropout, false);
}
Ok(Tensor::einsum("nlsh,nshd->nlhd", &[&attention, values], None::<&[i64]>).contiguous())
}
}
fn validate_attention_tensors(
queries: &Tensor,
keys: &Tensor,
values: &Tensor,
) -> Result<(), LoftrError> {
let q_dims = queries.size();
let k_dims = keys.size();
let v_dims = values.size();
if q_dims.len() != 4 || k_dims.len() != 4 || v_dims.len() != 4 {
return Err(LoftrError::InvalidConfig(format!(
"Attention expects [N,L,H,D], [N,S,H,D], [N,S,H,D]; got queries={q_dims:?}, keys={k_dims:?}, values={v_dims:?}"
)));
}
if q_dims[0] != k_dims[0] || q_dims[0] != v_dims[0] {
return Err(LoftrError::InvalidConfig(format!(
"Attention batch mismatch: queries={}, keys={}, values={}",
q_dims[0], k_dims[0], v_dims[0]
)));
}
if k_dims[1] != v_dims[1] || k_dims[2] != v_dims[2] || k_dims[3] != v_dims[3] {
return Err(LoftrError::InvalidConfig(format!(
"Attention key/value mismatch: keys={k_dims:?}, values={v_dims:?}"
)));
}
if q_dims[2] != k_dims[2] || q_dims[3] != k_dims[3] {
return Err(LoftrError::InvalidConfig(format!(
"Attention query/key head mismatch: queries={q_dims:?}, keys={k_dims:?}"
)));
}
Ok(())
}
fn expand_attention_mask(mask: &Tensor, like: &Tensor) -> Result<Tensor, LoftrError> {
let dims = mask.size();
let expected = [like.size()[0], like.size()[1]];
if dims != expected {
return Err(LoftrError::InvalidConfig(format!(
"Attention mask expects {expected:?}; got {dims:?}"
)));
}
Ok(mask
.f_to_device(like.device())?
.f_to_kind(like.kind())?
.unsqueeze(-1)
.unsqueeze(-1))
}
fn expand_full_mask(mask: &Tensor, like: &Tensor, is_query: bool) -> Result<Tensor, LoftrError> {
let dims = mask.size();
let expected = [like.size()[0], like.size()[1]];
if dims != expected {
return Err(LoftrError::InvalidConfig(format!(
"Attention mask expects {expected:?}; got {dims:?}"
)));
}
let mask = mask.f_to_device(like.device())?.f_to_kind(Kind::Bool)?;
Ok(if is_query {
mask.unsqueeze(-1).unsqueeze(-1)
} else {
mask.unsqueeze(1).unsqueeze(-1)
})
}
#[cfg(test)]
mod tests;