use std::sync::Arc;
use cudarc::driver::sys;
use cudarc::driver::CudaSlice;
use tracing::warn;
use super::super::cuda_graph::{CudaGraph, CudaGraphError};
use super::{
acquire_full_layer_buffers, acquire_kv_cache, full_layer_state, get_or_build_model_weights,
get_or_upload_f32_weight, init_attn_modules, profiling, CuGraphHolder, CudaAttnModules,
CudaCachedLayerWeights, CudaFullForwardLayerParams, CudaFullLayerBuffers, CudaKvCache,
};
use super::launchers::{
launch_batched_attn_scores_v2, launch_batched_attn_weighted_sum, launch_batched_softmax,
launch_fused_kv_store, launch_fused_qk_norm_rope,
};
#[allow(clippy::too_many_arguments)]
pub fn encode_full_layer(
graph: &CudaGraph,
hidden: &mut [f32],
pos: usize,
layer_idx: usize,
d_pre_attn_norm: &CudaSlice<f32>,
d_fused_qkv_weight: &Arc<CudaSlice<u8>>,
d_o_weight: &Arc<CudaSlice<u8>>,
d_q_norm: &CudaSlice<f32>,
d_k_norm: &CudaSlice<f32>,
d_post_attn_norm: &CudaSlice<f32>,
d_gate_up_weight: &Arc<CudaSlice<u8>>,
d_down_weight: &Arc<CudaSlice<u8>>,
rope_cos: &[f32],
rope_sin: &[f32],
hidden_size: usize,
intermediate_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
max_seq_len: usize,
n_layers: usize,
attn_mods: &CudaAttnModules,
) -> Result<(), CudaGraphError> {
let h = hidden_size;
let half_dim = head_dim / 2;
if hidden.len() < h {
return Err(CudaGraphError::WeightLayoutError(format!(
"hidden too short: need {h}, got {}",
hidden.len()
)));
}
if rope_cos.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_cos too short: need {half_dim}, got {}",
rope_cos.len()
)));
}
if rope_sin.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_sin too short: need {half_dim}, got {}",
rope_sin.len()
)));
}
let mut fl_guard =
acquire_full_layer_buffers(graph, h, nq, nkv, head_dim, max_seq_len, intermediate_size)?;
let bufs = fl_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("full_layer_buffers not allocated".into()))?;
let mut kv_guard = acquire_kv_cache(graph, n_layers, nkv, max_seq_len, head_dim)?;
let kv = kv_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("kv_cache not allocated".into()))?;
graph
.stream_arc()
.memcpy_htod(&hidden[..h], &mut bufs.d_hidden)
.map_err(|e| CudaGraphError::DriverError(format!("upload hidden: {e}")))?;
graph
.stream_arc()
.memcpy_htod(&rope_cos[..half_dim], &mut bufs.d_cos)
.map_err(|e| CudaGraphError::DriverError(format!("upload cos: {e}")))?;
graph
.stream_arc()
.memcpy_htod(&rope_sin[..half_dim], &mut bufs.d_sin)
.map_err(|e| CudaGraphError::DriverError(format!("upload sin: {e}")))?;
unsafe {
super::encode_attn_phase(
graph,
attn_mods,
d_pre_attn_norm,
d_fused_qkv_weight,
d_q_norm,
d_k_norm,
kv,
layer_idx,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
h,
bufs,
)?;
}
let h_u32 = h as u32;
let inter_u32 = intermediate_size as u32;
unsafe {
let attn_out_rows = (nq * head_dim) as u32;
graph.launch_gemv_pub(
d_o_weight,
&bufs.d_attn_out,
&mut bufs.d_normed,
h_u32,
attn_out_rows,
)?;
graph.launch_residual_add_pub(&mut bufs.d_hidden, &bufs.d_normed, h_u32)?;
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
d_post_attn_norm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_pub(
d_gate_up_weight,
&bufs.d_normed,
&mut bufs.d_gate_up,
2 * inter_u32,
h_u32,
)?;
graph.launch_swiglu_pub(&bufs.d_gate_up, &mut bufs.d_swiglu, inter_u32)?;
graph.launch_gemv_pub(
d_down_weight,
&bufs.d_swiglu,
&mut bufs.d_normed,
h_u32,
inter_u32,
)?;
graph.launch_residual_add_pub(&mut bufs.d_hidden, &bufs.d_normed, h_u32)?;
}
graph
.stream_arc()
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("fl stream sync: {e}")))?;
graph
.stream_arc()
.memcpy_dtoh(&bufs.d_hidden, &mut hidden[..h])
.map_err(|e| CudaGraphError::DriverError(format!("download hidden fl: {e}")))?;
graph
.stream_arc()
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("fl D2H sync: {e}")))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_full_layer(
hidden: &mut [f32],
pos: usize,
layer_idx: usize,
pre_attn_norm_handle_id: u64,
pre_attn_norm_bytes: &[f32],
fused_qkv_handle_id: u64,
fused_qkv_bytes: &[u8],
o_handle_id: u64,
o_bytes: &[u8],
q_norm_handle_id: u64,
q_norm_bytes: &[f32],
k_norm_handle_id: u64,
k_norm_bytes: &[f32],
post_attn_norm_handle_id: u64,
post_attn_norm_bytes: &[f32],
gate_up_handle_id: u64,
gate_bytes: &[u8],
up_bytes: &[u8],
down_handle_id: u64,
down_bytes: &[u8],
rope_cos: &[f32],
rope_sin: &[f32],
hidden_size: usize,
intermediate_size: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
max_seq_len: usize,
n_layers: usize,
) -> Result<(), CudaGraphError> {
let graph = CudaGraph::global()?;
let attn_mods = init_attn_modules(&graph)?;
let d_fused_qkv_weight =
graph.get_or_upload_weight_soa(fused_qkv_handle_id, fused_qkv_bytes)?;
let d_o_weight = graph.get_or_upload_weight_soa(o_handle_id, o_bytes)?;
let d_gate_up_weight = graph.get_or_upload_weight_soa_lazy(gate_up_handle_id, || {
let mut fused = Vec::with_capacity(gate_bytes.len() + up_bytes.len());
fused.extend_from_slice(gate_bytes);
fused.extend_from_slice(up_bytes);
fused
})?;
let d_down_weight = graph.get_or_upload_weight_soa(down_handle_id, down_bytes)?;
let d_pre_attn_norm =
get_or_upload_f32_weight(&graph, pre_attn_norm_handle_id, pre_attn_norm_bytes)?;
let d_post_attn_norm =
get_or_upload_f32_weight(&graph, post_attn_norm_handle_id, post_attn_norm_bytes)?;
let d_q_norm = get_or_upload_f32_weight(&graph, q_norm_handle_id, q_norm_bytes)?;
let d_k_norm = get_or_upload_f32_weight(&graph, k_norm_handle_id, k_norm_bytes)?;
encode_full_layer(
&graph,
hidden,
pos,
layer_idx,
&d_pre_attn_norm,
&d_fused_qkv_weight,
&d_o_weight,
&d_q_norm,
&d_k_norm,
&d_post_attn_norm,
&d_gate_up_weight,
&d_down_weight,
rope_cos,
rope_sin,
hidden_size,
intermediate_size,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
max_seq_len,
n_layers,
&attn_mods,
)
}
#[allow(clippy::too_many_arguments)]
unsafe fn encode_layer_device(
graph: &CudaGraph,
mods: &CudaAttnModules,
weights: &CudaCachedLayerWeights,
kv: &mut CudaKvCache,
layer_idx: usize,
_pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
bufs: &mut CudaFullLayerBuffers,
) -> Result<(), CudaGraphError> {
let h_u32 = hidden_size as u32;
let nq_u32 = nq as u32;
let nkv_u32 = nkv as u32;
let hd_u32 = head_dim as u32;
let inter_u32 = intermediate_size as u32;
let qkv_total_rows = (nq * head_dim + 2 * nkv * head_dim) as u32;
let heads_per_group_u32 = heads_per_group as u32;
let max_seq_u32 = bufs.max_seq as u32;
let inv_sqrt_hd = 1.0f32 / (head_dim as f32).sqrt();
let layer_offset = kv.layer_offset_elements(layer_idx);
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&weights.pre_attn_norm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_gemv_pub(
&weights.q_weight,
&bufs.d_normed,
&mut bufs.d_qkv,
qkv_total_rows,
h_u32,
)?;
let k_offset = nq * head_dim;
let k_in_view = bufs.d_qkv.slice(k_offset..);
launch_fused_qk_norm_rope(
graph,
mods,
&bufs.d_qkv,
&k_in_view,
&mut bufs.d_q_rope,
&mut bufs.d_k_rope,
&weights.q_norm,
&weights.k_norm,
&bufs.d_cos,
&bufs.d_sin,
nq_u32,
nkv_u32,
hd_u32,
norm_eps,
)?;
let v_offset = (nq + nkv) * head_dim;
let v_view = bufs.d_qkv.slice(v_offset..);
launch_fused_kv_store(
graph,
mods,
&bufs.d_k_rope,
&v_view,
&mut kv.k_cache,
&mut kv.v_cache,
hd_u32,
nkv_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
launch_batched_attn_scores_v2(
graph,
mods,
&bufs.d_q_rope,
&kv.k_cache,
&mut bufs.d_scores,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
inv_sqrt_hd,
layer_offset,
)?;
launch_batched_softmax(
graph,
mods,
&mut bufs.d_scores,
nq_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
)?;
launch_batched_attn_weighted_sum(
graph,
mods,
&bufs.d_scores,
&kv.v_cache,
&mut bufs.d_attn_out,
hd_u32,
nq_u32,
nkv_u32,
heads_per_group_u32,
max_seq_u32,
&bufs.d_pos_seqlen,
layer_offset,
)?;
let attn_out_rows = (nq * head_dim) as u32;
graph.launch_gemv_residual_pub(
&weights.o_weight,
&bufs.d_attn_out,
&mut bufs.d_hidden,
h_u32,
attn_out_rows,
)?;
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&weights.post_attn_norm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
graph.launch_fused_gate_up_swiglu_pub(
&weights.gate_up_weight,
&bufs.d_normed,
&mut bufs.d_swiglu,
inter_u32,
h_u32,
)?;
graph.launch_gemv_residual_pub(
&weights.down_weight,
&bufs.d_swiglu,
&mut bufs.d_hidden,
h_u32,
inter_u32,
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn encode_full_forward(
graph: &Arc<CudaGraph>,
hidden_init: &[f32],
all_layer_weights: &[CudaCachedLayerWeights],
rope_cos: &[f32],
rope_sin: &[f32],
pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
max_seq_len: usize,
final_norm_weight: Option<&[f32]>,
final_norm_handle: u64,
) -> Result<Vec<f32>, CudaGraphError> {
let h = hidden_size;
let half_dim = head_dim / 2;
let n_layers = all_layer_weights.len();
let h_u32 = h as u32;
if hidden_init.len() < h {
return Err(CudaGraphError::WeightLayoutError(format!(
"hidden_init too short: need {h}, got {}",
hidden_init.len()
)));
}
if rope_cos.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_cos too short: need {half_dim}, got {}",
rope_cos.len()
)));
}
if rope_sin.len() < half_dim {
return Err(CudaGraphError::WeightLayoutError(format!(
"rope_sin too short: need {half_dim}, got {}",
rope_sin.len()
)));
}
if n_layers == 0 {
return Err(CudaGraphError::WeightLayoutError(
"encode_full_forward: no layers provided".into(),
));
}
let attn_mods = init_attn_modules(graph)?;
let mut fl_guard =
acquire_full_layer_buffers(graph, h, nq, nkv, head_dim, max_seq_len, intermediate_size)?;
let bufs = fl_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("full_layer_buffers not allocated".into()))?;
let mut kv_guard = acquire_kv_cache(graph, n_layers, nkv, max_seq_len, head_dim)?;
let kv = kv_guard
.as_mut()
.ok_or_else(|| CudaGraphError::DriverError("kv_cache not allocated".into()))?;
let stream = graph.stream_arc();
let pos_seqlen_host = [pos as u32, (pos + 1) as u32];
unsafe {
graph
.raw_htod(&pos_seqlen_host, &mut bufs.d_pos_seqlen, 2)
.map_err(|e| CudaGraphError::DriverError(format!("upload pos_seqlen: {e}")))?;
}
{
let graph_guard = full_layer_state()
.cuda_driver_graph
.lock()
.map_err(|_| CudaGraphError::LockPoisoned)?;
if let Some(Some(ref holder)) = *graph_guard {
unsafe {
graph.raw_htod(&hidden_init[..h], &mut bufs.d_hidden, h)?;
graph.raw_htod(&rope_cos[..half_dim], &mut bufs.d_cos, half_dim)?;
graph.raw_htod(&rope_sin[..half_dim], &mut bufs.d_sin, half_dim)?;
}
unsafe { holder.launch() }
.map_err(|e| CudaGraphError::DriverError(format!("graph launch: {e}")))?;
let mut result = vec![0.0f32; h];
unsafe { graph.raw_dtoh(&bufs.d_hidden, &mut result, h)? }
stream
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("fast-path sync: {e}")))?;
return Ok(result);
}
}
unsafe {
graph
.raw_htod(&hidden_init[..h], &mut bufs.d_hidden, h)
.map_err(|e| CudaGraphError::DriverError(format!("upload hidden_init: {e}")))?;
graph
.raw_htod(&rope_cos[..half_dim], &mut bufs.d_cos, half_dim)
.map_err(|e| CudaGraphError::DriverError(format!("upload cos ff: {e}")))?;
graph
.raw_htod(&rope_sin[..half_dim], &mut bufs.d_sin, half_dim)
.map_err(|e| CudaGraphError::DriverError(format!("upload sin ff: {e}")))?;
}
for (layer_idx, weights) in all_layer_weights.iter().enumerate() {
unsafe {
encode_layer_device(
graph,
&attn_mods,
weights,
kv,
layer_idx,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
h,
intermediate_size,
bufs,
)?;
}
}
if let Some(fnorm_data) = final_norm_weight {
let d_fnorm = get_or_upload_f32_weight(graph, final_norm_handle, fnorm_data)?;
unsafe {
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&d_fnorm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
}
stream
.memcpy_dtod(&bufs.d_normed, &mut bufs.d_hidden)
.map_err(|e| CudaGraphError::DriverError(format!("dtod normed->hidden: {e}")))?;
}
let mut result = vec![0.0f32; h];
unsafe { graph.raw_dtoh(&bufs.d_hidden, &mut result, h)? }
stream
.synchronize()
.map_err(|e| CudaGraphError::DriverError(format!("ff D2H sync: {e}")))?;
{
if let Ok(ref mut graph_guard) = full_layer_state().cuda_driver_graph.lock() {
if graph_guard.is_none() {
let begin_ok = stream
.begin_capture(sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL)
.is_ok();
if !begin_ok {
warn!("CUDA graph: begin_capture failed — running without graph replay");
**graph_guard = Some(None);
} else {
let record_ok: bool = (|| -> Result<(), CudaGraphError> {
for (layer_idx, weights) in all_layer_weights.iter().enumerate() {
unsafe {
encode_layer_device(
graph,
&attn_mods,
weights,
kv,
layer_idx,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
h,
intermediate_size,
bufs,
)?;
}
}
if let Some(fnorm_data) = final_norm_weight {
let d_fnorm =
get_or_upload_f32_weight(graph, final_norm_handle, fnorm_data)?;
unsafe {
graph.launch_rmsnorm_pub(
&bufs.d_hidden,
&d_fnorm,
&mut bufs.d_normed,
h_u32,
norm_eps,
)?;
}
stream
.memcpy_dtod(&bufs.d_normed, &mut bufs.d_hidden)
.map_err(|e| {
CudaGraphError::DriverError(format!("dtod (capture): {e}"))
})?;
}
Ok(())
})()
.is_ok();
let end_result =
unsafe { cudarc::driver::result::stream::end_capture(stream.cu_stream()) };
match end_result {
Ok(cu_graph_raw) if !cu_graph_raw.is_null() && record_ok => {
let inst_result = unsafe {
let mut exec = std::mem::MaybeUninit::uninit();
sys::cuGraphInstantiateWithFlags(
exec.as_mut_ptr(),
cu_graph_raw,
0u64,
)
.result()
.map(|_| exec.assume_init())
};
match inst_result {
Ok(cu_graph_exec) => {
let holder = CuGraphHolder {
cu_graph: cu_graph_raw,
cu_graph_exec,
stream: Arc::clone(stream),
};
match unsafe { holder.upload() } {
Ok(()) => {
**graph_guard = Some(Some(holder));
tracing::debug!(
"CUDA graph captured and uploaded successfully"
);
}
Err(e) => {
warn!(
"CUDA graph upload failed: {e} — disabling replay"
);
**graph_guard = Some(None);
}
}
}
Err(e) => {
warn!("CUDA graph instantiate failed: {e} — disabling replay");
unsafe {
let _ =
cudarc::driver::result::graph::destroy(cu_graph_raw);
}
**graph_guard = Some(None);
}
}
}
Ok(_) => {
warn!("CUDA graph: end_capture returned no graph — disabling replay");
**graph_guard = Some(None);
}
Err(e) => {
warn!("CUDA graph: end_capture error: {e} — disabling replay");
**graph_guard = Some(None);
}
}
}
}
}
}
Ok(result)
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_full_forward(
hidden: &[f32],
layer_params: &[CudaFullForwardLayerParams<'_>],
rope_cos: &[f32],
rope_sin: &[f32],
pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
max_seq_len: usize,
final_norm_bytes: Option<&[f32]>,
final_norm_handle: u64,
) -> Option<Vec<f32>> {
let _t0 = profiling().then(std::time::Instant::now);
let (graph, layer_weights) = get_or_build_model_weights(layer_params)?;
let _t1 = profiling().then(std::time::Instant::now);
if profiling() {
eprintln!(
"[cuda-prof] try_ff pos={pos}: weight_lookup={:.3}ms",
(_t1.expect("profiling") - _t0.expect("profiling")).as_secs_f64() * 1000.0,
);
}
let r = encode_full_forward(
&graph,
hidden,
&layer_weights,
rope_cos,
rope_sin,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
hidden_size,
intermediate_size,
max_seq_len,
final_norm_bytes,
final_norm_handle,
);
if profiling() {
let elapsed = _t1.expect("profiling").elapsed().as_secs_f64() * 1000.0;
let path = if pos == 0 { "slow" } else { "fast" };
eprintln!("[cuda-prof] encode_ff pos={pos} path={path}: {elapsed:.1}ms");
}
if let Err(ref e) = r {
warn!("CUDA full-forward error at pos={pos}: {e}");
}
r.ok()
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_full_forward_with_gpu_lm_head(
hidden: &[f32],
layer_params: &[CudaFullForwardLayerParams<'_>],
rope_cos: &[f32],
rope_sin: &[f32],
pos: usize,
nq: usize,
nkv: usize,
head_dim: usize,
heads_per_group: usize,
norm_eps: f32,
hidden_size: usize,
intermediate_size: usize,
max_seq_len: usize,
final_norm_bytes: Option<&[f32]>,
final_norm_handle: u64,
lm_head_handle: u64,
lm_head_bytes: &[u8],
vocab_size: usize,
) -> Option<Vec<f32>> {
let normed = try_cuda_full_forward(
hidden,
layer_params,
rope_cos,
rope_sin,
pos,
nq,
nkv,
head_dim,
heads_per_group,
norm_eps,
hidden_size,
intermediate_size,
max_seq_len,
final_norm_bytes,
final_norm_handle,
)?;
let graph = CudaGraph::global().ok()?;
let _t_lm = profiling().then(std::time::Instant::now);
let r = graph.encode_lm_head_gemv(
&normed,
lm_head_handle,
lm_head_bytes,
vocab_size,
hidden_size,
);
if profiling() {
eprintln!(
"[cuda-prof] lm_head pos={pos}: {:.1}ms",
_t_lm.expect("profiling").elapsed().as_secs_f64() * 1000.0
);
}
r.ok()
}