use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[allow(clippy::too_many_arguments)]
pub trait VarLenAttentionOps<R: Runtime> {
fn varlen_attention_fwd(
&self,
q: &Tensor<R>,
k: &Tensor<R>,
v: &Tensor<R>,
cu_seqlens_q: &Tensor<R>,
cu_seqlens_k: &Tensor<R>,
batch_size: usize,
num_heads: usize,
max_seqlen_q: usize,
max_seqlen_k: usize,
head_dim: usize,
causal: bool,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn varlen_attention_bwd(
&self,
dout: &Tensor<R>,
q: &Tensor<R>,
k: &Tensor<R>,
v: &Tensor<R>,
output: &Tensor<R>,
lse: &Tensor<R>,
cu_seqlens_q: &Tensor<R>,
cu_seqlens_k: &Tensor<R>,
batch_size: usize,
num_heads: usize,
max_seqlen_q: usize,
max_seqlen_k: usize,
head_dim: usize,
causal: bool,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
}