use std::sync::Arc;
use cudarc::driver::{CudaSlice, CudaView};
use crate::gpu_backend::cuda_full_layer::{
encode_attn_phase_tq2, CudaAttnModules, CudaFullLayerBuffers, CudaKvCache,
};
use crate::gpu_backend::cuda_graph::{CudaGraph, CudaGraphError};
use super::launchers::{
launch_batched_rmsnorm, launch_fused_gate_up_swiglu_gemm_tq2, launch_gemm_tq2_v7,
};
use super::state::{CudaPrefillBuffers, CudaPrefillModules};
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn encode_prefill_ffn_phase_ternary(
graph: &CudaGraph,
pmods: &CudaPrefillModules,
d_ffn_norm_weight: &CudaSlice<f32>,
d_gate_up_weight: &Arc<CudaSlice<u8>>,
d_down_weight: &Arc<CudaSlice<u8>>,
pb: &mut CudaPrefillBuffers,
eps: f32,
) -> Result<(), CudaGraphError> {
let bs = pb.actual_batch_size as u32;
let h = pb.hidden_size as u32;
let inter = pb.intermediate_size as u32;
launch_batched_rmsnorm(
graph,
pmods,
&pb.d_input,
d_ffn_norm_weight,
&mut pb.d_normed,
h,
bs,
eps,
)?;
launch_fused_gate_up_swiglu_gemm_tq2(
graph,
pmods,
d_gate_up_weight,
&pb.d_normed,
&mut pb.d_swiglu,
inter,
h,
bs,
)?;
{
let n = pb.actual_batch_size * pb.hidden_size;
let mut dst_view = pb.d_normed.slice_mut(0..n);
graph
.stream_arc()
.memset_zeros(&mut dst_view)
.map_err(|e| CudaGraphError::DriverError(format!("zero d_normed tq2 down: {e}")))?;
}
launch_gemm_tq2_v7(
graph,
pmods,
d_down_weight,
&pb.d_swiglu,
&mut pb.d_normed,
h,
inter,
bs,
)?;
let total_bh = (pb.actual_batch_size * pb.hidden_size) as u32;
graph.launch_residual_add_pub(&mut pb.d_input, &pb.d_normed, total_bh)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn encode_prefill_layer_ternary(
graph: &CudaGraph,
pmods: &CudaPrefillModules,
attn_mods: &CudaAttnModules,
d_attn_norm_weight: &CudaSlice<f32>,
d_fused_qkv_weight: &Arc<CudaSlice<u8>>,
d_q_norm_weight: &CudaSlice<f32>,
d_k_norm_weight: &CudaSlice<f32>,
d_attn_proj_weight: &Arc<CudaSlice<u8>>,
d_ffn_norm_weight: &CudaSlice<f32>,
d_gate_up_weight: &Arc<CudaSlice<u8>>,
d_down_weight: &Arc<CudaSlice<u8>>,
kv: &mut CudaKvCache,
layer_idx: usize,
pos_start: usize,
pb: &mut CudaPrefillBuffers,
st_bufs: &mut CudaFullLayerBuffers,
cos_table: &[f32],
sin_table: &[f32],
heads_per_group: usize,
eps: f32,
) -> Result<(), CudaGraphError> {
let bs = pb.actual_batch_size;
let h = pb.hidden_size;
let nq = pb.nq;
let nkv = pb.nkv;
let hd = pb.head_dim;
let half_dim = hd / 2;
let h_u32 = h as u32;
let bs_u32 = bs as u32;
let qkv_total = nq * hd + 2 * nkv * hd;
launch_batched_rmsnorm(
graph,
pmods,
&pb.d_input,
d_attn_norm_weight,
&mut pb.d_normed,
h_u32,
bs_u32,
eps,
)?;
{
let n = bs * qkv_total;
let mut dst_view = pb.d_qkv.slice_mut(0..n);
graph
.stream_arc()
.memset_zeros(&mut dst_view)
.map_err(|e| CudaGraphError::DriverError(format!("zero d_qkv tq2: {e}")))?;
}
launch_gemm_tq2_v7(
graph,
pmods,
d_fused_qkv_weight,
&pb.d_normed,
&mut pb.d_qkv,
qkv_total as u32,
h_u32,
bs_u32,
)?;
{
let n = bs * nq * hd;
let mut dst_view = pb.d_attn_out.slice_mut(0..n);
graph
.stream_arc()
.memset_zeros(&mut dst_view)
.map_err(|e| CudaGraphError::DriverError(format!("zero d_attn_out tq2: {e}")))?;
}
for t in 0..bs {
let pos = pos_start + t;
{
let src_view: CudaView<f32> = pb.d_input.slice(t * h..(t + 1) * h);
graph
.stream_arc()
.memcpy_dtod(&src_view, &mut st_bufs.d_hidden)
.map_err(|e| CudaGraphError::DriverError(format!("copy hidden tq2 t={t}: {e}")))?;
}
let pos_seqlen = [pos as u32, (pos + 1) as u32];
graph
.stream_arc()
.memcpy_htod(&pos_seqlen, &mut st_bufs.d_pos_seqlen)
.map_err(|e| {
CudaGraphError::DriverError(format!("upload pos_seqlen tq2 t={t}: {e}"))
})?;
let rope_off = t * half_dim;
graph
.stream_arc()
.memcpy_htod(
&cos_table[rope_off..rope_off + half_dim],
&mut st_bufs.d_cos,
)
.map_err(|e| CudaGraphError::DriverError(format!("upload cos tq2 t={t}: {e}")))?;
graph
.stream_arc()
.memcpy_htod(
&sin_table[rope_off..rope_off + half_dim],
&mut st_bufs.d_sin,
)
.map_err(|e| CudaGraphError::DriverError(format!("upload sin tq2 t={t}: {e}")))?;
encode_attn_phase_tq2(
graph,
attn_mods,
d_attn_norm_weight,
d_fused_qkv_weight,
d_q_norm_weight,
d_k_norm_weight,
kv,
layer_idx,
nq,
nkv,
hd,
heads_per_group,
eps,
h,
st_bufs,
)?;
{
let src_view: CudaView<f32> = st_bufs.d_attn_out.slice(..nq * hd);
let mut dst_view = pb.d_attn_out.slice_mut(t * nq * hd..(t + 1) * nq * hd);
graph
.stream_arc()
.memcpy_dtod(&src_view, &mut dst_view)
.map_err(|e| {
CudaGraphError::DriverError(format!("copy attn_out tq2 t={t}: {e}"))
})?;
}
}
{
let n = bs * h;
let mut dst_view = pb.d_normed.slice_mut(0..n);
graph
.stream_arc()
.memset_zeros(&mut dst_view)
.map_err(|e| CudaGraphError::DriverError(format!("zero d_normed tq2 oproj: {e}")))?;
}
launch_gemm_tq2_v7(
graph,
pmods,
d_attn_proj_weight,
&pb.d_attn_out,
&mut pb.d_normed,
h_u32,
(nq * hd) as u32,
bs_u32,
)?;
let total_oproj = (bs * h) as u32;
graph.launch_residual_add_pub(&mut pb.d_input, &pb.d_normed, total_oproj)?;
encode_prefill_ffn_phase_ternary(
graph,
pmods,
d_ffn_norm_weight,
d_gate_up_weight,
d_down_weight,
pb,
eps,
)?;
Ok(())
}