#![cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
use cudarc::driver::result as cudarc_result;
use cudarc::driver::sys;
use cudarc::driver::{CudaFunction, CudaSlice, CudaStream, CudaView, LaunchConfig, PushKernelArg};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use tracing::{debug, warn};
struct CuGraphHolder {
cu_graph: sys::CUgraph,
cu_graph_exec: sys::CUgraphExec,
stream: Arc<CudaStream>,
}
impl CuGraphHolder {
unsafe fn launch(&self) -> Result<(), cudarc::driver::DriverError> {
cudarc_result::graph::launch(self.cu_graph_exec, self.stream.cu_stream())
}
unsafe fn upload(&self) -> Result<(), cudarc::driver::DriverError> {
cudarc_result::graph::upload(self.cu_graph_exec, self.stream.cu_stream())
}
}
impl Drop for CuGraphHolder {
fn drop(&mut self) {
unsafe {
let _ = cudarc_result::graph::exec_destroy(self.cu_graph_exec);
let _ = cudarc_result::graph::destroy(self.cu_graph);
}
}
}
unsafe impl Send for CuGraphHolder {}
use super::cuda_attn_kernels::CUDA_ATTENTION_KERNELS_SRC;
use super::cuda_graph::{compile_or_load_ptx, CudaGraph, CudaGraphError};
pub struct CudaAttnModules {
pub fused_qk_norm: CudaFunction,
pub fused_qk_rope: CudaFunction,
pub fused_qk_norm_rope: CudaFunction,
pub fused_kv_store: CudaFunction,
pub batched_attn_scores_v2: CudaFunction,
pub batched_softmax: CudaFunction,
pub batched_attn_weighted_sum: CudaFunction,
}
unsafe impl Send for CudaAttnModules {}
unsafe impl Sync for CudaAttnModules {}
pub struct CudaKvCache {
pub k_cache: CudaSlice<u16>,
pub v_cache: CudaSlice<u16>,
pub n_layers: usize,
pub n_kv: usize,
pub max_seq: usize,
pub head_dim: usize,
}
unsafe impl Send for CudaKvCache {}
unsafe impl Sync for CudaKvCache {}
impl CudaKvCache {
#[inline]
pub fn layer_offset_elements(&self, layer_idx: usize) -> u32 {
(layer_idx * self.n_kv * self.max_seq * self.head_dim) as u32
}
pub fn matches(&self, n_layers: usize, n_kv: usize, max_seq: usize, head_dim: usize) -> bool {
self.n_layers == n_layers
&& self.n_kv == n_kv
&& self.max_seq == max_seq
&& self.head_dim == head_dim
}
}
pub struct CudaFullLayerBuffers {
pub d_hidden: CudaSlice<f32>,
pub d_normed: CudaSlice<f32>,
pub d_qkv: CudaSlice<f32>,
pub d_q_rope: CudaSlice<f32>,
pub d_k_rope: CudaSlice<f32>,
pub d_cos: CudaSlice<f32>,
pub d_sin: CudaSlice<f32>,
pub d_scores: CudaSlice<f32>,
pub d_attn_out: CudaSlice<f32>,
pub d_gate_up: CudaSlice<f32>,
pub d_swiglu: CudaSlice<f32>,
pub d_pos_seqlen: CudaSlice<u32>,
pub hidden_size: usize,
pub nq: usize,
pub nkv: usize,
pub head_dim: usize,
pub max_seq: usize,
pub intermediate_size: usize,
}
unsafe impl Send for CudaFullLayerBuffers {}
unsafe impl Sync for CudaFullLayerBuffers {}
impl CudaFullLayerBuffers {
pub fn matches(
&self,
hidden_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
max_seq: usize,
intermediate_size: usize,
) -> bool {
self.hidden_size == hidden_size
&& self.nq == nq
&& self.nkv == nkv
&& self.head_dim == head_dim
&& self.max_seq == max_seq
&& self.intermediate_size == intermediate_size
}
}
pub struct CudaCachedLayerWeights {
pub q_weight: Arc<CudaSlice<u8>>,
pub k_weight: Arc<CudaSlice<u8>>,
pub v_weight: Arc<CudaSlice<u8>>,
pub o_weight: Arc<CudaSlice<u8>>,
pub gate_up_weight: Arc<CudaSlice<u8>>,
pub down_weight: Arc<CudaSlice<u8>>,
pub pre_attn_norm: Arc<CudaSlice<f32>>,
pub post_attn_norm: Arc<CudaSlice<f32>>,
pub q_norm: Arc<CudaSlice<f32>>,
pub k_norm: Arc<CudaSlice<f32>>,
}
unsafe impl Send for CudaCachedLayerWeights {}
unsafe impl Sync for CudaCachedLayerWeights {}
pub struct CudaCachedModelWeights {
pub graph: Arc<CudaGraph>,
pub dummy_weight: Arc<CudaSlice<u8>>,
pub layers: Arc<Vec<CudaCachedLayerWeights>>,
pub n_layers: usize,
}
unsafe impl Send for CudaCachedModelWeights {}
unsafe impl Sync for CudaCachedModelWeights {}
struct CudaFullLayerState {
attn_modules: Mutex<Option<Arc<CudaAttnModules>>>,
full_layer_buffers: Mutex<Option<CudaFullLayerBuffers>>,
kv_cache: Mutex<Option<CudaKvCache>>,
f32_weight_cache: Mutex<HashMap<u64, Arc<CudaSlice<f32>>>>,
cached_model_weights: Mutex<Option<CudaCachedModelWeights>>,
cuda_driver_graph: Mutex<Option<Option<CuGraphHolder>>>,
}
unsafe impl Send for CudaFullLayerState {}
unsafe impl Sync for CudaFullLayerState {}
static FULL_LAYER_STATE: OnceLock<CudaFullLayerState> = OnceLock::new();
fn full_layer_state() -> &'static CudaFullLayerState {
FULL_LAYER_STATE.get_or_init(|| CudaFullLayerState {
attn_modules: Mutex::new(None),
full_layer_buffers: Mutex::new(None),
kv_cache: Mutex::new(None),
f32_weight_cache: Mutex::new(HashMap::new()),
cached_model_weights: Mutex::new(None),
cuda_driver_graph: Mutex::new(None),
})
}
static PROFILE_ENABLED: OnceLock<bool> = OnceLock::new();
#[inline(always)]
pub(super) fn profiling() -> bool {
*PROFILE_ENABLED.get_or_init(|| std::env::var("CUDA_PROFILE").is_ok())
}
pub fn get_or_upload_f32_weight(
graph: &CudaGraph,
key: u64,
data: &[f32],
) -> Result<Arc<CudaSlice<f32>>, CudaGraphError> {
let state = full_layer_state();
{
let cache = state
.f32_weight_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(existing) = cache.get(&key) {
return Ok(Arc::clone(existing));
}
}
let d_slice = graph
.stream_arc()
.clone_htod(data)
.map_err(|e| CudaGraphError::DriverError(format!("clone_htod f32: {e}")))?;
let arc = Arc::new(d_slice);
let mut cache = state
.f32_weight_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
cache.insert(key, Arc::clone(&arc));
Ok(arc)
}
fn get_or_build_model_weights(
layer_params: &[CudaFullForwardLayerParams<'_>],
) -> Option<(Arc<CudaGraph>, Arc<Vec<CudaCachedLayerWeights>>)> {
let n_layers = layer_params.len();
let state = full_layer_state();
{
let guard = state.cached_model_weights.lock().ok()?;
if let Some(ref cmw) = *guard {
if cmw.n_layers == n_layers {
return Some((Arc::clone(&cmw.graph), Arc::clone(&cmw.layers)));
}
}
}
let graph = CudaGraph::global().ok()?;
let dummy_weight = Arc::new(graph.stream_arc().alloc_zeros::<u8>(1).ok()?);
let mut cached: Vec<CudaCachedLayerWeights> = Vec::with_capacity(n_layers);
for lp in layer_params {
let q_weight = graph
.get_or_upload_weight_soa(lp.fused_qkv_handle, lp.fused_qkv_bytes)
.ok()?;
let o_weight = graph
.get_or_upload_weight_soa(lp.attn_proj_handle, lp.attn_proj_bytes)
.ok()?;
let gate_bytes = lp.gate_bytes;
let up_bytes = lp.up_bytes;
let gate_up_weight = graph
.get_or_upload_weight_soa_lazy(lp.gate_up_handle, || {
let mut fused = Vec::with_capacity(gate_bytes.len() + up_bytes.len());
fused.extend_from_slice(gate_bytes);
fused.extend_from_slice(up_bytes);
fused
})
.ok()?;
let down_weight = graph
.get_or_upload_weight_soa(lp.down_handle, lp.down_bytes)
.ok()?;
let pre_attn_norm =
get_or_upload_f32_weight(&graph, lp.attn_norm_handle, lp.attn_norm_bytes).ok()?;
let post_attn_norm =
get_or_upload_f32_weight(&graph, lp.ffn_norm_handle, lp.ffn_norm_bytes).ok()?;
let q_norm = get_or_upload_f32_weight(&graph, lp.q_norm_handle, lp.q_norm_bytes).ok()?;
let k_norm = get_or_upload_f32_weight(&graph, lp.k_norm_handle, lp.k_norm_bytes).ok()?;
cached.push(CudaCachedLayerWeights {
q_weight,
k_weight: Arc::clone(&dummy_weight),
v_weight: Arc::clone(&dummy_weight),
o_weight,
gate_up_weight,
down_weight,
pre_attn_norm,
post_attn_norm,
q_norm,
k_norm,
});
}
let layers = Arc::new(cached);
let cmw = CudaCachedModelWeights {
graph: Arc::clone(&graph),
dummy_weight,
layers: Arc::clone(&layers),
n_layers,
};
if let Ok(mut guard) = state.cached_model_weights.lock() {
*guard = Some(cmw);
}
Some((graph, layers))
}
pub fn init_attn_modules(graph: &CudaGraph) -> Result<Arc<CudaAttnModules>, CudaGraphError> {
let state = full_layer_state();
let mut guard = state
.attn_modules
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(ref m) = *guard {
return Ok(Arc::clone(m));
}
let ptx = compile_or_load_ptx(CUDA_ATTENTION_KERNELS_SRC, "attn_kernels")?;
let module = graph
.context_arc()
.load_module(ptx)
.map_err(|e| CudaGraphError::DriverError(format!("load_module attn: {e}")))?;
let load = |name: &str| -> Result<CudaFunction, CudaGraphError> {
module
.load_function(name)
.map_err(|e| CudaGraphError::DriverError(format!("load_function({name}): {e}")))
};
let modules = Arc::new(CudaAttnModules {
fused_qk_norm: load("fused_qk_norm")?,
fused_qk_rope: load("fused_qk_rope")?,
fused_qk_norm_rope: load("fused_qk_norm_rope")?,
fused_kv_store: load("fused_kv_store")?,
batched_attn_scores_v2: load("batched_attn_scores_v2")?,
batched_softmax: load("batched_softmax")?,
batched_attn_weighted_sum: load("batched_attn_weighted_sum")?,
});
*guard = Some(Arc::clone(&modules));
Ok(modules)
}
fn acquire_full_layer_buffers(
graph: &CudaGraph,
hidden_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
max_seq: usize,
intermediate_size: usize,
) -> Result<std::sync::MutexGuard<'static, Option<CudaFullLayerBuffers>>, CudaGraphError> {
let state = full_layer_state();
let mut guard = state
.full_layer_buffers
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some(b) => !b.matches(hidden_size, nq, nkv, head_dim, max_seq, intermediate_size),
None => true,
};
if needs_alloc {
let alloc = |n: usize| -> Result<CudaSlice<f32>, CudaGraphError> {
graph
.stream_arc()
.alloc_zeros::<f32>(n)
.map_err(|e| CudaGraphError::DriverError(format!("alloc_zeros fl({n}): {e}")))
};
let qkv_total = nq * head_dim + 2 * nkv * head_dim;
let half_dim = head_dim / 2;
let alloc_u32 = |n: usize| -> Result<CudaSlice<u32>, CudaGraphError> {
graph
.stream_arc()
.alloc_zeros::<u32>(n)
.map_err(|e| CudaGraphError::DriverError(format!("alloc_zeros u32({n}): {e}")))
};
*guard = Some(CudaFullLayerBuffers {
d_hidden: alloc(hidden_size)?,
d_normed: alloc(hidden_size)?,
d_qkv: alloc(qkv_total)?,
d_q_rope: alloc(nq * head_dim)?,
d_k_rope: alloc(nkv * head_dim)?,
d_cos: alloc(half_dim)?,
d_sin: alloc(half_dim)?,
d_scores: alloc(nq * max_seq)?,
d_attn_out: alloc(nq * head_dim)?,
d_gate_up: alloc(2 * intermediate_size)?,
d_swiglu: alloc(intermediate_size)?,
d_pos_seqlen: alloc_u32(2)?,
hidden_size,
nq,
nkv,
head_dim,
max_seq,
intermediate_size,
});
if let Ok(mut g) = full_layer_state().cuda_driver_graph.lock() {
*g = None;
}
}
Ok(guard)
}
fn acquire_kv_cache(
graph: &CudaGraph,
n_layers: usize,
n_kv: usize,
max_seq: usize,
head_dim: usize,
) -> Result<std::sync::MutexGuard<'static, Option<CudaKvCache>>, CudaGraphError> {
let state = full_layer_state();
let mut guard = state
.kv_cache
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
let needs_alloc = match guard.as_ref() {
Some(c) => !c.matches(n_layers, n_kv, max_seq, head_dim),
None => true,
};
if needs_alloc {
let total_elements = n_layers * n_kv * max_seq * head_dim;
let k_cache = graph
.stream_arc()
.alloc_zeros::<u16>(total_elements)
.map_err(|e| CudaGraphError::DriverError(format!("alloc kv k_cache: {e}")))?;
let v_cache = graph
.stream_arc()
.alloc_zeros::<u16>(total_elements)
.map_err(|e| CudaGraphError::DriverError(format!("alloc kv v_cache: {e}")))?;
*guard = Some(CudaKvCache {
k_cache,
v_cache,
n_layers,
n_kv,
max_seq,
head_dim,
});
}
Ok(guard)
}
#[allow(clippy::too_many_arguments, dead_code)]
unsafe fn launch_fused_qk_norm(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_in: &CudaSlice<f32>,
d_k_in: &CudaSlice<f32>,
d_q_out: &mut CudaSlice<f32>,
d_k_out: &mut CudaSlice<f32>,
d_q_weight: &CudaSlice<f32>,
d_k_weight: &CudaSlice<f32>,
nq: u32,
nkv: u32,
head_dim: u32,
eps: f32,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (nq + nkv, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_qk_norm)
.arg(d_q_in)
.arg(d_k_in)
.arg(d_q_out)
.arg(d_k_out)
.arg(d_q_weight)
.arg(d_k_weight)
.arg(&nq)
.arg(&nkv)
.arg(&head_dim)
.arg(&eps)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_qk_norm launch: {e}")))
}
#[allow(clippy::too_many_arguments, dead_code)]
unsafe fn launch_fused_qk_rope(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_in: &CudaSlice<f32>,
d_k_in: &CudaSlice<f32>,
d_q_out: &mut CudaSlice<f32>,
d_k_out: &mut CudaSlice<f32>,
d_cos: &CudaSlice<f32>,
d_sin: &CudaSlice<f32>,
nq: u32,
nkv: u32,
half_dim: u32,
) -> Result<(), CudaGraphError> {
let grid_x = half_dim.div_ceil(64);
let cfg = LaunchConfig {
grid_dim: (grid_x, nq + nkv, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_qk_rope)
.arg(d_q_in)
.arg(d_k_in)
.arg(d_q_out)
.arg(d_k_out)
.arg(d_cos)
.arg(d_sin)
.arg(&nq)
.arg(&nkv)
.arg(&half_dim)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_qk_rope launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
unsafe fn launch_fused_qk_norm_rope(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_in: &CudaSlice<f32>,
d_k_in_view: &CudaView<'_, f32>,
d_q_out: &mut CudaSlice<f32>,
d_k_out: &mut CudaSlice<f32>,
d_q_weight: &CudaSlice<f32>,
d_k_weight: &CudaSlice<f32>,
d_cos: &CudaSlice<f32>,
d_sin: &CudaSlice<f32>,
nq: u32,
nkv: u32,
head_dim: u32,
eps: f32,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (nq + nkv, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_qk_norm_rope)
.arg(d_q_in)
.arg(d_k_in_view)
.arg(d_q_out)
.arg(d_k_out)
.arg(d_q_weight)
.arg(d_k_weight)
.arg(d_cos)
.arg(d_sin)
.arg(&nq)
.arg(&nkv)
.arg(&head_dim)
.arg(&eps)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_qk_norm_rope launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
unsafe fn launch_fused_kv_store(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_k_data: &CudaSlice<f32>,
d_v_data_view: &CudaView<'_, f32>,
d_k_cache: &mut CudaSlice<u16>,
d_v_cache: &mut CudaSlice<u16>,
head_dim: u32,
nkv: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
layer_offset: u32,
) -> Result<(), CudaGraphError> {
let grid_x = head_dim.div_ceil(64);
let cfg = LaunchConfig {
grid_dim: (grid_x, nkv, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_kv_store)
.arg(d_k_data)
.arg(d_v_data_view)
.arg(d_k_cache)
.arg(d_v_cache)
.arg(&head_dim)
.arg(&nkv)
.arg(&max_seq)
.arg(d_pos_seqlen)
.arg(&layer_offset)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_kv_store launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
unsafe fn launch_batched_attn_scores_v2(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_queries: &CudaSlice<f32>,
d_k_cache: &CudaSlice<u16>,
d_scores: &mut CudaSlice<f32>,
head_dim: u32,
n_q: u32,
n_kv: u32,
heads_per_group: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
inv_sqrt_hd: f32,
cache_layer_offset: u32,
) -> Result<(), CudaGraphError> {
const BATCH_STRIDE: u32 = 4;
let grid_y = max_seq.div_ceil(BATCH_STRIDE);
let cfg = LaunchConfig {
grid_dim: (n_q, grid_y, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_attn_scores_v2)
.arg(d_queries)
.arg(d_k_cache)
.arg(d_scores)
.arg(&head_dim)
.arg(&n_q)
.arg(&n_kv)
.arg(&heads_per_group)
.arg(&max_seq)
.arg(d_pos_seqlen)
.arg(&inv_sqrt_hd)
.arg(&cache_layer_offset)
.arg(&BATCH_STRIDE)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_attn_scores_v2 launch: {e}")))
}
unsafe fn launch_batched_softmax(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_scores: &mut CudaSlice<f32>,
n_q: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (n_q, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_softmax)
.arg(d_scores)
.arg(&n_q)
.arg(&max_seq)
.arg(d_pos_seqlen)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_softmax launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
unsafe fn launch_batched_attn_weighted_sum(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_scores: &CudaSlice<f32>,
d_v_cache: &CudaSlice<u16>,
d_attn_out: &mut CudaSlice<f32>,
head_dim: u32,
n_q: u32,
n_kv: u32,
heads_per_group: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
cache_layer_offset: u32,
) -> Result<(), CudaGraphError> {
let grid_x = head_dim.div_ceil(64);
let cfg = LaunchConfig {
grid_dim: (grid_x, n_q, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_attn_weighted_sum)
.arg(d_scores)
.arg(d_v_cache)
.arg(d_attn_out)
.arg(&head_dim)
.arg(&n_q)
.arg(&n_kv)
.arg(&heads_per_group)
.arg(&max_seq)
.arg(d_pos_seqlen)
.arg(&cache_layer_offset)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_attn_weighted_sum launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn encode_attn_phase(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_pre_norm_weight: &CudaSlice<f32>,
d_fused_qkv_weight: &Arc<CudaSlice<u8>>,
d_q_norm_weight: &CudaSlice<f32>,
d_k_norm_weight: &CudaSlice<f32>,
kv: &mut CudaKvCache,
layer_idx: usize,
_pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
bufs: &mut CudaFullLayerBuffers,
) -> Result<(), CudaGraphError> {
let h_u32 = hidden_size as u32;
let nq_u32 = nq as u32;
let nkv_u32 = nkv as u32;
let hd_u32 = head_dim as u32;
let qkv_total_rows = (nq * head_dim + 2 * nkv * head_dim) as u32;
let heads_per_group_u32 = heads_per_group as u32;
let max_seq_u32 = bufs.max_seq as u32;
let inv_sqrt_hd = 1.0f32 / (head_dim as f32).sqrt();
let layer_offset = kv.layer_offset_elements(layer_idx);
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
d_pre_norm_weight,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_pub(
d_fused_qkv_weight,
&bufs.d_normed,
&mut bufs.d_qkv,
qkv_total_rows,
h_u32,
)?;
let k_offset = nq * head_dim;
let k_in_view = bufs.d_qkv.slice(k_offset..);
launch_fused_qk_norm_rope(
graph,
mods,
&bufs.d_qkv, &k_in_view, &mut bufs.d_q_rope,
&mut bufs.d_k_rope,
d_q_norm_weight,
d_k_norm_weight,
&bufs.d_cos,
&bufs.d_sin,
nq_u32,
nkv_u32,
hd_u32,
norm_eps,
)?;
let v_offset = (nq + nkv) * head_dim;
let v_view = bufs.d_qkv.slice(v_offset..);
launch_fused_kv_store(
graph,
mods,
&bufs.d_k_rope,
&v_view,
&mut kv.k_cache,
&mut kv.v_cache,
hd_u32,
nkv_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
launch_batched_attn_scores_v2(
graph,
mods,
&bufs.d_q_rope,
&kv.k_cache,
&mut bufs.d_scores,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
inv_sqrt_hd,
layer_offset,
)?;
launch_batched_softmax(
graph,
mods,
&mut bufs.d_scores,
nq_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
)?;
launch_batched_attn_weighted_sum(
graph,
mods,
&bufs.d_scores,
&kv.v_cache,
&mut bufs.d_attn_out,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)
}
#[allow(clippy::too_many_arguments)]
pub fn encode_full_layer(
graph: &CudaGraph,
hidden: &mut [f32],
pos: usize,
layer_idx: usize,
d_pre_attn_norm: &CudaSlice<f32>,
d_fused_qkv_weight: &Arc<CudaSlice<u8>>,
d_o_weight: &Arc<CudaSlice<u8>>,
d_q_norm: &CudaSlice<f32>,
d_k_norm: &CudaSlice<f32>,
d_post_attn_norm: &CudaSlice<f32>,
d_gate_up_weight: &Arc<CudaSlice<u8>>,
d_down_weight: &Arc<CudaSlice<u8>>,
rope_cos: &[f32],
rope_sin: &[f32],
hidden_size: usize,
intermediate_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
max_seq_len: usize,
n_layers: usize,
attn_mods: &CudaAttnModules,
) -> Result<(), CudaGraphError> {
let h = hidden_size;
let half_dim = head_dim / 2;
if hidden.len() < h {
return Err(CudaGraphError::WeightLayoutError(format!(
"hidden too short: need {h}, got {}",
hidden.len()
)));
}
if rope_cos.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_cos too short: need {half_dim}, got {}",
rope_cos.len()
)));
}
if rope_sin.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_sin too short: need {half_dim}, got {}",
rope_sin.len()
)));
}
let mut fl_guard =
acquire_full_layer_buffers(graph, h, nq, nkv, head_dim, max_seq_len, intermediate_size)?;
let bufs = fl_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("full_layer_buffers not allocated".into()))?;
let mut kv_guard = acquire_kv_cache(graph, n_layers, nkv, max_seq_len, head_dim)?;
let kv = kv_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("kv_cache not allocated".into()))?;
graph
.stream_arc()
.memcpy_htod(&hidden[..h], &mut bufs.d_hidden)
.map_err(|e| CudaGraphError::DriverError(format!("upload hidden: {e}")))?;
graph
.stream_arc()
.memcpy_htod(&rope_cos[..half_dim], &mut bufs.d_cos)
.map_err(|e| CudaGraphError::DriverError(format!("upload cos: {e}")))?;
graph
.stream_arc()
.memcpy_htod(&rope_sin[..half_dim], &mut bufs.d_sin)
.map_err(|e| CudaGraphError::DriverError(format!("upload sin: {e}")))?;
unsafe {
encode_attn_phase(
graph,
attn_mods,
d_pre_attn_norm,
d_fused_qkv_weight,
d_q_norm,
d_k_norm,
kv,
layer_idx,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
h,
bufs,
)?;
}
let h_u32 = h as u32;
let inter_u32 = intermediate_size as u32;
unsafe {
let attn_out_rows = (nq * head_dim) as u32;
graph.launch_gemv_pub(
d_o_weight,
&bufs.d_attn_out,
&mut bufs.d_normed,
h_u32,
attn_out_rows,
)?;
graph.launch_residual_add_pub(&mut bufs.d_hidden, &bufs.d_normed, h_u32)?;
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
d_post_attn_norm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_pub(
d_gate_up_weight,
&bufs.d_normed,
&mut bufs.d_gate_up,
2 * inter_u32,
h_u32,
)?;
graph.launch_swiglu_pub(&bufs.d_gate_up, &mut bufs.d_swiglu, inter_u32)?;
graph.launch_gemv_pub(
d_down_weight,
&bufs.d_swiglu,
&mut bufs.d_normed,
h_u32,
inter_u32,
)?;
graph.launch_residual_add_pub(&mut bufs.d_hidden, &bufs.d_normed, h_u32)?;
}
graph
.stream_arc()
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("fl stream sync: {e}")))?;
graph
.stream_arc()
.memcpy_dtoh(&bufs.d_hidden, &mut hidden[..h])
.map_err(|e| CudaGraphError::DriverError(format!("download hidden fl: {e}")))?;
graph
.stream_arc()
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("fl D2H sync: {e}")))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_full_layer(
hidden: &mut [f32],
pos: usize,
layer_idx: usize,
pre_attn_norm_handle_id: u64,
pre_attn_norm_bytes: &[f32],
fused_qkv_handle_id: u64,
fused_qkv_bytes: &[u8],
o_handle_id: u64,
o_bytes: &[u8],
q_norm_handle_id: u64,
q_norm_bytes: &[f32],
k_norm_handle_id: u64,
k_norm_bytes: &[f32],
post_attn_norm_handle_id: u64,
post_attn_norm_bytes: &[f32],
gate_up_handle_id: u64,
gate_bytes: &[u8],
up_bytes: &[u8],
down_handle_id: u64,
down_bytes: &[u8],
rope_cos: &[f32],
rope_sin: &[f32],
hidden_size: usize,
intermediate_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
max_seq_len: usize,
n_layers: usize,
) -> Result<(), CudaGraphError> {
let graph = CudaGraph::global()?;
let attn_mods = init_attn_modules(&graph)?;
let d_fused_qkv_weight =
graph.get_or_upload_weight_soa(fused_qkv_handle_id, fused_qkv_bytes)?;
let d_o_weight = graph.get_or_upload_weight_soa(o_handle_id, o_bytes)?;
let d_gate_up_weight = graph.get_or_upload_weight_soa_lazy(gate_up_handle_id, || {
let mut fused = Vec::with_capacity(gate_bytes.len() + up_bytes.len());
fused.extend_from_slice(gate_bytes);
fused.extend_from_slice(up_bytes);
fused
})?;
let d_down_weight = graph.get_or_upload_weight_soa(down_handle_id, down_bytes)?;
let d_pre_attn_norm =
get_or_upload_f32_weight(&graph, pre_attn_norm_handle_id, pre_attn_norm_bytes)?;
let d_post_attn_norm =
get_or_upload_f32_weight(&graph, post_attn_norm_handle_id, post_attn_norm_bytes)?;
let d_q_norm = get_or_upload_f32_weight(&graph, q_norm_handle_id, q_norm_bytes)?;
let d_k_norm = get_or_upload_f32_weight(&graph, k_norm_handle_id, k_norm_bytes)?;
encode_full_layer(
&graph,
hidden,
pos,
layer_idx,
&d_pre_attn_norm,
&d_fused_qkv_weight,
&d_o_weight,
&d_q_norm,
&d_k_norm,
&d_post_attn_norm,
&d_gate_up_weight,
&d_down_weight,
rope_cos,
rope_sin,
hidden_size,
intermediate_size,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
max_seq_len,
n_layers,
&attn_mods,
)
}
#[allow(clippy::too_many_arguments)]
unsafe fn encode_layer_device(
graph: &CudaGraph,
mods: &CudaAttnModules,
weights: &CudaCachedLayerWeights,
kv: &mut CudaKvCache,
layer_idx: usize,
_pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
bufs: &mut CudaFullLayerBuffers,
) -> Result<(), CudaGraphError> {
let h_u32 = hidden_size as u32;
let nq_u32 = nq as u32;
let nkv_u32 = nkv as u32;
let hd_u32 = head_dim as u32;
let inter_u32 = intermediate_size as u32;
let qkv_total_rows = (nq * head_dim + 2 * nkv * head_dim) as u32;
let heads_per_group_u32 = heads_per_group as u32;
let max_seq_u32 = bufs.max_seq as u32;
let inv_sqrt_hd = 1.0f32 / (head_dim as f32).sqrt();
let layer_offset = kv.layer_offset_elements(layer_idx);
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&weights.pre_attn_norm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_pub(
&weights.q_weight,
&bufs.d_normed,
&mut bufs.d_qkv,
qkv_total_rows,
h_u32,
)?;
let k_offset = nq * head_dim;
let k_in_view = bufs.d_qkv.slice(k_offset..);
launch_fused_qk_norm_rope(
graph,
mods,
&bufs.d_qkv,
&k_in_view,
&mut bufs.d_q_rope,
&mut bufs.d_k_rope,
&weights.q_norm,
&weights.k_norm,
&bufs.d_cos,
&bufs.d_sin,
nq_u32,
nkv_u32,
hd_u32,
norm_eps,
)?;
let v_offset = (nq + nkv) * head_dim;
let v_view = bufs.d_qkv.slice(v_offset..);
launch_fused_kv_store(
graph,
mods,
&bufs.d_k_rope,
&v_view,
&mut kv.k_cache,
&mut kv.v_cache,
hd_u32,
nkv_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
launch_batched_attn_scores_v2(
graph,
mods,
&bufs.d_q_rope,
&kv.k_cache,
&mut bufs.d_scores,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
inv_sqrt_hd,
layer_offset,
)?;
launch_batched_softmax(
graph,
mods,
&mut bufs.d_scores,
nq_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
)?;
launch_batched_attn_weighted_sum(
graph,
mods,
&bufs.d_scores,
&kv.v_cache,
&mut bufs.d_attn_out,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
let attn_out_rows = (nq * head_dim) as u32;
graph.launch_gemv_residual_pub(
&weights.o_weight,
&bufs.d_attn_out,
&mut bufs.d_hidden,
h_u32,
attn_out_rows,
)?;
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&weights.post_attn_norm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_fused_gate_up_swiglu_pub(
&weights.gate_up_weight,
&bufs.d_normed,
&mut bufs.d_swiglu,
inter_u32,
h_u32,
)?;
graph.launch_gemv_residual_pub(
&weights.down_weight,
&bufs.d_swiglu,
&mut bufs.d_hidden,
h_u32,
inter_u32,
)?;
Ok(())
}
pub struct CudaFullForwardLayerParams<'a> {
pub attn_norm_handle: u64,
pub attn_norm_bytes: &'a [f32],
pub fused_qkv_handle: u64,
pub fused_qkv_bytes: &'a [u8],
pub q_norm_handle: u64,
pub q_norm_bytes: &'a [f32],
pub k_norm_handle: u64,
pub k_norm_bytes: &'a [f32],
pub attn_proj_handle: u64,
pub attn_proj_bytes: &'a [u8],
pub ffn_norm_handle: u64,
pub ffn_norm_bytes: &'a [f32],
pub gate_up_handle: u64,
pub gate_bytes: &'a [u8],
pub up_bytes: &'a [u8],
pub down_handle: u64,
pub down_bytes: &'a [u8],
}
#[allow(clippy::too_many_arguments)]
pub fn encode_full_forward(
graph: &Arc<CudaGraph>,
hidden_init: &[f32],
all_layer_weights: &[CudaCachedLayerWeights],
rope_cos: &[f32],
rope_sin: &[f32],
pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
max_seq_len: usize,
final_norm_weight: Option<&[f32]>,
final_norm_handle: u64,
) -> Result<Vec<f32>, CudaGraphError> {
let h = hidden_size;
let half_dim = head_dim / 2;
let n_layers = all_layer_weights.len();
let h_u32 = h as u32;
if hidden_init.len() < h {
return Err(CudaGraphError::WeightLayoutError(format!(
"hidden_init too short: need {h}, got {}",
hidden_init.len()
)));
}
if rope_cos.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_cos too short: need {half_dim}, got {}",
rope_cos.len()
)));
}
if rope_sin.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_sin too short: need {half_dim}, got {}",
rope_sin.len()
)));
}
if n_layers == 0 {
return Err(CudaGraphError::WeightLayoutError(
"encode_full_forward: no layers provided".into(),
));
}
let attn_mods = init_attn_modules(graph)?;
let mut fl_guard =
acquire_full_layer_buffers(graph, h, nq, nkv, head_dim, max_seq_len, intermediate_size)?;
let bufs = fl_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("full_layer_buffers not allocated".into()))?;
let mut kv_guard = acquire_kv_cache(graph, n_layers, nkv, max_seq_len, head_dim)?;
let kv = kv_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("kv_cache not allocated".into()))?;
let stream = graph.stream_arc();
let pos_seqlen_host = [pos as u32, (pos + 1) as u32];
unsafe {
graph
.raw_htod(&pos_seqlen_host, &mut bufs.d_pos_seqlen, 2)
.map_err(|e| CudaGraphError::DriverError(format!("upload pos_seqlen: {e}")))?;
}
{
let graph_guard = full_layer_state()
.cuda_driver_graph
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(Some(ref holder)) = *graph_guard {
unsafe {
graph.raw_htod(&hidden_init[..h], &mut bufs.d_hidden, h)?;
graph.raw_htod(&rope_cos[..half_dim], &mut bufs.d_cos, half_dim)?;
graph.raw_htod(&rope_sin[..half_dim], &mut bufs.d_sin, half_dim)?;
}
unsafe { holder.launch() }
.map_err(|e| CudaGraphError::DriverError(format!("graph launch: {e}")))?;
let mut result = vec![0.0f32; h];
unsafe { graph.raw_dtoh(&bufs.d_hidden, &mut result, h)? }
stream
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("fast-path sync: {e}")))?;
return Ok(result);
}
}
unsafe {
graph
.raw_htod(&hidden_init[..h], &mut bufs.d_hidden, h)
.map_err(|e| CudaGraphError::DriverError(format!("upload hidden_init: {e}")))?;
graph
.raw_htod(&rope_cos[..half_dim], &mut bufs.d_cos, half_dim)
.map_err(|e| CudaGraphError::DriverError(format!("upload cos ff: {e}")))?;
graph
.raw_htod(&rope_sin[..half_dim], &mut bufs.d_sin, half_dim)
.map_err(|e| CudaGraphError::DriverError(format!("upload sin ff: {e}")))?;
}
for (layer_idx, weights) in all_layer_weights.iter().enumerate() {
unsafe {
encode_layer_device(
graph,
&attn_mods,
weights,
kv,
layer_idx,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
h,
intermediate_size,
bufs,
)?;
}
}
if let Some(fnorm_data) = final_norm_weight {
let d_fnorm = get_or_upload_f32_weight(graph, final_norm_handle, fnorm_data)?;
unsafe {
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&d_fnorm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
}
stream
.memcpy_dtod(&bufs.d_normed, &mut bufs.d_hidden)
.map_err(|e| CudaGraphError::DriverError(format!("dtod normed->hidden: {e}")))?;
}
let mut result = vec![0.0f32; h];
unsafe { graph.raw_dtoh(&bufs.d_hidden, &mut result, h)? }
stream
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("ff D2H sync: {e}")))?;
{
if let Ok(ref mut graph_guard) = full_layer_state().cuda_driver_graph.lock() {
if graph_guard.is_none() {
let begin_ok = stream
.begin_capture(sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL)
.is_ok();
if !begin_ok {
warn!("CUDA graph: begin_capture failed — running without graph replay");
**graph_guard = Some(None);
} else {
let record_ok: bool = (|| -> Result<(), CudaGraphError> {
for (layer_idx, weights) in all_layer_weights.iter().enumerate() {
unsafe {
encode_layer_device(
graph,
&attn_mods,
weights,
kv,
layer_idx,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
h,
intermediate_size,
bufs,
)?;
}
}
if let Some(fnorm_data) = final_norm_weight {
let d_fnorm =
get_or_upload_f32_weight(graph, final_norm_handle, fnorm_data)?;
unsafe {
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&d_fnorm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
}
stream
.memcpy_dtod(&bufs.d_normed, &mut bufs.d_hidden)
.map_err(|e| {
CudaGraphError::DriverError(format!("dtod (capture): {e}"))
})?;
}
Ok(())
})()
.is_ok();
let end_result =
unsafe { cudarc_result::stream::end_capture(stream.cu_stream()) };
match end_result {
Ok(cu_graph_raw) if !cu_graph_raw.is_null() && record_ok => {
let inst_result = unsafe {
let mut exec = std::mem::MaybeUninit::uninit();
sys::cuGraphInstantiateWithFlags(
exec.as_mut_ptr(),
cu_graph_raw,
0u64,
)
.result()
.map(|_| exec.assume_init())
};
match inst_result {
Ok(cu_graph_exec) => {
let holder = CuGraphHolder {
cu_graph: cu_graph_raw,
cu_graph_exec,
stream: Arc::clone(&stream),
};
match unsafe { holder.upload() } {
Ok(()) => {
**graph_guard = Some(Some(holder));
debug!("CUDA graph captured and uploaded successfully");
}
Err(e) => {
warn!(
"CUDA graph upload failed: {e} — disabling replay"
);
**graph_guard = Some(None);
}
}
}
Err(e) => {
warn!("CUDA graph instantiate failed: {e} — disabling replay");
unsafe {
let _ = cudarc_result::graph::destroy(cu_graph_raw);
}
**graph_guard = Some(None);
}
}
}
Ok(_) => {
warn!("CUDA graph: end_capture returned no graph — disabling replay");
**graph_guard = Some(None);
}
Err(e) => {
warn!("CUDA graph: end_capture error: {e} — disabling replay");
**graph_guard = Some(None);
}
}
}
}
}
}
Ok(result)
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_full_forward(
hidden: &[f32],
layer_params: &[CudaFullForwardLayerParams<'_>],
rope_cos: &[f32],
rope_sin: &[f32],
pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
max_seq_len: usize,
final_norm_bytes: Option<&[f32]>,
final_norm_handle: u64,
) -> Option<Vec<f32>> {
let _t0 = profiling().then(std::time::Instant::now);
let (graph, layer_weights) = get_or_build_model_weights(layer_params)?;
let _t1 = profiling().then(std::time::Instant::now);
if profiling() {
eprintln!(
"[cuda-prof] try_ff pos={pos}: weight_lookup={:.3}ms",
(_t1.expect("profiling") - _t0.expect("profiling")).as_secs_f64() * 1000.0,
);
}
let r = encode_full_forward(
&graph,
hidden,
&*layer_weights,
rope_cos,
rope_sin,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
hidden_size,
intermediate_size,
max_seq_len,
final_norm_bytes,
final_norm_handle,
);
if profiling() {
let elapsed = _t1.expect("profiling").elapsed().as_secs_f64() * 1000.0;
let path = if pos == 0 { "slow" } else { "fast" };
eprintln!("[cuda-prof] encode_ff pos={pos} path={path}: {elapsed:.1}ms");
}
if let Err(ref e) = r {
warn!("CUDA full-forward error at pos={pos}: {e}");
}
r.ok()
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_full_forward_with_gpu_lm_head(
hidden: &[f32],
layer_params: &[CudaFullForwardLayerParams<'_>],
rope_cos: &[f32],
rope_sin: &[f32],
pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
max_seq_len: usize,
final_norm_bytes: Option<&[f32]>,
final_norm_handle: u64,
lm_head_handle: u64,
lm_head_bytes: &[u8],
vocab_size: usize,
) -> Option<Vec<f32>> {
let normed = try_cuda_full_forward(
hidden,
layer_params,
rope_cos,
rope_sin,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
hidden_size,
intermediate_size,
max_seq_len,
final_norm_bytes,
final_norm_handle,
)?;
let graph = CudaGraph::global().ok()?;
let _t_lm = profiling().then(std::time::Instant::now);
let r = graph.encode_lm_head_gemv(
&normed,
lm_head_handle,
lm_head_bytes,
vocab_size,
hidden_size,
);
if profiling() {
eprintln!(
"[cuda-prof] lm_head pos={pos}: {:.1}ms",
_t_lm.expect("profiling").elapsed().as_secs_f64() * 1000.0
);
}
r.ok()
}
#[cfg(test)]
mod tests;