use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg};
use crate::gpu_backend::cuda_graph::{CudaGraph, CudaGraphError};
use super::state::CudaPrefillModules;
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_gemm_v7(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_blocks: &CudaSlice<u8>,
d_inputs: &CudaSlice<f32>,
d_outputs: &mut CudaSlice<f32>,
n_rows: u32,
k: u32,
batch_size: u32,
) -> Result<(), CudaGraphError> {
let grid_x = n_rows.div_ceil(8);
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.gemm_v7)
.arg(d_blocks)
.arg(d_inputs)
.arg(d_outputs)
.arg(&n_rows)
.arg(&k)
.arg(&batch_size)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("gemm_v7 launch: {e}")))
}
#[allow(clippy::too_many_arguments, dead_code)]
pub(super) unsafe fn launch_gemm_v7_residual(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_blocks: &CudaSlice<u8>,
d_inputs: &CudaSlice<f32>,
d_outputs: &mut CudaSlice<f32>,
n_rows: u32,
k: u32,
batch_size: u32,
d_residual: &CudaSlice<f32>,
) -> Result<(), CudaGraphError> {
let grid_x = n_rows.div_ceil(8);
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.gemm_v7_residual)
.arg(d_blocks)
.arg(d_inputs)
.arg(d_outputs)
.arg(&n_rows)
.arg(&k)
.arg(&batch_size)
.arg(d_residual)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("gemm_v7_residual launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_fused_gate_up_swiglu_gemm(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_blocks: &CudaSlice<u8>,
d_inputs: &CudaSlice<f32>,
d_outputs: &mut CudaSlice<f32>,
n_rows: u32,
k: u32,
batch_size: u32,
) -> Result<(), CudaGraphError> {
let grid_x = n_rows.div_ceil(8);
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_gate_up_swiglu_gemm)
.arg(d_blocks)
.arg(d_inputs)
.arg(d_outputs)
.arg(&n_rows)
.arg(&k)
.arg(&batch_size)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("fused_gate_up_swiglu_gemm launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_batched_rmsnorm(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_input: &CudaSlice<f32>,
d_weight: &CudaSlice<f32>,
d_output: &mut CudaSlice<f32>,
n: u32,
batch_size: u32,
eps: f32,
) -> Result<(), CudaGraphError> {
let cfg = LaunchConfig {
grid_dim: (batch_size, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.batched_rmsnorm)
.arg(d_input)
.arg(d_weight)
.arg(d_output)
.arg(&n)
.arg(&batch_size)
.arg(&eps)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("batched_rmsnorm launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_gemm_tq2_v7(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_soa_raw: &CudaSlice<u8>,
d_inputs: &CudaSlice<f32>,
d_outputs: &mut CudaSlice<f32>,
n_rows: u32,
k: u32,
batch_size: u32,
) -> Result<(), CudaGraphError> {
let grid_x = n_rows.div_ceil(8);
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.gemm_tq2_v7)
.arg(d_soa_raw)
.arg(d_inputs)
.arg(d_outputs)
.arg(&n_rows)
.arg(&k)
.arg(&batch_size)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("gemm_tq2_v7 launch: {e}")))
}
#[allow(clippy::too_many_arguments, dead_code)]
pub(super) unsafe fn launch_gemm_tq2_v7_residual(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_soa_raw: &CudaSlice<u8>,
d_inputs: &CudaSlice<f32>,
d_outputs: &mut CudaSlice<f32>,
n_rows: u32,
k: u32,
batch_size: u32,
d_residual: &CudaSlice<f32>,
) -> Result<(), CudaGraphError> {
let grid_x = n_rows.div_ceil(8);
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.gemm_tq2_v7_residual)
.arg(d_soa_raw)
.arg(d_inputs)
.arg(d_outputs)
.arg(&n_rows)
.arg(&k)
.arg(&batch_size)
.arg(d_residual)
.launch(cfg)
.map(|_| ())
.map_err(|e| CudaGraphError::DriverError(format!("gemm_tq2_v7_residual launch: {e}")))
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn launch_fused_gate_up_swiglu_gemm_tq2(
graph: &CudaGraph,
mods: &CudaPrefillModules,
d_soa_raw: &CudaSlice<u8>,
d_inputs: &CudaSlice<f32>,
d_outputs: &mut CudaSlice<f32>,
n_rows: u32,
k: u32,
batch_size: u32,
) -> Result<(), CudaGraphError> {
let grid_x = n_rows.div_ceil(8);
let cfg = LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
graph
.stream_arc()
.launch_builder(&mods.fused_gate_up_swiglu_gemm_tq2)
.arg(d_soa_raw)
.arg(d_inputs)
.arg(d_outputs)
.arg(&n_rows)
.arg(&k)
.arg(&batch_size)
.launch(cfg)
.map(|_| ())
.map_err(|e| {
CudaGraphError::DriverError(format!("fused_gate_up_swiglu_gemm_tq2 launch: {e}"))
})
}