use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use cudarc::driver::{
CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, DeviceRepr, LaunchConfig,
PushKernelArg,
};
use ferrum_types::{FerrumError, Result};
use half::f16;
use super::{CudaBackend, CudaState, BATCHED_SCRATCH_CAP, HOST_STAGING_TOTAL};
use crate::backend::{Backend, BackendPagedKv};
use crate::ptx;
#[derive(Debug, Clone, PartialEq, Eq)]
struct CudaPagedRuntimeConfig {
kv_capacity: usize,
paged_flash_splits: Option<usize>,
split_k_attn: Option<bool>,
fa2_source: bool,
paged_flash: bool,
}
impl CudaPagedRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self {
kv_capacity: 512,
paged_flash_splits: None,
split_k_attn: None,
fa2_source: false,
paged_flash: true,
};
for (name, value) in vars {
let value = value.as_ref();
match name.as_ref() {
"FERRUM_KV_CAPACITY" => {
if let Ok(kv_capacity) = value.parse::<usize>() {
config.kv_capacity = kv_capacity;
}
}
"FERRUM_PAGED_FLASH_SPLITS" => {
config.paged_flash_splits = value.parse::<usize>().ok();
}
"FERRUM_SPLIT_K_ATTN" => {
config.split_k_attn = match value {
"1" => Some(true),
"0" => Some(false),
_ => None,
};
}
"FERRUM_FA2_SOURCE" => {
config.fa2_source = matches!(value, "1" | "true" | "TRUE" | "on" | "ON");
}
"FERRUM_PAGED_FLASH" => config.paged_flash = value != "0",
_ => {}
}
}
config
}
fn shared_kv_for(&self, kv_len: usize) -> usize {
self.kv_capacity.max(kv_len).max(1)
}
}
#[cfg(test)]
mod tests {
use super::CudaPagedRuntimeConfig;
#[test]
fn cuda_paged_runtime_config_parses_startup_knobs() {
let config = CudaPagedRuntimeConfig::from_env_vars([
("FERRUM_KV_CAPACITY", "4096"),
("FERRUM_PAGED_FLASH_SPLITS", "4"),
("FERRUM_SPLIT_K_ATTN", "1"),
("FERRUM_FA2_SOURCE", "on"),
("FERRUM_PAGED_FLASH", "0"),
]);
assert_eq!(config.kv_capacity, 4096);
assert_eq!(config.paged_flash_splits, Some(4));
assert_eq!(config.split_k_attn, Some(true));
assert!(config.fa2_source);
assert!(!config.paged_flash);
assert_eq!(config.shared_kv_for(8192), 8192);
}
#[test]
fn cuda_paged_runtime_config_keeps_existing_defaults() {
let config = CudaPagedRuntimeConfig::from_env_vars([
("FERRUM_KV_CAPACITY", "bad"),
("FERRUM_PAGED_FLASH_SPLITS", "bad"),
("FERRUM_SPLIT_K_ATTN", "auto"),
("FERRUM_FA2_SOURCE", "trueish"),
]);
assert_eq!(config.kv_capacity, 512);
assert_eq!(config.paged_flash_splits, None);
assert_eq!(config.split_k_attn, None);
assert!(!config.fa2_source);
assert!(config.paged_flash);
assert_eq!(config.shared_kv_for(128), 512);
}
}
fn cuda_paged_runtime_config() -> &'static CudaPagedRuntimeConfig {
static CONFIG: OnceLock<CudaPagedRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(CudaPagedRuntimeConfig::from_env)
}
struct SplitKScratch {
partial_out: CudaSlice<f32>, partial_m: CudaSlice<f32>, partial_l: CudaSlice<f32>, out_capacity: usize,
ml_capacity: usize,
}
unsafe impl Send for SplitKScratch {}
unsafe impl Sync for SplitKScratch {}
static SPLIT_K_SCRATCH: std::sync::OnceLock<std::sync::RwLock<HashMap<usize, SplitKScratch>>> =
std::sync::OnceLock::new();
fn split_k_scratch_slots() -> &'static std::sync::RwLock<HashMap<usize, SplitKScratch>> {
SPLIT_K_SCRATCH.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
fn with_split_k_scratch<R>(
stream: &Arc<CudaStream>,
ordinal: usize,
out_required: usize,
ml_required: usize,
body: impl FnOnce(&mut CudaSlice<f32>, &mut CudaSlice<f32>, &mut CudaSlice<f32>) -> R,
) -> R {
let slots = split_k_scratch_slots();
{
let g = slots.read().expect("SPLIT_K_SCRATCH poisoned");
if let Some(s) = g.get(&ordinal) {
if s.out_capacity >= out_required && s.ml_capacity >= ml_required {
drop(g);
let mut w = slots.write().expect("SPLIT_K_SCRATCH poisoned");
let s = w.get_mut(&ordinal).expect("just observed Some");
return body(&mut s.partial_out, &mut s.partial_m, &mut s.partial_l);
}
}
}
let mut w = slots.write().expect("SPLIT_K_SCRATCH poisoned");
let need_new = match w.get(&ordinal) {
Some(s) => s.out_capacity < out_required || s.ml_capacity < ml_required,
None => true,
};
if need_new {
let partial_out = unsafe { stream.alloc::<f32>(out_required) }
.expect("SPLIT_K_SCRATCH partial_out alloc");
let partial_m =
unsafe { stream.alloc::<f32>(ml_required) }.expect("SPLIT_K_SCRATCH partial_m alloc");
let partial_l =
unsafe { stream.alloc::<f32>(ml_required) }.expect("SPLIT_K_SCRATCH partial_l alloc");
w.insert(
ordinal,
SplitKScratch {
partial_out,
partial_m,
partial_l,
out_capacity: out_required,
ml_capacity: ml_required,
},
);
}
let s = w.get_mut(&ordinal).expect("just allocated");
body(&mut s.partial_out, &mut s.partial_m, &mut s.partial_l)
}
#[allow(clippy::too_many_arguments)]
fn paged_varlen_split_k_dispatch(
ctx: &mut CudaState,
q: &CudaSlice<f16>,
k_pool: &CudaSlice<f16>,
v_pool: &CudaSlice<f16>,
out: &mut CudaSlice<f16>,
cu_seqlens_q: &CudaSlice<u32>,
pos_offsets: &CudaSlice<u32>,
block_tables: &CudaSlice<u32>,
num_seqs: usize,
total_q_tokens: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
let num_splits: usize = match max_kv_len {
kv if kv <= 384 => 2,
kv if kv <= 1024 => 4,
kv if kv <= 2048 => 8,
_ => 16,
};
let chunk = (max_kv_len + num_splits - 1) / num_splits;
let out_required = total_q_tokens * num_heads * num_splits * head_dim;
let ml_required = total_q_tokens * num_heads * num_splits;
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let phase1 = ctx.func(
"paged_varlen_split_k_phase1",
ptx::PAGED_VARLEN_ATTENTION,
"paged_varlen_attn_split_k_phase1_f16",
);
let reduce = ctx.func(
"paged_varlen_split_k_reduce",
ptx::PAGED_VARLEN_ATTENTION,
"paged_varlen_split_k_reduce_f16",
);
with_split_k_scratch(
&stream,
ctx.ordinal,
out_required,
ml_required,
|partial_out, partial_m, partial_l| {
let qv = q.slice(..);
let kp = k_pool.slice(..);
let vp = v_pool.slice(..);
let csq = cu_seqlens_q.slice(..);
let po = pos_offsets.slice(..);
let bt = block_tables.slice(..);
let pout = partial_out.slice(..);
let pm = partial_m.slice(..);
let pl = partial_l.slice(..);
let ns = num_seqs as i32;
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_blocks_per_seq as i32;
let bsi = block_size as i32;
let nsp = num_splits as i32;
let mut b1 = stream.launch_builder(&phase1);
b1.arg(&qv);
b1.arg(&kp);
b1.arg(&vp);
b1.arg(&csq);
b1.arg(&po);
b1.arg(&bt);
b1.arg(&pout);
b1.arg(&pm);
b1.arg(&pl);
b1.arg(&ns);
b1.arg(&nqi);
b1.arg(&nkvi);
b1.arg(&hdi);
b1.arg(&mbps);
b1.arg(&bsi);
b1.arg(&scale);
b1.arg(&nsp);
let shmem1 = (chunk.max(1) as u32) * 4;
unsafe {
b1.launch(LaunchConfig {
grid_dim: (num_heads as u32, total_q_tokens as u32, num_splits as u32),
block_dim: (256, 1, 1),
shared_mem_bytes: shmem1,
})
}
.map_err(|e| FerrumError::model(format!("paged_varlen_split_k_phase1: {e}")))?;
let pout2 = partial_out.slice(..);
let pm2 = partial_m.slice(..);
let pl2 = partial_l.slice(..);
let mut b2 = stream.launch_builder(&reduce);
b2.arg(&pout2);
b2.arg(&pm2);
b2.arg(&pl2);
b2.arg(out);
b2.arg(&nqi);
b2.arg(&hdi);
b2.arg(&nsp);
unsafe {
b2.launch(LaunchConfig {
grid_dim: (num_heads as u32, total_q_tokens as u32, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("paged_varlen_split_k_reduce: {e}")))?;
Ok::<(), FerrumError>(())
},
)
}
#[allow(clippy::too_many_arguments)]
fn paged_batched_flash_dispatch(
ctx: &mut CudaState,
q: &CudaSlice<f16>,
k_pool: &CudaSlice<f16>,
v_pool: &CudaSlice<f16>,
out: &mut CudaSlice<f16>,
block_tables: &CudaSlice<u32>,
valid_kv_lens: &CudaSlice<u32>,
num_seqs: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
let force_splits = cuda_paged_runtime_config().paged_flash_splits;
const SM_TARGET: usize = 128;
let base_grid = num_seqs * num_heads;
let saturated = base_grid >= 2 * SM_TARGET;
let waves = base_grid / SM_TARGET; let num_splits: usize = force_splits.unwrap_or_else(|| {
if saturated {
match (max_kv_len, waves) {
(kv, _) if kv > 2048 => 4,
(kv, _) if kv <= 768 && waves >= 8 => 4,
(kv, _) if kv <= 768 => 1,
_ => 2,
}
} else {
let needed = (SM_TARGET + base_grid - 1) / base_grid;
let by_kv = match max_kv_len {
kv if kv <= 256 => 4,
kv if kv <= 1024 => 8,
_ => 16,
};
needed.max(1).min(by_kv).min(16)
}
});
if num_splits <= 1 {
return paged_batched_decode_single_pass(
ctx,
q,
k_pool,
v_pool,
out,
block_tables,
valid_kv_lens,
num_seqs,
max_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_blocks_per_seq,
);
}
let chunk = (max_kv_len + num_splits - 1) / num_splits;
let total_qh = num_seqs * num_heads;
let out_required = total_qh * num_splits * head_dim;
let ml_required = total_qh * num_splits;
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let phase1 = ctx.func(
"paged_batched_flash_attn",
ptx::PAGED_DECODE_ATTENTION,
"paged_batched_flash_decode_attn_f16",
);
let phase2 = ctx.func(
"paged_batched_flash_reduce",
ptx::PAGED_DECODE_ATTENTION,
"paged_batched_flash_decode_reduce_f16",
);
with_split_k_scratch(
&stream,
ctx.ordinal,
out_required,
ml_required,
|partial_out, partial_m, partial_l| {
let qv = q.slice(..);
let kp = k_pool.slice(..);
let vp = v_pool.slice(..);
let bt = block_tables.slice(..);
let kvl = valid_kv_lens.slice(..);
let pout = partial_out.slice(..);
let pm = partial_m.slice(..);
let pl = partial_l.slice(..);
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_blocks_per_seq as i32;
let bsi = block_size as i32;
let nsp = num_splits as i32;
let mut b1 = stream.launch_builder(&phase1);
b1.arg(&qv);
b1.arg(&kp);
b1.arg(&vp);
b1.arg(&bt);
b1.arg(&kvl);
b1.arg(&pout);
b1.arg(&pm);
b1.arg(&pl);
b1.arg(&nqi);
b1.arg(&nkvi);
b1.arg(&hdi);
b1.arg(&mbps);
b1.arg(&bsi);
b1.arg(&scale);
b1.arg(&nsp);
let safe_kv = cuda_paged_runtime_config().kv_capacity;
let safe_chunk = (safe_kv + num_splits - 1) / num_splits;
let shmem1 = (safe_chunk.max(chunk).max(1) as u32) * 4;
unsafe {
b1.launch(LaunchConfig {
grid_dim: (num_heads as u32, num_seqs as u32, num_splits as u32),
block_dim: (256, 1, 1),
shared_mem_bytes: shmem1,
})
}
.map_err(|e| FerrumError::model(format!("paged_batched_flash phase1: {e}")))?;
let pout2 = partial_out.slice(..);
let pm2 = partial_m.slice(..);
let pl2 = partial_l.slice(..);
let mut b2 = stream.launch_builder(&phase2);
b2.arg(&pout2);
b2.arg(&pm2);
b2.arg(&pl2);
b2.arg(out);
b2.arg(&nqi);
b2.arg(&hdi);
b2.arg(&nsp);
unsafe {
b2.launch(LaunchConfig {
grid_dim: (num_heads as u32, num_seqs as u32, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("paged_batched_flash phase2: {e}")))?;
Ok::<(), FerrumError>(())
},
)
}
#[allow(clippy::too_many_arguments)]
fn paged_batched_decode_single_pass(
ctx: &mut CudaState,
q: &CudaSlice<f16>,
k_pool: &CudaSlice<f16>,
v_pool: &CudaSlice<f16>,
out: &mut CudaSlice<f16>,
block_tables: &CudaSlice<u32>,
valid_kv_lens: &CudaSlice<u32>,
num_seqs: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
let func = ctx.func(
"paged_batched_decode_attn",
ptx::PAGED_DECODE_ATTENTION,
"paged_batched_decode_attn_f16",
);
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let qv = q.slice(..);
let kp = k_pool.slice(..);
let vp = v_pool.slice(..);
let bt = block_tables.slice(..);
let kvl = valid_kv_lens.slice(..);
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_blocks_per_seq as i32;
let bsi = block_size as i32;
let safe_kv_max = cuda_paged_runtime_config().kv_capacity;
let shared_kv = safe_kv_max.max(max_kv_len).max(1);
let shared_bytes = (shared_kv as u32) * 4;
if shared_bytes > 48 * 1024 {
let _ = func.set_attribute(
cudarc::driver::sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_bytes as i32,
);
}
let mut b = stream.launch_builder(&func);
b.arg(&qv);
b.arg(&kp);
b.arg(&vp);
b.arg(&bt);
b.arg(&kvl);
b.arg(out);
b.arg(&nqi);
b.arg(&nkvi);
b.arg(&hdi);
b.arg(&mbps);
b.arg(&bsi);
b.arg(&scale);
unsafe {
b.launch(LaunchConfig {
grid_dim: (num_heads as u32, num_seqs as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("paged_batched_decode_attn: {e}")))
}
impl BackendPagedKv for CudaBackend {
fn supports_paged_kv() -> bool {
true
}
fn supports_varlen_qkv() -> bool {
true
}
fn populate_batched_pointers(
ctx: &mut Self::Context,
k_caches: &[&Self::Buffer],
v_caches: &[&Self::Buffer],
num_layers: usize,
m: usize,
) -> Result<()> {
use cudarc::driver::DevicePtr;
if num_layers == 0 || m == 0 {
return Ok(());
}
if num_layers > super::MAX_LAYERS_FOR_GRAPH {
return Err(FerrumError::model(format!(
"populate_batched_pointers: num_layers={num_layers} > MAX_LAYERS_FOR_GRAPH={}",
super::MAX_LAYERS_FOR_GRAPH
)));
}
if m > BATCHED_SCRATCH_CAP {
return Err(FerrumError::model(format!(
"populate_batched_pointers: m={m} > BATCHED_SCRATCH_CAP={BATCHED_SCRATCH_CAP}",
)));
}
if k_caches.len() != num_layers * m || v_caches.len() != num_layers * m {
return Err(FerrumError::model(
"populate_batched_pointers: k/v_caches length != num_layers * m",
));
}
let stream = ctx.stream.clone();
if ctx.batched_scratch_u64_cache.is_none() {
ctx.batched_scratch_u64_cache = Some(
stream
.alloc_zeros::<u64>(HOST_STAGING_TOTAL)
.map_err(|e| FerrumError::model(format!("alloc cache_ptrs: {e}")))?,
);
}
if ctx.batched_scratch_u64_k.is_none() {
ctx.batched_scratch_u64_k = Some(
stream
.alloc_zeros::<u64>(HOST_STAGING_TOTAL)
.map_err(|e| FerrumError::model(format!("alloc k_ptrs: {e}")))?,
);
}
if ctx.batched_scratch_u64_v.is_none() {
ctx.batched_scratch_u64_v = Some(
stream
.alloc_zeros::<u64>(HOST_STAGING_TOTAL)
.map_err(|e| FerrumError::model(format!("alloc v_ptrs: {e}")))?,
);
}
for li in 0..num_layers {
let k_off = li * BATCHED_SCRATCH_CAP;
let v_off = (li + super::MAX_LAYERS_FOR_GRAPH) * BATCHED_SCRATCH_CAP;
for i in 0..m {
let (kp, _) = k_caches[li * m + i].as_f16().device_ptr(&stream);
let (vp, _) = v_caches[li * m + i].as_f16().device_ptr(&stream);
ctx.batched_host_cache_ptrs[k_off + i] = kp;
ctx.batched_host_cache_ptrs[v_off + i] = vp;
ctx.batched_host_k_ptrs[k_off + i] = kp;
ctx.batched_host_v_ptrs[k_off + i] = vp;
}
}
ctx.ctx
.bind_to_thread()
.map_err(|e| FerrumError::unsupported(format!("populate bind_to_thread: {e}")))?;
let total_bytes = HOST_STAGING_TOTAL * std::mem::size_of::<u64>();
unsafe {
use cudarc::driver::{sys, DevicePtrMut};
let scratch_cache = ctx.batched_scratch_u64_cache.as_mut().unwrap();
let (dst, _g) = scratch_cache.device_ptr_mut(&stream);
let st = sys::cuMemcpyHtoD_v2(
dst,
ctx.batched_host_cache_ptrs.as_ptr() as *const std::ffi::c_void,
total_bytes,
);
if st != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!(
"populate cache_ptrs sync memcpy: {st:?}"
)));
}
let scratch_k = ctx.batched_scratch_u64_k.as_mut().unwrap();
let (dst, _g) = scratch_k.device_ptr_mut(&stream);
let st = sys::cuMemcpyHtoD_v2(
dst,
ctx.batched_host_k_ptrs.as_ptr() as *const std::ffi::c_void,
total_bytes,
);
if st != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!(
"populate k_ptrs sync memcpy: {st:?}"
)));
}
let scratch_v = ctx.batched_scratch_u64_v.as_mut().unwrap();
let (dst, _g) = scratch_v.device_ptr_mut(&stream);
let st = sys::cuMemcpyHtoD_v2(
dst,
ctx.batched_host_v_ptrs.as_ptr() as *const std::ffi::c_void,
total_bytes,
);
if st != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!(
"populate v_ptrs sync memcpy: {st:?}"
)));
}
}
Ok(())
}
fn paged_varlen_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
cu_seqlens_q: &Self::Buffer,
pos_offsets: &Self::Buffer,
block_tables: &Self::Buffer,
num_seqs: usize,
total_q_tokens: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
if num_seqs == 0 || total_q_tokens == 0 {
return Ok(());
}
let use_split_k = match cuda_paged_runtime_config().split_k_attn {
Some(force) => force,
None => total_q_tokens <= 64 && (num_seqs <= 4 || max_kv_len >= 768),
};
if use_split_k {
return paged_varlen_split_k_dispatch(
ctx,
q.as_f16(),
k_pool.as_f16(),
v_pool.as_f16(),
out.as_f16_mut(),
cu_seqlens_q.as_u32(),
pos_offsets.as_u32(),
block_tables.as_u32(),
num_seqs,
total_q_tokens,
max_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_blocks_per_seq,
);
}
let func = ctx.func(
"paged_varlen_attn",
ptx::PAGED_VARLEN_ATTENTION,
"paged_varlen_attn_f16",
);
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let qv = q.as_f16().slice(..);
let kp = k_pool.as_f16().slice(..);
let vp = v_pool.as_f16().slice(..);
let csq = cu_seqlens_q.as_u32().slice(..);
let po = pos_offsets.as_u32().slice(..);
let bt = block_tables.as_u32().slice(..);
let ns = num_seqs as i32;
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_blocks_per_seq as i32;
let bsi = block_size as i32;
let shared_kv = cuda_paged_runtime_config().shared_kv_for(max_kv_len);
let shared_bytes = (shared_kv as u32) * 4;
if shared_bytes > 48 * 1024 {
let _ = func.set_attribute(
cudarc::driver::sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_bytes as i32,
);
}
let mut b = stream.launch_builder(&func);
b.arg(&qv);
b.arg(&kp);
b.arg(&vp);
b.arg(&csq);
b.arg(&po);
b.arg(&bt);
b.arg(out);
b.arg(&ns);
b.arg(&nqi);
b.arg(&nkvi);
b.arg(&hdi);
b.arg(&mbps);
b.arg(&bsi);
b.arg(&scale);
unsafe {
b.launch(LaunchConfig {
grid_dim: (num_heads as u32, total_q_tokens as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("paged_varlen_attn: {e}")))
}
fn paged_varlen_attention_fa2_ffi(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
lse: &mut Self::Buffer,
cu_seqlens_q: &Self::Buffer,
seq_lens: &Self::Buffer,
block_tables: &Self::Buffer,
num_seqs: usize,
total_q_tokens: usize,
max_q_len: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
#[cfg(feature = "fa2-source")]
if cuda_paged_runtime_config().fa2_source {
return super::fa2_source::paged_varlen_attention_fa2_source(
super::fa2_source::Fa2SourcePagedVarlenArgs {
stream: &ctx.stream,
q: q.as_f16(),
k_pool: k_pool.as_f16(),
v_pool: v_pool.as_f16(),
out: out.as_f16_mut(),
lse: lse.as_f32_mut(),
cu_seqlens_q: cu_seqlens_q.as_u32(),
seq_lens: seq_lens.as_u32(),
block_tables: block_tables.as_u32(),
num_seqs,
total_q_tokens,
max_q_len,
max_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_blocks_per_seq,
},
);
}
super::fa2_ffi::paged_varlen_attention_fa2_ffi(
&ctx.stream,
q.as_f16(),
k_pool.as_f16(),
v_pool.as_f16(),
out.as_f16_mut(),
lse.as_f32_mut(),
cu_seqlens_q.as_u32(),
seq_lens.as_u32(),
block_tables.as_u32(),
num_seqs,
total_q_tokens,
max_q_len,
max_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_blocks_per_seq,
)
}
#[allow(clippy::too_many_arguments)]
fn paged_decode_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
block_tables: &Self::Buffer,
context_lens: &Self::Buffer,
num_seqs: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
q_len: usize,
) -> Result<()> {
let max_kv_len = block_size * max_num_blocks_per_seq;
if q_len == 1 {
return Self::paged_batched_decode_attention(
ctx,
q,
k_pool,
v_pool,
out,
block_tables,
context_lens,
num_seqs,
max_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
);
}
if num_seqs != 1 {
return Err(FerrumError::model(format!(
"paged_decode_attention(CUDA): q_len={q_len} num_seqs={num_seqs} \
not supported (caller must split prefill into per-seq calls)"
)));
}
let cl_host: Vec<u32> = {
let stream = ctx.stream.clone();
let view = context_lens.as_u32().slice(0..1);
let mut h = vec![0u32; 1];
stream
.memcpy_dtoh(&view, h.as_mut_slice())
.map_err(|e| FerrumError::model(format!("dtoh context_lens: {e}")))?;
stream
.synchronize()
.map_err(|e| FerrumError::model(format!("dtoh sync: {e}")))?;
h
};
let final_kv_len = cl_host[0] as usize;
if final_kv_len < q_len {
return Err(FerrumError::model(format!(
"paged_decode_attention(CUDA): final_kv_len={final_kv_len} < q_len={q_len}"
)));
}
let pos_offset = (final_kv_len - q_len) as u32;
let mut cu_seqlens_q_buf = <Self as Backend>::alloc_typed(crate::backend::Dtype::U32, 2);
<Self as Backend>::write_typed::<u32>(ctx, &mut cu_seqlens_q_buf, &[0u32, q_len as u32]);
let mut pos_offsets_buf = <Self as Backend>::alloc_typed(crate::backend::Dtype::U32, 1);
<Self as Backend>::write_typed::<u32>(ctx, &mut pos_offsets_buf, &[pos_offset]);
let q_n = q_len * num_heads * head_dim;
if ctx.paged_attn_out_tm_capacity < q_n {
let stream = ctx.stream.clone();
let n_grown = q_n.next_power_of_two().max(q_n);
ctx.paged_attn_out_tm = Some(crate::backend::CudaBuf::from_f16(
stream
.alloc_zeros::<f16>(n_grown)
.map_err(|e| FerrumError::model(format!("alloc paged_attn_out_tm: {e}")))?,
));
ctx.paged_attn_out_tm_capacity = n_grown;
}
let out_tm_ptr: *mut crate::backend::CudaBuf =
ctx.paged_attn_out_tm
.as_mut()
.expect("paged_attn_out_tm allocated") as *mut _;
unsafe {
Self::paged_varlen_attention(
ctx,
q,
k_pool,
v_pool,
&mut *out_tm_ptr,
&cu_seqlens_q_buf,
&pos_offsets_buf,
block_tables,
1,
q_len,
final_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
)?;
<Self as Backend>::transpose_token_to_head(
ctx,
&*out_tm_ptr,
out,
q_len,
num_heads,
head_dim,
);
}
Ok(())
}
fn paged_batched_decode_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
block_tables: &Self::Buffer,
valid_kv_lens: &Self::Buffer,
num_seqs: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
if num_seqs == 0 {
return Ok(());
}
if cuda_paged_runtime_config().paged_flash {
return paged_batched_flash_dispatch(
ctx,
q.as_f16(),
k_pool.as_f16(),
v_pool.as_f16(),
out.as_f16_mut(),
block_tables.as_u32(),
valid_kv_lens.as_u32(),
num_seqs,
max_kv_len,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_blocks_per_seq,
);
}
let func = ctx.func(
"paged_batched_decode_attn",
ptx::PAGED_DECODE_ATTENTION,
"paged_batched_decode_attn_f16",
);
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let qv = q.as_f16().slice(..);
let kp = k_pool.as_f16().slice(..);
let vp = v_pool.as_f16().slice(..);
let bt = block_tables.as_u32().slice(..);
let kvl = valid_kv_lens.as_u32().slice(..);
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_blocks_per_seq as i32;
let bsi = block_size as i32;
let mut b = stream.launch_builder(&func);
b.arg(&qv);
b.arg(&kp);
b.arg(&vp);
b.arg(&bt);
b.arg(&kvl);
b.arg(out);
b.arg(&nqi);
b.arg(&nkvi);
b.arg(&hdi);
b.arg(&mbps);
b.arg(&bsi);
b.arg(&scale);
let shared_kv = cuda_paged_runtime_config().shared_kv_for(max_kv_len);
let shared_bytes = (shared_kv as u32) * 4;
unsafe {
b.launch(LaunchConfig {
grid_dim: (num_heads as u32, num_seqs as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("paged_batched_decode_attn: {e}")))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
qkv_byte_offset: u64,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
q_out_byte_offset: u64,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
block_table: &Self::Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
cache_len: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
if tokens == 0 {
return Ok(());
}
let func = ctx.func(
"split_qkv_norm_rope_into_paged_cache",
ptx::SPLIT_QKV_NORM_ROPE_INTO_PAGED_CACHE,
"split_qkv_norm_rope_into_paged_cache_f16",
);
let stream = ctx.stream.clone();
let qkv_byte_offset_u64 = qkv_byte_offset;
let q_out_byte_offset_u64 = q_out_byte_offset;
let tokens_i32 = tokens as i32;
let q_heads_i32 = q_heads as i32;
let kv_heads_i32 = kv_heads as i32;
let head_dim_i32 = head_dim as i32;
let pos_offset_i32 = pos_offset as i32;
let cache_len_i32 = cache_len as i32;
let block_size_i32 = block_size as i32;
let max_blocks_per_seq_i32 = max_blocks_per_seq as i32;
let qk_mode_i32 = qk_mode;
let mut b = stream.launch_builder(&func);
b.arg(qkv);
b.arg(&qkv_byte_offset_u64);
b.arg(q_norm_w);
b.arg(k_norm_w);
b.arg(cos);
b.arg(sin);
b.arg(q_out);
b.arg(&q_out_byte_offset_u64);
b.arg(cache_k);
b.arg(cache_v);
b.arg(block_table);
b.arg(&tokens_i32);
b.arg(&q_heads_i32);
b.arg(&kv_heads_i32);
b.arg(&head_dim_i32);
b.arg(&pos_offset_i32);
b.arg(&eps);
b.arg(&qk_mode_i32);
b.arg(&cache_len_i32);
b.arg(&block_size_i32);
b.arg(&max_blocks_per_seq_i32);
let total_heads = (q_heads + 2 * kv_heads) as u32;
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, total_heads, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("split_qkv_norm_rope_into_paged_cache: {e}")))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache_varlen(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cu_seqlens_q: &Self::Buffer,
pos_offsets: &Self::Buffer,
block_tables: &Self::Buffer,
num_seqs: usize,
m_total: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
eps: f32,
qk_mode: i32,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
if m_total == 0 || num_seqs == 0 {
return Ok(());
}
let func = ctx.func(
"split_qkv_norm_rope_into_paged_cache_varlen",
ptx::SPLIT_QKV_NORM_ROPE_INTO_PAGED_CACHE,
"split_qkv_norm_rope_into_paged_cache_varlen_f16",
);
let stream = ctx.stream.clone();
let num_seqs_i32 = num_seqs as i32;
let m_total_i32 = m_total as i32;
let q_heads_i32 = q_heads as i32;
let kv_heads_i32 = kv_heads as i32;
let head_dim_i32 = head_dim as i32;
let qk_mode_i32 = qk_mode;
let block_size_i32 = block_size as i32;
let max_blocks_per_seq_i32 = max_blocks_per_seq as i32;
let mut b = stream.launch_builder(&func);
b.arg(qkv);
b.arg(q_norm_w);
b.arg(k_norm_w);
b.arg(cos);
b.arg(sin);
b.arg(q_out);
b.arg(cache_k);
b.arg(cache_v);
b.arg(cu_seqlens_q);
b.arg(pos_offsets);
b.arg(block_tables);
b.arg(&num_seqs_i32);
b.arg(&m_total_i32);
b.arg(&q_heads_i32);
b.arg(&kv_heads_i32);
b.arg(&head_dim_i32);
b.arg(&eps);
b.arg(&qk_mode_i32);
b.arg(&block_size_i32);
b.arg(&max_blocks_per_seq_i32);
let total_heads = (q_heads + 2 * kv_heads) as u32;
unsafe {
b.launch(LaunchConfig {
grid_dim: (m_total as u32, total_heads, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.map(|_| ())
.map_err(|e| {
FerrumError::model(format!("split_qkv_norm_rope_into_paged_cache_varlen: {e}"))
})
}
#[cfg(feature = "vllm-paged-attn-v2")]
fn supports_vllm_paged_attn() -> bool {
true
}
#[cfg(feature = "vllm-paged-attn-v2")]
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache_vllm(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
qkv_byte_offset: u64,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
q_out_byte_offset: u64,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
block_table: &Self::Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
cache_len: usize,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
if tokens == 0 {
return Ok(());
}
let func = ctx.func(
"split_qkv_norm_rope_into_paged_cache_vllm",
ptx::SPLIT_QKV_NORM_ROPE_INTO_PAGED_CACHE_VLLM,
"split_qkv_norm_rope_into_paged_cache_vllm_f16",
);
let stream = ctx.stream.clone();
let qkv_byte_offset_u64 = qkv_byte_offset;
let q_out_byte_offset_u64 = q_out_byte_offset;
let tokens_i32 = tokens as i32;
let q_heads_i32 = q_heads as i32;
let kv_heads_i32 = kv_heads as i32;
let head_dim_i32 = head_dim as i32;
let pos_offset_i32 = pos_offset as i32;
let cache_len_i32 = cache_len as i32;
let block_size_i32 = block_size as i32;
let max_blocks_per_seq_i32 = max_blocks_per_seq as i32;
let qk_mode_i32 = qk_mode;
let mut b = stream.launch_builder(&func);
b.arg(qkv);
b.arg(&qkv_byte_offset_u64);
b.arg(q_norm_w);
b.arg(k_norm_w);
b.arg(cos);
b.arg(sin);
b.arg(q_out);
b.arg(&q_out_byte_offset_u64);
b.arg(cache_k);
b.arg(cache_v);
b.arg(block_table);
b.arg(&tokens_i32);
b.arg(&q_heads_i32);
b.arg(&kv_heads_i32);
b.arg(&head_dim_i32);
b.arg(&pos_offset_i32);
b.arg(&eps);
b.arg(&qk_mode_i32);
b.arg(&cache_len_i32);
b.arg(&block_size_i32);
b.arg(&max_blocks_per_seq_i32);
let total_heads = (q_heads + 2 * kv_heads) as u32;
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, total_heads, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("split_qkv_norm_rope_into_paged_cache_vllm: {e}")))
}
#[cfg(feature = "vllm-paged-attn-v2")]
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache_varlen_vllm(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cu_seqlens_q: &Self::Buffer,
pos_offsets: &Self::Buffer,
block_tables: &Self::Buffer,
num_seqs: usize,
m_total: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
eps: f32,
qk_mode: i32,
block_size: usize,
max_blocks_per_seq: usize,
) -> Result<()> {
if m_total == 0 || num_seqs == 0 {
return Ok(());
}
let func = ctx.func(
"split_qkv_norm_rope_into_paged_cache_varlen_vllm",
ptx::SPLIT_QKV_NORM_ROPE_INTO_PAGED_CACHE_VLLM,
"split_qkv_norm_rope_into_paged_cache_varlen_vllm_f16",
);
let stream = ctx.stream.clone();
let num_seqs_i32 = num_seqs as i32;
let m_total_i32 = m_total as i32;
let q_heads_i32 = q_heads as i32;
let kv_heads_i32 = kv_heads as i32;
let head_dim_i32 = head_dim as i32;
let qk_mode_i32 = qk_mode;
let block_size_i32 = block_size as i32;
let max_blocks_per_seq_i32 = max_blocks_per_seq as i32;
let mut b = stream.launch_builder(&func);
b.arg(qkv);
b.arg(q_norm_w);
b.arg(k_norm_w);
b.arg(cos);
b.arg(sin);
b.arg(q_out);
b.arg(cache_k);
b.arg(cache_v);
b.arg(cu_seqlens_q);
b.arg(pos_offsets);
b.arg(block_tables);
b.arg(&num_seqs_i32);
b.arg(&m_total_i32);
b.arg(&q_heads_i32);
b.arg(&kv_heads_i32);
b.arg(&head_dim_i32);
b.arg(&eps);
b.arg(&qk_mode_i32);
b.arg(&block_size_i32);
b.arg(&max_blocks_per_seq_i32);
let total_heads = (q_heads + 2 * kv_heads) as u32;
unsafe {
b.launch(LaunchConfig {
grid_dim: (m_total as u32, total_heads, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.map(|_| ())
.map_err(|e| {
FerrumError::model(format!(
"split_qkv_norm_rope_into_paged_cache_varlen_vllm: {e}"
))
})
}
#[cfg(feature = "vllm-paged-attn-v2")]
#[allow(clippy::too_many_arguments)]
fn paged_decode_attention_v2(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
block_tables: &Self::Buffer,
context_lens: &Self::Buffer,
num_seqs: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
max_seq_len: usize,
) -> Result<()> {
if num_seqs == 0 {
return Ok(());
}
let stream = ctx.stream.clone();
super::vllm_paged_attn::dispatch_paged_attention_v2(
&stream,
ctx.ordinal,
out.as_f16_mut(),
q.as_f16(),
k_pool.as_f16(),
v_pool.as_f16(),
block_tables.as_u32(),
context_lens.as_u32(),
num_seqs,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
max_seq_len,
)
}
#[cfg(feature = "vllm-paged-attn-v2")]
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_vllm_layout(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
block_tables: &Self::Buffer,
context_lens: &Self::Buffer,
num_seqs: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
q_len: usize,
) -> Result<()> {
if q_len == 0 || num_seqs == 0 {
return Ok(());
}
if num_seqs != 1 {
return Err(FerrumError::model(format!(
"paged_varlen_attention_vllm_layout(CUDA): q_len={q_len} num_seqs={num_seqs} \
not supported yet"
)));
}
let final_kv_len = {
let stream = ctx.stream.clone();
let view = context_lens.as_u32().slice(0..1);
let mut host = vec![0u32; 1];
stream
.memcpy_dtoh(&view, host.as_mut_slice())
.map_err(|e| FerrumError::model(format!("dtoh context_lens vllm: {e}")))?;
stream
.synchronize()
.map_err(|e| FerrumError::model(format!("dtoh sync vllm: {e}")))?;
host[0] as usize
};
if final_kv_len < q_len {
return Err(FerrumError::model(format!(
"paged_varlen_attention_vllm_layout(CUDA): final_kv_len={final_kv_len} < q_len={q_len}"
)));
}
let pos_offset = (final_kv_len - q_len) as u32;
let mut cu_seqlens_q_buf = <Self as Backend>::alloc_typed(crate::backend::Dtype::U32, 2);
<Self as Backend>::write_typed::<u32>(ctx, &mut cu_seqlens_q_buf, &[0u32, q_len as u32]);
let mut pos_offsets_buf = <Self as Backend>::alloc_typed(crate::backend::Dtype::U32, 1);
<Self as Backend>::write_typed::<u32>(ctx, &mut pos_offsets_buf, &[pos_offset]);
let q_n = q_len * num_heads * head_dim;
if ctx.paged_attn_out_tm_capacity < q_n {
let stream = ctx.stream.clone();
let n_grown = q_n.next_power_of_two().max(q_n);
ctx.paged_attn_out_tm = Some(crate::backend::CudaBuf::from_f16(
stream.alloc_zeros::<f16>(n_grown).map_err(|e| {
FerrumError::model(format!("alloc paged_attn_out_tm vllm: {e}"))
})?,
));
ctx.paged_attn_out_tm_capacity = n_grown;
}
let func = ctx.func(
"paged_varlen_attn_vllm",
ptx::PAGED_VARLEN_ATTENTION_VLLM,
"paged_varlen_attn_vllm_f16",
);
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let shared_kv = cuda_paged_runtime_config().shared_kv_for(final_kv_len);
let shared_bytes = (shared_kv as u32) * 4;
if shared_bytes > 48 * 1024 {
let _ = func.set_attribute(
cudarc::driver::sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_bytes as i32,
);
}
let out_tm_ptr: *mut crate::backend::CudaBuf =
ctx.paged_attn_out_tm
.as_mut()
.expect("paged_attn_out_tm allocated") as *mut _;
unsafe {
let qv = q.as_f16().slice(..);
let kp = k_pool.as_f16().slice(..);
let vp = v_pool.as_f16().slice(..);
let csq = cu_seqlens_q_buf.as_u32().slice(..);
let po = pos_offsets_buf.as_u32().slice(..);
let bt = block_tables.as_u32().slice(..);
let ns = num_seqs as i32;
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_num_blocks_per_seq as i32;
let bsi = block_size as i32;
let mut b = stream.launch_builder(&func);
b.arg(&qv);
b.arg(&kp);
b.arg(&vp);
b.arg(&csq);
b.arg(&po);
b.arg(&bt);
b.arg(&mut *out_tm_ptr);
b.arg(&ns);
b.arg(&nqi);
b.arg(&nkvi);
b.arg(&hdi);
b.arg(&mbps);
b.arg(&bsi);
b.arg(&scale);
b.launch(LaunchConfig {
grid_dim: (num_heads as u32, q_len as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
.map_err(|e| FerrumError::model(format!("paged_varlen_attn_vllm: {e}")))?;
<Self as Backend>::transpose_token_to_head(
ctx,
&*out_tm_ptr,
out,
q_len,
num_heads,
head_dim,
);
}
Ok(())
}
#[cfg(feature = "vllm-paged-attn-v2")]
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_vllm(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
cu_seqlens_q: &Self::Buffer,
pos_offsets: &Self::Buffer,
block_tables: &Self::Buffer,
num_seqs: usize,
total_q_tokens: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
) -> Result<()> {
if num_seqs == 0 || total_q_tokens == 0 {
return Ok(());
}
let func = ctx.func(
"paged_varlen_attn_vllm",
ptx::PAGED_VARLEN_ATTENTION_VLLM,
"paged_varlen_attn_vllm_f16",
);
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let shared_kv = cuda_paged_runtime_config().shared_kv_for(max_kv_len);
let shared_bytes = (shared_kv as u32) * 4;
if shared_bytes > 48 * 1024 {
let _ = func.set_attribute(
cudarc::driver::sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_bytes as i32,
);
}
let qv = q.as_f16().slice(..);
let kp = k_pool.as_f16().slice(..);
let vp = v_pool.as_f16().slice(..);
let csq = cu_seqlens_q.as_u32().slice(..);
let po = pos_offsets.as_u32().slice(..);
let bt = block_tables.as_u32().slice(..);
let ns = num_seqs as i32;
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_num_blocks_per_seq as i32;
let bsi = block_size as i32;
let mut b = stream.launch_builder(&func);
b.arg(&qv);
b.arg(&kp);
b.arg(&vp);
b.arg(&csq);
b.arg(&po);
b.arg(&bt);
b.arg(out);
b.arg(&ns);
b.arg(&nqi);
b.arg(&nkvi);
b.arg(&hdi);
b.arg(&mbps);
b.arg(&bsi);
b.arg(&scale);
unsafe {
b.launch(LaunchConfig {
grid_dim: (num_heads as u32, total_q_tokens as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("paged_varlen_attn_vllm: {e}")))
}
#[cfg(feature = "vllm-paged-attn-v2")]
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_vllm_tiled_q4(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
cu_seqlens_q: &Self::Buffer,
pos_offsets: &Self::Buffer,
block_tables: &Self::Buffer,
tile_seqs: &Self::Buffer,
tile_starts: &Self::Buffer,
num_tiles: usize,
max_kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
) -> Result<()> {
if num_tiles == 0 {
return Ok(());
}
let func = ctx.func(
"paged_varlen_attn_vllm",
ptx::PAGED_VARLEN_ATTENTION_VLLM,
"paged_varlen_attn_vllm_tiled_q4_f16",
);
let scale: f32 = 1.0 / (head_dim as f32).sqrt();
let stream = ctx.stream.clone();
let shared_kv = cuda_paged_runtime_config().shared_kv_for(max_kv_len);
let shared_bytes = (4 * shared_kv as u32) * 4;
if shared_bytes > 48 * 1024 {
let _ = func.set_attribute(
cudarc::driver::sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_bytes as i32,
);
}
let qv = q.as_f16().slice(..);
let kp = k_pool.as_f16().slice(..);
let vp = v_pool.as_f16().slice(..);
let csq = cu_seqlens_q.as_u32().slice(..);
let po = pos_offsets.as_u32().slice(..);
let bt = block_tables.as_u32().slice(..);
let ts = tile_seqs.as_u32().slice(..);
let tst = tile_starts.as_u32().slice(..);
let nqi = num_heads as i32;
let nkvi = num_kv_heads as i32;
let hdi = head_dim as i32;
let mbps = max_num_blocks_per_seq as i32;
let bsi = block_size as i32;
let score_stride = shared_kv as i32;
let mut b = stream.launch_builder(&func);
b.arg(&qv);
b.arg(&kp);
b.arg(&vp);
b.arg(&csq);
b.arg(&po);
b.arg(&bt);
b.arg(&ts);
b.arg(&tst);
b.arg(out);
b.arg(&nqi);
b.arg(&nkvi);
b.arg(&hdi);
b.arg(&mbps);
b.arg(&bsi);
b.arg(&score_stride);
b.arg(&scale);
unsafe {
b.launch(LaunchConfig {
grid_dim: (num_heads as u32, num_tiles as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
}
.map(|_| ())
.map_err(|e| FerrumError::model(format!("paged_varlen_attn_vllm_tiled_q4: {e}")))
}
}