use crate::error::Result;
use numr::autograd::Var;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub trait AttentionOps<R: Runtime> {
fn multi_head_attention(
&self,
q: &Var<R>,
k: &Var<R>,
v: &Var<R>,
mask: Option<&Var<R>>,
num_heads: usize,
) -> Result<Var<R>>;
}
#[allow(clippy::too_many_arguments)]
pub trait FlashAttentionOps<R: Runtime> {
fn flash_attention_fwd(
&self,
q: &Tensor<R>,
k: &Tensor<R>,
v: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
causal: bool,
window_size: usize,
kv_seq_len: Option<usize>,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn flash_attention_fwd_fp8(
&self,
q: &Tensor<R>,
k: &Tensor<R>,
v: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
causal: bool,
q_scale: f32,
k_scale: f32,
v_scale: f32,
o_scale: f32,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn flash_attention_bwd(
&self,
dout: &Tensor<R>,
q: &Tensor<R>,
k: &Tensor<R>,
v: &Tensor<R>,
output: &Tensor<R>,
lse: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
causal: bool,
window_size: usize,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn flash_attention_bwd_fp8(
&self,
dout: &Tensor<R>,
q: &Tensor<R>,
k: &Tensor<R>,
v: &Tensor<R>,
output: &Tensor<R>,
lse: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
causal: bool,
q_scale: f32,
k_scale: f32,
v_scale: f32,
do_scale: f32,
o_scale: f32,
dq_scale: f32,
dk_scale: f32,
dv_scale: f32,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
}