use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::tensor::Tensor;
pub fn flex_attention<T, F>(
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
score_mod: Option<F>,
) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>, usize, usize) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.len() != 4 || k_shape.len() != 4 || v_shape.len() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"flex_attention: expected 4-D tensors [batch, heads, seq, dim], \
got Q={:?}, K={:?}, V={:?}",
q_shape, k_shape, v_shape
),
});
}
let batch = q_shape[0];
let heads = q_shape[1];
let n_q = q_shape[2];
let d = q_shape[3];
let n_k = k_shape[2];
let d_v = v_shape[3];
if d == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "flex_attention: head dimension d must be > 0".into(),
});
}
if k_shape[0] != batch || k_shape[1] != heads || k_shape[3] != d {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"flex_attention: Q shape {:?} incompatible with K shape {:?}",
q_shape, k_shape
),
});
}
if v_shape[0] != batch || v_shape[1] != heads || v_shape[2] != n_k {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"flex_attention: K shape {:?} incompatible with V shape {:?}",
k_shape, v_shape
),
});
}
if query.device() != key.device() || query.device() != value.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: query.device(),
got: key.device(),
});
}
let scale = T::from(1.0 / (d as f64).sqrt()).unwrap();
let device = query.device();
let bh = batch * heads;
let q3 = crate::grad_fns::shape::reshape(query, &[bh as isize, n_q as isize, d as isize])?;
let k3 = crate::grad_fns::shape::reshape(key, &[bh as isize, n_k as isize, d as isize])?;
let v3 = crate::grad_fns::shape::reshape(value, &[bh as isize, n_k as isize, d_v as isize])?;
let k3_t = k3.transpose(1, 2)?;
let scores3 = crate::grad_fns::linalg::bmm_differentiable(&q3, &k3_t)?;
let scale_t = crate::creation::scalar(scale)?.to(device)?;
let scores3_scaled = crate::grad_fns::arithmetic::mul(&scores3, &scale_t)?;
let scores4 = crate::grad_fns::shape::reshape(
&scores3_scaled,
&[batch as isize, heads as isize, n_q as isize, n_k as isize],
)?;
let scores_after_mod = if let Some(ref sm) = score_mod {
let mut per_bh: Vec<Tensor<T>> = Vec::with_capacity(bh);
for b in 0..batch {
for h in 0..heads {
let bh_view = scores4
.narrow(0, b, 1)?
.narrow(1, h, 1)?
.squeeze_t(0)?
.squeeze_t(0)?;
let modified = sm(&bh_view, b, h)?;
if modified.shape() != [n_q, n_k] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"flex_attention: score_mod returned shape {:?}, expected [{}, {}]",
modified.shape(),
n_q,
n_k
),
});
}
let lifted = modified.unsqueeze_t(0)?.unsqueeze_t(0)?;
per_bh.push(lifted);
}
}
let mut head_groups: Vec<Tensor<T>> = Vec::with_capacity(batch);
for b in 0..batch {
let group: Vec<Tensor<T>> = per_bh[b * heads..(b + 1) * heads].to_vec();
let cat_h = crate::grad_fns::shape::cat(&group, 1)?;
head_groups.push(cat_h);
}
crate::grad_fns::shape::cat(&head_groups, 0)?
} else {
scores4
};
let weights4 = crate::grad_fns::activation::softmax(&scores_after_mod)?;
let weights3 = crate::grad_fns::shape::reshape(
&weights4,
&[bh as isize, n_q as isize, n_k as isize],
)?;
let output3 = crate::grad_fns::linalg::bmm_differentiable(&weights3, &v3)?;
crate::grad_fns::shape::reshape(
&output3,
&[batch as isize, heads as isize, n_q as isize, d_v as isize],
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
fn make_tensor_grad(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, true).unwrap()
}
#[test]
fn test_flex_attention_basic() {
let q = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let k = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let v = make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
let output = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 2, 2]);
}
#[test]
fn test_flex_attention_with_score_mod() {
let q = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let k = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let v = make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
let output = flex_attention(
&q,
&k,
&v,
Some(|scores: &Tensor<f32>, _b: usize, _h: usize| Ok(scores.clone())),
)
.unwrap();
assert_eq!(output.shape(), &[1, 1, 2, 2]);
}
#[test]
fn test_flex_attention_shape_validation() {
let q = make_tensor(vec![1.0, 2.0], vec![2]);
let k = make_tensor(vec![1.0, 2.0], vec![2]);
let v = make_tensor(vec![1.0, 2.0], vec![2]);
let result = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None);
assert!(result.is_err());
}
#[test]
fn test_flex_attention_d_zero() {
let q = make_tensor(vec![], vec![1, 1, 2, 0]);
let k = make_tensor(vec![], vec![1, 1, 2, 0]);
let v = make_tensor(vec![], vec![1, 1, 2, 0]);
let result = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("d must be > 0"));
}
#[test]
fn test_flex_attention_with_grad() {
let q = make_tensor_grad(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let k = make_tensor_grad(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let v = make_tensor_grad(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
let output = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None)
.unwrap();
assert!(
output.grad_fn().is_some(),
"expected output to have a grad_fn so backward propagates to Q/K/V"
);
}
#[test]
fn test_flex_attention_numerical_value() {
let q = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let k = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let v = make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
let out = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None)
.unwrap();
let data = out.data().unwrap();
let expected = [1.6603, 2.6602, 2.3399, 3.3399];
for (i, (&got, &exp)) in data.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-3,
"out[{}]: expected {}, got {}",
i,
exp,
got
);
}
}
#[test]
fn test_flex_attention_score_mod_additive_bias() {
let q = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let k = make_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![1, 1, 2, 2]);
let v = make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
let baseline = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None)
.unwrap();
let with_const_bias = flex_attention(
&q,
&k,
&v,
Some(|s: &Tensor<f32>, _b: usize, _h: usize| {
let one = crate::creation::scalar(1.0f32).unwrap();
crate::grad_fns::arithmetic::add(s, &one)
}),
)
.unwrap();
let base_data = baseline.data().unwrap();
let mod_data = with_const_bias.data().unwrap();
for (i, (&b, &m)) in base_data.iter().zip(mod_data.iter()).enumerate() {
assert!(
(b - m).abs() < 1e-5,
"softmax-invariant additive bias should not change output[{}]: base={}, mod={}",
i,
b,
m
);
}
}
#[test]
fn test_flex_attention_grad_propagates_to_qkv() {
let q = make_tensor_grad(vec![1.0, 0.0, 0.5, 1.0], vec![1, 1, 2, 2]);
let k = make_tensor_grad(vec![0.5, 1.0, 1.0, 0.0], vec![1, 1, 2, 2]);
let v = make_tensor_grad(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
let output = flex_attention::<
f32,
fn(&Tensor<f32>, usize, usize) -> FerrotorchResult<Tensor<f32>>,
>(&q, &k, &v, None)
.unwrap();
let loss = crate::grad_fns::reduction::sum(&output).unwrap();
loss.backward().unwrap();
let gq = q
.grad()
.unwrap()
.expect("query should have a gradient after backward");
let gk = k
.grad()
.unwrap()
.expect("key should have a gradient after backward");
let gv = v
.grad()
.unwrap()
.expect("value should have a gradient after backward");
assert_eq!(gq.shape(), &[1, 1, 2, 2]);
assert_eq!(gk.shape(), &[1, 1, 2, 2]);
assert_eq!(gv.shape(), &[1, 1, 2, 2]);
let gv_data = gv.data().unwrap();
assert!(
gv_data.iter().any(|&x| x != 0.0),
"expected non-zero gradient on V"
);
}
}