use crate::error::{Error, Result};
use crate::ops::traits::PagedAttentionOps;
use numr::dtype::DType;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
use super::paged_attention_bwd::paged_attention_bwd_impl;
use super::paged_attention_fwd::{paged_attention_fwd_fp8_impl, paged_attention_fwd_impl};
pub(super) fn fwd_block_config(head_dim: usize, dtype: DType) -> Result<(usize, usize)> {
match (dtype, head_dim) {
(DType::F32, 64) => Ok((64, 32)), (DType::F32, 128) => Ok((32, 32)), (DType::F16 | DType::BF16, 64) => Ok((64, 32)), (DType::F16 | DType::BF16, 128) => Ok((32, 32)), (DType::FP8E4M3 | DType::FP8E5M2, 64) => Ok((64, 32)), (DType::FP8E4M3 | DType::FP8E5M2, 128) => Ok((32, 32)), _ => Err(Error::InvalidArgument {
arg: "head_dim",
reason: format!(
"unsupported head_dim={} for paged attention. Supported: 64, 128",
head_dim
),
}),
}
}
pub(super) fn bwd_block_config(head_dim: usize, dtype: DType) -> Result<(usize, usize)> {
match (dtype, head_dim) {
(DType::F32, 64) => Ok((32, 32)), (DType::F32, 128) => Ok((16, 16)), (DType::F16 | DType::BF16, 64) => Ok((64, 32)), (DType::F16 | DType::BF16, 128) => Ok((32, 32)), _ => Err(Error::InvalidArgument {
arg: "head_dim",
reason: format!(
"unsupported head_dim={} for paged attention backward",
head_dim
),
}),
}
}
impl PagedAttentionOps<CudaRuntime> for CudaClient {
fn paged_attention_fwd(
&self,
q: &Tensor<CudaRuntime>,
k_blocks: &Tensor<CudaRuntime>,
v_blocks: &Tensor<CudaRuntime>,
block_table: &Tensor<CudaRuntime>,
num_heads: usize,
num_kv_heads: usize,
seq_len_q: usize,
seq_len_k: usize,
head_dim: usize,
block_size: usize,
causal: bool,
) -> Result<(Tensor<CudaRuntime>, Tensor<CudaRuntime>)> {
paged_attention_fwd_impl(
self,
q,
k_blocks,
v_blocks,
block_table,
num_heads,
num_kv_heads,
seq_len_q,
seq_len_k,
head_dim,
block_size,
causal,
)
}
fn paged_attention_fwd_fp8(
&self,
q: &Tensor<CudaRuntime>,
k_blocks: &Tensor<CudaRuntime>,
v_blocks: &Tensor<CudaRuntime>,
block_table: &Tensor<CudaRuntime>,
num_heads: usize,
num_kv_heads: usize,
seq_len_q: usize,
seq_len_k: usize,
head_dim: usize,
block_size: usize,
causal: bool,
q_scale: f32,
k_scale: f32,
v_scale: f32,
o_scale: f32,
) -> Result<(Tensor<CudaRuntime>, Tensor<CudaRuntime>)> {
paged_attention_fwd_fp8_impl(
self,
q,
k_blocks,
v_blocks,
block_table,
num_heads,
num_kv_heads,
seq_len_q,
seq_len_k,
head_dim,
block_size,
causal,
q_scale,
k_scale,
v_scale,
o_scale,
)
}
fn paged_attention_bwd(
&self,
dout: &Tensor<CudaRuntime>,
q: &Tensor<CudaRuntime>,
k_blocks: &Tensor<CudaRuntime>,
v_blocks: &Tensor<CudaRuntime>,
output: &Tensor<CudaRuntime>,
lse: &Tensor<CudaRuntime>,
block_table: &Tensor<CudaRuntime>,
num_heads: usize,
num_kv_heads: usize,
seq_len_q: usize,
seq_len_k: usize,
head_dim: usize,
block_size: usize,
causal: bool,
) -> Result<(
Tensor<CudaRuntime>,
Tensor<CudaRuntime>,
Tensor<CudaRuntime>,
)> {
paged_attention_bwd_impl(
self,
dout,
q,
k_blocks,
v_blocks,
output,
lse,
block_table,
num_heads,
num_kv_heads,
seq_len_q,
seq_len_k,
head_dim,
block_size,
causal,
)
}
}