pub mod causal;
pub mod standard;
pub mod varlen;
use hanzo_ml::{DType, Result, Tensor, WithDType};
use std::iter::Sum;
use super::AttnMask;
pub fn flash_attn<T>(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
attn_mask: AttnMask,
max_bias: Option<f32>,
softcap: Option<f32>,
) -> Result<Tensor>
where
T: WithDType + Sum + num_traits::real::Real,
{
let b = q.dims()[0];
if b > 1 {
let dt = q.dtype();
let varlen_ok = (dt == DType::F32 || dt == DType::F16) && softcap.is_none();
let mask_ok = matches!(&attn_mask, AttnMask::Causal { .. } | AttnMask::None);
if !varlen_ok || !mask_ok {
hanzo_ml::bail!(
"CPU flash attention with B>1 requires: f32/f16 dtype, no softcap, \
and Causal or None mask. Got B={b}, dtype={dt:?}, softcap={softcap:?}, \
mask={}",
match &attn_mask {
AttnMask::Causal { .. } => "Causal",
AttnMask::None => "None",
AttnMask::Mask(_) => "Mask(tensor)",
}
);
}
return flash_attn_via_varlen(q, k, v, softmax_scale, &attn_mask, max_bias);
}
match attn_mask {
AttnMask::Causal { kv_offset } => {
causal::run_causal_attn_cpu::<T>(q, k, v, softmax_scale, kv_offset, max_bias, softcap)
}
AttnMask::None => {
standard::run_flash_attn_cpu::<T>(q, k, v, None, softmax_scale, max_bias, softcap)
}
AttnMask::Mask(mask) => standard::run_flash_attn_cpu::<T>(
q,
k,
v,
Some(&mask),
softmax_scale,
max_bias,
softcap,
),
}
}
fn flash_attn_via_varlen(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
attn_mask: &AttnMask,
max_bias: Option<f32>,
) -> Result<Tensor> {
let q_dims = q.dims();
let k_dims = k.dims();
let (b, s_q, h_q, d) = (q_dims[0], q_dims[1], q_dims[2], q_dims[3]);
let (s_kv, h_kv) = (k_dims[1], k_dims[2]);
let causal = attn_mask.is_causal();
let q_packed = q.contiguous()?.reshape((b * s_q, h_q, d))?;
let k_packed = k.contiguous()?.reshape((b * s_kv, h_kv, d))?;
let v_packed = v.contiguous()?.reshape((b * s_kv, h_kv, d))?;
let device = q.device();
let seqlens_q = Tensor::from_vec(vec![s_q as u32; b], b, device)?;
let seqlens_k = Tensor::from_vec(vec![s_kv as u32; b], b, device)?;
let alibi_slopes = if let Some(mb) = max_bias {
if mb > 0.0 {
let n2 = 2_usize.pow((h_q as f32).log2().ceil() as u32);
let slopes: Vec<f32> = (0..h_q)
.map(|h| 2.0f32.powf(-mb * ((h + 1) as f32) / n2 as f32))
.collect();
Some(Tensor::from_vec(slopes, h_q, device)?)
} else {
None
}
} else {
None
};
let ctx = varlen::flash_attn_varlen_cpu(
&q_packed,
&k_packed,
&v_packed,
alibi_slopes.as_ref(),
&seqlens_q,
&seqlens_k,
s_q,
s_kv,
softmax_scale,
causal,
None,
None,
)?;
ctx.reshape((b, s_q, h_q, d))?.transpose(1, 2)?.contiguous()
}