use crate::cuda::executor::{CudaExecutor, GpuBuffer, GpuError};
const CHUNK_THRESHOLD: u32 = 8192;
#[inline]
fn use_dp4a_kernel(dim: u32) -> bool {
dim.is_multiple_of(256) && dim <= CHUNK_THRESHOLD
}
#[allow(clippy::too_many_arguments)]
pub fn fused_ffn_swiglu_gpu(
executor: &mut CudaExecutor,
input: &GpuBuffer<f32>,
ffn_gate_name: &str,
ffn_up_name: &str,
ffn_down_name: &str,
hidden_dim: u32,
intermediate_dim: u32,
) -> Result<GpuBuffer<f32>, GpuError> {
static TRUE_DP4A_ENABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let use_true_dp4a = *TRUE_DP4A_ENABLED.get_or_init(|| {
std::env::var("TRUE_DP4A")
.map(|v| v == "1")
.unwrap_or(false)
});
if use_true_dp4a {
return fused_ffn_swiglu_gpu_true_dp4a(
executor,
input,
ffn_gate_name,
ffn_up_name,
ffn_down_name,
hidden_dim,
intermediate_dim,
);
}
let gate = if use_dp4a_kernel(hidden_dim) {
executor.dp4a_q4k_gemv_cached_async(ffn_gate_name, input, intermediate_dim, hidden_dim)?
} else {
executor.q4k_gemv_cached_async(ffn_gate_name, input, intermediate_dim, hidden_dim)?
};
let up = if use_dp4a_kernel(hidden_dim) {
executor.dp4a_q4k_gemv_cached_async(ffn_up_name, input, intermediate_dim, hidden_dim)?
} else {
executor.q4k_gemv_cached_async(ffn_up_name, input, intermediate_dim, hidden_dim)?
};
let activated = executor.fused_swiglu_gpu(&gate, &up, intermediate_dim)?;
let output = if use_dp4a_kernel(intermediate_dim) {
executor.dp4a_q4k_gemv_cached_async(
ffn_down_name,
&activated,
hidden_dim,
intermediate_dim,
)?
} else {
executor.q4k_gemv_cached_async(ffn_down_name, &activated, hidden_dim, intermediate_dim)?
};
Ok(output)
}
#[allow(clippy::too_many_arguments)]
pub fn fused_ffn_swiglu_gpu_true_dp4a(
executor: &mut CudaExecutor,
input: &GpuBuffer<f32>,
ffn_gate_name: &str,
ffn_up_name: &str,
ffn_down_name: &str,
hidden_dim: u32,
intermediate_dim: u32,
) -> Result<GpuBuffer<f32>, GpuError> {
static PACKED_DP4A_ENABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let use_packed_dp4a = *PACKED_DP4A_ENABLED.get_or_init(|| {
std::env::var("PACKED_DP4A")
.map(|v| v == "1")
.unwrap_or(false)
});
let q8_input = executor.q8_quantize_async(input, hidden_dim)?;
let gate = if use_packed_dp4a {
executor.packed_dp4a_q4k_q8_gemv_async(
ffn_gate_name,
&q8_input,
intermediate_dim,
hidden_dim,
)?
} else {
executor.q4k_q8_gemv_async(ffn_gate_name, &q8_input, intermediate_dim, hidden_dim)?
};
let up = if use_packed_dp4a {
executor.packed_dp4a_q4k_q8_gemv_async(
ffn_up_name,
&q8_input,
intermediate_dim,
hidden_dim,
)?
} else {
executor.q4k_q8_gemv_async(ffn_up_name, &q8_input, intermediate_dim, hidden_dim)?
};
let activated = executor.fused_swiglu_gpu(&gate, &up, intermediate_dim)?;
let q8_activated = executor.q8_quantize_async(&activated, intermediate_dim)?;
let output = if use_packed_dp4a {
executor.packed_dp4a_q4k_q8_gemv_async(
ffn_down_name,
&q8_activated,
hidden_dim,
intermediate_dim,
)?
} else {
executor.q4k_q8_gemv_async(ffn_down_name, &q8_activated, hidden_dim, intermediate_dim)?
};
Ok(output)
}
include!("helper.rs");