impl CudaExecutor {
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
fn workspace_ffn_phase(
&mut self,
hidden_buf1: &GpuBuffer<f32>,
hidden_buf2: &GpuBuffer<f32>,
input_staging: &GpuBuffer<f32>,
ffn_gate_buf: &GpuBuffer<f32>,
ffn_up_buf: &GpuBuffer<f32>,
ffn_act_buf: &GpuBuffer<f32>,
layer_idx: usize,
layer_weights: &ValidatedLayerWeights,
hidden_dim: u32,
intermediate_dim: u32,
epsilon: f32,
skip_debug: bool,
profiling: bool,
) -> Result<(), GpuError> {
let timer_rmsnorm2 = if profiling {
self.start_brick_id(trueno::BrickId::RmsNorm)
} else {
None
};
self.rmsnorm_ptr_into(
input_staging,
layer_weights.ffn_norm_ptr,
layer_weights.ffn_norm_len,
hidden_buf1,
hidden_dim,
epsilon,
)?;
if profiling {
self.stop_brick_id(timer_rmsnorm2, 1);
}
self.q8_activation_valid = false;
let use_fused = self.gpu_profile.fused_gate_up
&& layer_weights.ffn_gate_qtype == WeightQuantType::Q4K
&& layer_weights.ffn_up_qtype == WeightQuantType::Q4K;
if use_fused {
self.fused_gate_up_swiglu_hw_dp4a_q4k_gemv_into(
layer_weights.ffn_gate_ptr,
layer_weights.ffn_up_ptr,
hidden_buf1, ffn_act_buf, hidden_dim, intermediate_dim,
)?;
} else {
self.workspace_ffn_gate_up_swiglu_separate(
hidden_buf1, ffn_gate_buf, ffn_up_buf, ffn_act_buf,
layer_weights, intermediate_dim, hidden_dim, profiling,
)?;
}
if !skip_debug && (layer_idx < 4 || (layer_idx >= 10 && layer_idx <= 12)) {
self.debug_check_buf(ffn_act_buf, "SwiGLU", layer_idx)?;
}
let metadata_qtype = layer_weights.ffn_down_qtype;
let metadata_matches = metadata_qtype.matches_size(
layer_weights.ffn_down_len,
hidden_dim as usize,
intermediate_dim as usize,
);
let ffn_down_qtype = if metadata_matches {
metadata_qtype
} else {
WeightQuantType::from_size(
layer_weights.ffn_down_len,
hidden_dim as usize,
intermediate_dim as usize,
)
.unwrap_or(metadata_qtype)
};
self.q8_activation_valid = false;
let timer_ffn_down = if profiling {
self.start_brick_id(trueno::BrickId::DownProjection)
} else {
None
};
self.gemv_dispatch(
ffn_down_qtype,
layer_weights.ffn_down_ptr,
ffn_act_buf, hidden_buf1, hidden_dim, intermediate_dim,
)?;
if profiling {
self.stop_brick_id(timer_ffn_down, 1);
}
if !skip_debug && (layer_idx < 4 || (layer_idx >= 10 && layer_idx <= 12)) {
self.debug_check_buf(hidden_buf1, "FFN down", layer_idx)?;
}
let timer_res2 = if profiling {
self.start_brick_timer("Residual2")
} else {
None
};
self.residual_add_into(input_staging, hidden_buf1, hidden_buf2, hidden_dim)?;
if profiling {
self.stop_brick_timer(timer_res2, 1);
}
self.stream.synchronize()?;
if !skip_debug && (layer_idx < 10 || (layer_idx >= 10 && layer_idx <= 12)) {
self.debug_check_buf(hidden_buf2, "Layer output", layer_idx)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn workspace_ffn_gate_up_swiglu_separate(
&mut self,
hidden_buf1: &GpuBuffer<f32>,
ffn_gate_buf: &GpuBuffer<f32>,
ffn_up_buf: &GpuBuffer<f32>,
ffn_act_buf: &GpuBuffer<f32>,
layer_weights: &ValidatedLayerWeights,
intermediate_dim: u32,
hidden_dim: u32,
profiling: bool,
) -> Result<(), GpuError> {
let timer = if profiling {
self.start_brick_id(trueno::BrickId::GateProjection)
} else {
None
};
self.gemv_dispatch(
layer_weights.ffn_gate_qtype,
layer_weights.ffn_gate_ptr,
hidden_buf1, ffn_gate_buf, intermediate_dim, hidden_dim,
)?;
self.gemv_dispatch(
layer_weights.ffn_up_qtype,
layer_weights.ffn_up_ptr,
hidden_buf1, ffn_up_buf, intermediate_dim, hidden_dim,
)?;
if profiling {
self.stop_brick_id(timer, 1);
}
self.fused_swiglu_into(ffn_gate_buf, ffn_up_buf, ffn_act_buf, intermediate_dim)?;
Ok(())
}
}