use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[allow(clippy::too_many_arguments)]
pub trait PagedAttentionOps<R: Runtime> {
fn paged_attention_fwd(
&self,
q: &Tensor<R>,
k_blocks: &Tensor<R>,
v_blocks: &Tensor<R>,
block_table: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
seq_len_q: usize,
seq_len_k: usize,
head_dim: usize,
block_size: usize,
causal: bool,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn paged_attention_fwd_fp8(
&self,
q: &Tensor<R>,
k_blocks: &Tensor<R>,
v_blocks: &Tensor<R>,
block_table: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
seq_len_q: usize,
seq_len_k: usize,
head_dim: usize,
block_size: usize,
causal: bool,
q_scale: f32,
k_scale: f32,
v_scale: f32,
o_scale: f32,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn paged_attention_bwd(
&self,
dout: &Tensor<R>,
q: &Tensor<R>,
k_blocks: &Tensor<R>,
v_blocks: &Tensor<R>,
output: &Tensor<R>,
lse: &Tensor<R>,
block_table: &Tensor<R>,
num_heads: usize,
num_kv_heads: usize,
seq_len_q: usize,
seq_len_k: usize,
head_dim: usize,
block_size: usize,
causal: bool,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
}