#[cfg(feature = "triton-kernels")]
use cudarc::driver::{CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
#[cfg(feature = "triton-kernels")]
use std::sync::Arc;
#[cfg(feature = "triton-kernels")]
use crate::triton_meta::parse_meta;
#[cfg(feature = "triton-kernels")]
use crate::triton_ptx::fused_moe_w4a16_f16_bm16;
pub const BM: i32 = 16;
pub const BN: i32 = 64;
#[cfg_attr(not(feature = "triton-kernels"), allow(dead_code))]
pub const BK: i32 = 32;
#[cfg(feature = "triton-kernels")]
pub const FUSED_MOE_W4A16_PTX: &str = fused_moe_w4a16_f16_bm16::PTX;
pub fn fn_name() -> &'static str {
"fused_moe_w4a16_typed"
}
#[cfg(feature = "triton-kernels")]
pub struct TritonStackedGptqWeight {
pub qweight: CudaSlice<i32>,
pub scales: CudaSlice<half::f16>,
pub qzeros: CudaSlice<i32>,
pub num_experts: usize,
pub k: usize,
pub n: usize,
pub group_size: i32,
}
#[cfg(feature = "triton-kernels")]
pub fn load_stacked_gptq_raw(
stream: &Arc<CudaStream>,
qweights: &[&[i32]],
scales_f32: &[&[f32]],
qzeros: &[&[i32]],
bits: u32,
group_size: usize,
k: usize,
n_per_expert: usize,
) -> candle_core::Result<TritonStackedGptqWeight> {
if bits != 4 {
return Err(candle_core::Error::Msg(format!(
"TritonStackedGptqWeight: only bits=4 supported (got {bits})"
)));
}
let num_experts = qweights.len();
if num_experts == 0 || scales_f32.len() != num_experts || qzeros.len() != num_experts {
return Err(candle_core::Error::Msg(format!(
"TritonStackedGptqWeight: shape mismatch qw={} sc={} qz={}",
num_experts,
scales_f32.len(),
qzeros.len()
)));
}
let qw_per = (k / 8) * n_per_expert;
let groups = k / group_size;
let sc_per = groups * n_per_expert;
let qz_per = groups * (n_per_expert / 8);
for (e, qw) in qweights.iter().enumerate() {
if qw.len() != qw_per {
return Err(candle_core::Error::Msg(format!(
"TritonStacked: qweight[{e}].len()={} expected {qw_per}",
qw.len()
)));
}
}
for (e, sc) in scales_f32.iter().enumerate() {
if sc.len() != sc_per {
return Err(candle_core::Error::Msg(format!(
"TritonStacked: scales[{e}].len()={} expected {sc_per}",
sc.len()
)));
}
}
for (e, qz) in qzeros.iter().enumerate() {
if qz.len() != qz_per {
return Err(candle_core::Error::Msg(format!(
"TritonStacked: qzeros[{e}].len()={} expected {qz_per}",
qz.len()
)));
}
}
let mut qw_flat: Vec<i32> = Vec::with_capacity(num_experts * qw_per);
for qw in qweights {
qw_flat.extend_from_slice(qw);
}
let mut sc_flat_f16: Vec<half::f16> = Vec::with_capacity(num_experts * sc_per);
for sc in scales_f32 {
sc_flat_f16.extend(sc.iter().map(|&x| half::f16::from_f32(x)));
}
let mut qz_flat: Vec<i32> = Vec::with_capacity(num_experts * qz_per);
for qz in qzeros {
qz_flat.extend_from_slice(qz);
}
let qw_dev = stream
.clone_htod(&qw_flat)
.map_err(|e| candle_core::Error::Msg(format!("triton qw htod: {e}")))?;
let sc_dev = stream
.clone_htod(&sc_flat_f16)
.map_err(|e| candle_core::Error::Msg(format!("triton sc htod: {e}")))?;
let qz_dev = stream
.clone_htod(&qz_flat)
.map_err(|e| candle_core::Error::Msg(format!("triton qz htod: {e}")))?;
Ok(TritonStackedGptqWeight {
qweight: qw_dev,
scales: sc_dev,
qzeros: qz_dev,
num_experts,
k,
n: n_per_expert,
group_size: group_size as i32,
})
}
#[cfg(feature = "triton-kernels")]
#[allow(clippy::too_many_arguments)]
pub fn launch_fused_moe_w4a16_triton(
stream: &Arc<CudaStream>,
func: &CudaFunction,
input: &CudaSlice<half::f16>, weight: &TritonStackedGptqWeight,
output: &mut CudaSlice<half::f16>, sorted_token_ids: &CudaSlice<i32>,
expert_ids: &CudaSlice<i32>,
num_padded_tokens: i32, size_m: i32, ) -> candle_core::Result<()> {
let k = weight.k as i32;
let n = weight.n as i32;
let gs = weight.group_size;
let qw_per_expert = ((weight.k / 8) * weight.n) as i32;
let groups = (weight.k as i32) / gs;
let s_per_expert = (groups as i64 * weight.n as i64) as i32;
let qz_per_expert = (groups * (weight.n as i32) / 8) as i32;
let stride_am = k;
let stride_ak = 1i32;
let stride_qwk = n;
let stride_qwn = 1i32;
let stride_sk = n;
let stride_sn = 1i32;
let stride_qzk = n / 8;
let stride_qzn = 1i32;
let stride_cm = n;
let stride_cn = 1i32;
let num_valid_tokens = size_m;
let global_scratch: CudaSlice<u8> = stream
.alloc_zeros::<u8>(1)
.map_err(|e| candle_core::Error::Msg(format!("triton fused_moe scratch: {e}")))?;
let profile_scratch: CudaSlice<u8> = stream
.alloc_zeros::<u8>(1)
.map_err(|e| candle_core::Error::Msg(format!("triton fused_moe profile: {e}")))?;
let inp = input.slice(..);
let qw = weight.qweight.slice(..);
let sc = weight.scales.slice(..);
let qz = weight.qzeros.slice(..);
let st = sorted_token_ids.slice(..);
let eid = expert_ids.slice(..);
let mut b = stream.launch_builder(func);
b.arg(&inp);
b.arg(&qw);
b.arg(&sc);
b.arg(&qz);
b.arg(&st);
b.arg(&eid);
b.arg(output);
b.arg(&num_valid_tokens);
b.arg(&n);
b.arg(&k);
b.arg(&gs);
b.arg(&qw_per_expert);
b.arg(&s_per_expert);
b.arg(&qz_per_expert);
b.arg(&stride_am);
b.arg(&stride_ak);
b.arg(&stride_qwk);
b.arg(&stride_qwn);
b.arg(&stride_sk);
b.arg(&stride_sn);
b.arg(&stride_qzk);
b.arg(&stride_qzn);
b.arg(&stride_cm);
b.arg(&stride_cn);
b.arg(&global_scratch);
b.arg(&profile_scratch);
let lp = launch_params();
let blocks_m = ((num_padded_tokens + BM - 1) / BM) as u32;
let blocks_n = (n + BN - 1) / BN;
unsafe {
b.launch(LaunchConfig {
grid_dim: (blocks_m, blocks_n as u32, 1),
block_dim: (lp.num_warps * 32, 1, 1),
shared_mem_bytes: lp.shared_mem_bytes,
})
}
.map_err(|e| candle_core::Error::Msg(format!("triton fused_moe launch: {e}")))?;
Ok(())
}
#[cfg(feature = "triton-kernels")]
struct LaunchParams {
num_warps: u32,
shared_mem_bytes: u32,
}
#[cfg(feature = "triton-kernels")]
fn launch_params() -> LaunchParams {
let meta = parse_meta(fused_moe_w4a16_f16_bm16::META).expect("parse fused_moe_w4a16 meta");
LaunchParams {
num_warps: meta.num_warps as u32,
shared_mem_bytes: meta.shared_mem as u32,
}
}