use cudarc::driver::{CudaSlice, CudaView, LaunchConfig, PushKernelArg};
use super::super::cuda_graph::{CudaGraph, CudaGraphError};
use super::CudaAttnModules;
#[allow(clippy::too_many_arguments, dead_code)]
pub(super) unsafe fn launch_fused_qk_norm(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_in: &CudaSlice<f32>,
d_k_in: &CudaSlice<f32>,
d_q_out: &mut CudaSlice<f32>,
d_k_out: &mut CudaSlice<f32>,
d_q_weight: &CudaSlice<f32>,
d_k_weight: &CudaSlice<f32>,
nq: u32,
nkv: u32,
head_dim: u32,
eps: f32,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (nq + nkv, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_qk_norm)
.arg(d_q_in)
.arg(d_k_in)
.arg(d_q_out)
.arg(d_k_out)
.arg(d_q_weight)
.arg(d_k_weight)
.arg(&nq)
.arg(&nkv)
.arg(&head_dim)
.arg(&eps)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_qk_norm launch: {e}")))
}
#[allow(clippy::too_many_arguments, dead_code)]
pub(super) unsafe fn launch_fused_qk_rope(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_in: &CudaSlice<f32>,
d_k_in: &CudaSlice<f32>,
d_q_out: &mut CudaSlice<f32>,
d_k_out: &mut CudaSlice<f32>,
d_cos: &CudaSlice<f32>,
d_sin: &CudaSlice<f32>,
nq: u32,
nkv: u32,
half_dim: u32,
) -> Result<(), CudaGraphError> {
let grid_x = half_dim.div_ceil(64);
let cfg = LaunchConfig {
grid_dim: (grid_x, nq + nkv, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_qk_rope)
.arg(d_q_in)
.arg(d_k_in)
.arg(d_q_out)
.arg(d_k_out)
.arg(d_cos)
.arg(d_sin)
.arg(&nq)
.arg(&nkv)
.arg(&half_dim)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_qk_rope launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_fused_qk_norm_rope(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_q_in: &CudaSlice<f32>,
d_k_in_view: &CudaView<'_, f32>,
d_q_out: &mut CudaSlice<f32>,
d_k_out: &mut CudaSlice<f32>,
d_q_weight: &CudaSlice<f32>,
d_k_weight: &CudaSlice<f32>,
d_cos: &CudaSlice<f32>,
d_sin: &CudaSlice<f32>,
nq: u32,
nkv: u32,
head_dim: u32,
eps: f32,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (nq + nkv, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_qk_norm_rope)
.arg(d_q_in)
.arg(d_k_in_view)
.arg(d_q_out)
.arg(d_k_out)
.arg(d_q_weight)
.arg(d_k_weight)
.arg(d_cos)
.arg(d_sin)
.arg(&nq)
.arg(&nkv)
.arg(&head_dim)
.arg(&eps)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_qk_norm_rope launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_fused_kv_store(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_k_data: &CudaSlice<f32>,
d_v_data_view: &CudaView<'_, f32>,
d_k_cache: &mut CudaSlice<u16>,
d_v_cache: &mut CudaSlice<u16>,
head_dim: u32,
nkv: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
layer_offset: u32,
) -> Result<(), CudaGraphError> {
let grid_x = head_dim.div_ceil(64);
let cfg = LaunchConfig {
grid_dim: (grid_x, nkv, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_kv_store)
.arg(d_k_data)
.arg(d_v_data_view)
.arg(d_k_cache)
.arg(d_v_cache)
.arg(&head_dim)
.arg(&nkv)
.arg(&max_seq)
.arg(d_pos_seqlen)
.arg(&layer_offset)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_kv_store launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_batched_attn_scores_v2(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_queries: &CudaSlice<f32>,
d_k_cache: &CudaSlice<u16>,
d_scores: &mut CudaSlice<f32>,
head_dim: u32,
n_q: u32,
n_kv: u32,
heads_per_group: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
inv_sqrt_hd: f32,
cache_layer_offset: u32,
) -> Result<(), CudaGraphError> {
const BATCH_STRIDE: u32 = 4;
let grid_y = max_seq.div_ceil(BATCH_STRIDE);
let cfg = LaunchConfig {
grid_dim: (n_q, grid_y, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_attn_scores_v2)
.arg(d_queries)
.arg(d_k_cache)
.arg(d_scores)
.arg(&head_dim)
.arg(&n_q)
.arg(&n_kv)
.arg(&heads_per_group)
.arg(&max_seq)
.arg(d_pos_seqlen)
.arg(&inv_sqrt_hd)
.arg(&cache_layer_offset)
.arg(&BATCH_STRIDE)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_attn_scores_v2 launch: {e}")))
}
pub(super) unsafe fn launch_batched_softmax(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_scores: &mut CudaSlice<f32>,
n_q: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (n_q, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_softmax)
.arg(d_scores)
.arg(&n_q)
.arg(&max_seq)
.arg(d_pos_seqlen)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_softmax launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_batched_attn_weighted_sum(
graph: &CudaGraph,
mods: &CudaAttnModules,
d_scores: &CudaSlice<f32>,
d_v_cache: &CudaSlice<u16>,
d_attn_out: &mut CudaSlice<f32>,
head_dim: u32,
n_q: u32,
n_kv: u32,
heads_per_group: u32,
max_seq: u32,
d_pos_seqlen: &CudaSlice<u32>,
cache_layer_offset: u32,
) -> Result<(), CudaGraphError> {
let grid_x = head_dim.div_ceil(64);
let cfg = LaunchConfig {
grid_dim: (grid_x, n_q, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_attn_weighted_sum)
.arg(d_scores)
.arg(d_v_cache)
.arg(d_attn_out)
.arg(&head_dim)
.arg(&n_q)
.arg(&n_kv)
.arg(&heads_per_group)
.arg(&max_seq)
.arg(d_pos_seqlen)
.arg(&cache_layer_offset)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_attn_weighted_sum launch: {e}")))
}