use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
use ferrum_types::{FerrumError, Result};
use half::f16;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
extern "C" {
fn ferrum_vllm_paged_attention_v1_f16_h128_b16(
out: *mut std::ffi::c_void, query: *const std::ffi::c_void, key_cache: *const std::ffi::c_void,
value_cache: *const std::ffi::c_void,
num_kv_heads: i32,
scale: f32,
block_tables: *const std::ffi::c_void,
seq_lens: *const std::ffi::c_void,
num_seqs: i32,
num_heads: i32,
max_num_blocks_per_seq: i32,
q_stride: i32,
kv_block_stride: i32,
kv_head_stride: i32,
max_seq_len: i32,
stream: *mut std::ffi::c_void,
);
fn ferrum_vllm_paged_attention_v2_f16_h128_b16(
out: *mut std::ffi::c_void, exp_sums: *mut std::ffi::c_void, max_logits: *mut std::ffi::c_void, tmp_out: *mut std::ffi::c_void, query: *const std::ffi::c_void, key_cache: *const std::ffi::c_void,
value_cache: *const std::ffi::c_void,
num_kv_heads: i32,
scale: f32,
block_tables: *const std::ffi::c_void, seq_lens: *const std::ffi::c_void,
num_seqs: i32,
num_heads: i32,
max_num_blocks_per_seq: i32,
q_stride: i32,
kv_block_stride: i32,
kv_head_stride: i32,
max_seq_len: i32,
stream: *mut std::ffi::c_void,
);
}
struct PagedAttnScratch {
exp_sums: CudaSlice<f32>,
max_logits: CudaSlice<f32>,
tmp_out: CudaSlice<f16>,
capacity: PagedAttnCapacity,
}
#[derive(Copy, Clone, Eq, PartialEq)]
struct PagedAttnCapacity {
num_seqs: usize,
num_heads: usize,
max_partitions: usize,
head_dim: usize,
}
static PA_SCRATCH: std::sync::OnceLock<std::sync::RwLock<HashMap<usize, PagedAttnScratch>>> =
std::sync::OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct VllmPagedAttnRuntimeConfig {
v1_short: bool,
}
impl VllmPagedAttnRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self { v1_short: true };
for (name, value) in vars {
if name.as_ref() == "FERRUM_VLLM_PAGED_ATTN_V1_SHORT" {
config.v1_short = value.as_ref() != "0";
}
}
config
}
}
fn vllm_paged_attn_runtime_config() -> &'static VllmPagedAttnRuntimeConfig {
static CONFIG: OnceLock<VllmPagedAttnRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(VllmPagedAttnRuntimeConfig::from_env)
}
fn pa_scratch_slots() -> &'static std::sync::RwLock<HashMap<usize, PagedAttnScratch>> {
PA_SCRATCH.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
fn pa_v1_short_enabled() -> bool {
vllm_paged_attn_runtime_config().v1_short
}
fn ensure_pa_scratch(
stream: &Arc<CudaStream>,
ordinal: usize,
num_seqs: usize,
num_heads: usize,
max_partitions: usize,
head_dim: usize,
) {
let need = PagedAttnCapacity {
num_seqs,
num_heads,
max_partitions,
head_dim,
};
{
let g = pa_scratch_slots().read().expect("PA_SCRATCH poisoned");
if let Some(s) = g.get(&ordinal) {
if s.capacity.num_seqs >= need.num_seqs
&& s.capacity.num_heads >= need.num_heads
&& s.capacity.max_partitions >= need.max_partitions
&& s.capacity.head_dim >= need.head_dim
{
return;
}
}
}
let cap = PagedAttnCapacity {
num_seqs: need.num_seqs.max(64),
num_heads: need.num_heads,
max_partitions: need.max_partitions.max(8),
head_dim: need.head_dim,
};
let n_floats = cap.num_seqs * cap.num_heads * cap.max_partitions;
let n_halves = cap.num_seqs * cap.num_heads * cap.max_partitions * cap.head_dim;
let exp_sums = unsafe { stream.alloc::<f32>(n_floats) }.expect("PA exp_sums alloc");
let max_logits = unsafe { stream.alloc::<f32>(n_floats) }.expect("PA max_logits alloc");
let tmp_out = unsafe { stream.alloc::<f16>(n_halves) }.expect("PA tmp_out alloc");
let mut w = pa_scratch_slots().write().expect("PA_SCRATCH poisoned");
w.insert(
ordinal,
PagedAttnScratch {
exp_sums,
max_logits,
tmp_out,
capacity: cap,
},
);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_paged_attention_v2(
stream: &Arc<CudaStream>,
ordinal: usize,
out: &mut CudaSlice<f16>,
q: &CudaSlice<f16>,
k_cache: &CudaSlice<f16>,
v_cache: &CudaSlice<f16>,
block_tables: &CudaSlice<u32>,
seq_lens: &CudaSlice<u32>,
num_seqs: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
max_seq_len: usize,
) -> Result<()> {
const PARTITION_SIZE: usize = 512;
if head_dim != 128 {
return Err(FerrumError::unsupported(format!(
"vllm paged_attn_v2: only head_dim=128 instantiated, got {head_dim}"
)));
}
if block_size != 16 {
return Err(FerrumError::unsupported(format!(
"vllm paged_attn_v2: only block_size=16 instantiated, got {block_size}"
)));
}
let q_stride = (num_heads * head_dim) as i32;
let kv_block_stride = (num_kv_heads * head_dim * block_size) as i32;
let kv_head_stride = (head_dim * block_size) as i32;
let scale = 1.0_f32 / (head_dim as f32).sqrt();
let raw_stream = stream.cu_stream() as *mut std::ffi::c_void;
let use_v1_short = max_seq_len <= PARTITION_SIZE && pa_v1_short_enabled();
if use_v1_short {
unsafe {
let (out_dp, _o_recs) = out.device_ptr_mut(stream);
let (q_dp, _q_recs) = q.device_ptr(stream);
let (k_dp, _k_recs) = k_cache.device_ptr(stream);
let (v_dp, _v_recs) = v_cache.device_ptr(stream);
let (bt_dp, _bt_recs) = block_tables.device_ptr(stream);
let (sl_dp, _sl_recs) = seq_lens.device_ptr(stream);
ferrum_vllm_paged_attention_v1_f16_h128_b16(
out_dp as *mut std::ffi::c_void,
q_dp as *const std::ffi::c_void,
k_dp as *const std::ffi::c_void,
v_dp as *const std::ffi::c_void,
num_kv_heads as i32,
scale,
bt_dp as *const std::ffi::c_void,
sl_dp as *const std::ffi::c_void,
num_seqs as i32,
num_heads as i32,
max_num_blocks_per_seq as i32,
q_stride,
kv_block_stride,
kv_head_stride,
max_seq_len as i32,
raw_stream,
);
}
return Ok(());
}
let max_partitions = max_seq_len.div_ceil(PARTITION_SIZE).max(1);
ensure_pa_scratch(
stream,
ordinal,
num_seqs,
num_heads,
max_partitions,
head_dim,
);
let mut sg = pa_scratch_slots().write().expect("PA_SCRATCH poisoned");
let scratch = sg
.get_mut(&ordinal)
.expect("ensure_pa_scratch must have populated");
unsafe {
let (out_dp, _o_recs) = out.device_ptr_mut(stream);
let (es_dp, _es_recs) = scratch.exp_sums.device_ptr_mut(stream);
let (ml_dp, _ml_recs) = scratch.max_logits.device_ptr_mut(stream);
let (to_dp, _to_recs) = scratch.tmp_out.device_ptr_mut(stream);
let (q_dp, _q_recs) = q.device_ptr(stream);
let (k_dp, _k_recs) = k_cache.device_ptr(stream);
let (v_dp, _v_recs) = v_cache.device_ptr(stream);
let (bt_dp, _bt_recs) = block_tables.device_ptr(stream);
let (sl_dp, _sl_recs) = seq_lens.device_ptr(stream);
ferrum_vllm_paged_attention_v2_f16_h128_b16(
out_dp as *mut std::ffi::c_void,
es_dp as *mut std::ffi::c_void,
ml_dp as *mut std::ffi::c_void,
to_dp as *mut std::ffi::c_void,
q_dp as *const std::ffi::c_void,
k_dp as *const std::ffi::c_void,
v_dp as *const std::ffi::c_void,
num_kv_heads as i32,
scale,
bt_dp as *const std::ffi::c_void,
sl_dp as *const std::ffi::c_void,
num_seqs as i32,
num_heads as i32,
max_num_blocks_per_seq as i32,
q_stride,
kv_block_stride,
kv_head_stride,
max_seq_len as i32,
raw_stream,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::VllmPagedAttnRuntimeConfig;
#[test]
fn vllm_paged_attn_runtime_config_defaults_short_v1_on() {
let config = VllmPagedAttnRuntimeConfig::from_env_vars(std::iter::empty::<(&str, &str)>());
assert!(config.v1_short);
}
#[test]
fn vllm_paged_attn_runtime_config_parses_short_v1_opt_out() {
let config =
VllmPagedAttnRuntimeConfig::from_env_vars([("FERRUM_VLLM_PAGED_ATTN_V1_SHORT", "0")]);
assert!(!config.v1_short);
}
}