#![cfg(all(feature = "cuda", feature = "triton-kernels"))]
use cudarc::driver::{CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
use std::sync::Arc;
use crate::triton_meta::parse_meta;
use crate::triton_ptx::w4a16_gptq_f16;
pub const BM: usize = 64;
pub const BN: usize = 64;
pub const BK: usize = 32;
pub const W4A16_PTX: &str = w4a16_gptq_f16::PTX;
pub struct TritonGptqWeight {
pub qweight: CudaSlice<i32>,
pub scales: CudaSlice<half::f16>,
pub qzeros: CudaSlice<i32>,
pub k: usize,
pub n: usize,
pub group_size: i32,
}
struct LaunchParams {
num_warps: u32,
shared_mem: u32,
fn_name: &'static str,
}
fn launch_params() -> &'static LaunchParams {
static CACHE: std::sync::OnceLock<LaunchParams> = std::sync::OnceLock::new();
CACHE.get_or_init(|| {
let meta = parse_meta(w4a16_gptq_f16::META)
.unwrap_or_else(|e| panic!("triton w4a16 meta parse: {e}"));
let fn_name: &'static str = Box::leak(meta.name.into_boxed_str());
LaunchParams {
num_warps: meta.num_warps,
shared_mem: meta.shared_mem as u32,
fn_name,
}
})
}
pub fn launch_w4a16_gptq_triton(
stream: &Arc<CudaStream>,
func: &CudaFunction,
input: &CudaSlice<half::f16>,
weight: &TritonGptqWeight,
output: &mut CudaSlice<half::f16>,
m: i32,
) -> candle_core::Result<()> {
let k = weight.k as i32;
let n = weight.n as i32;
let gs = weight.group_size;
let lp = launch_params();
let stride_am = k;
let stride_ak = 1i32;
let stride_qwk = n;
let stride_qwn = 1i32;
let stride_sk = n;
let stride_sn = 1i32;
let stride_qzk = n / 8;
let stride_qzn = 1i32;
let stride_cm = n;
let stride_cn = 1i32;
let global_scratch: CudaSlice<u8> = stream
.alloc_zeros::<u8>(1)
.map_err(|e| candle_core::Error::Msg(format!("triton w4a16 scratch: {e}")))?;
let profile_scratch: CudaSlice<u8> = stream
.alloc_zeros::<u8>(1)
.map_err(|e| candle_core::Error::Msg(format!("triton w4a16 profile: {e}")))?;
let qw = weight.qweight.slice(..);
let sc = weight.scales.slice(..);
let qz = weight.qzeros.slice(..);
let inp = input.slice(..);
let mut b = stream.launch_builder(func);
b.arg(&inp);
b.arg(&qw);
b.arg(&sc);
b.arg(&qz);
b.arg(output);
b.arg(&m);
b.arg(&n);
b.arg(&k);
b.arg(&gs);
b.arg(&stride_am);
b.arg(&stride_ak);
b.arg(&stride_qwk);
b.arg(&stride_qwn);
b.arg(&stride_sk);
b.arg(&stride_sn);
b.arg(&stride_qzk);
b.arg(&stride_qzn);
b.arg(&stride_cm);
b.arg(&stride_cn);
b.arg(&global_scratch);
b.arg(&profile_scratch);
let grid_m = ((m as usize + BM - 1) / BM) as u32;
let grid_n = ((n as usize + BN - 1) / BN) as u32;
let block_size = lp.num_warps * 32;
let cfg = LaunchConfig {
grid_dim: (grid_m, grid_n, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: lp.shared_mem,
};
unsafe { b.launch(cfg) }.map(|_| ()).map_err(|e| {
candle_core::Error::Msg(format!(
"triton w4a16 launch: {e} (m={m}, n={n}, k={k}, gs={gs})"
))
})
}
pub fn fn_name() -> &'static str {
launch_params().fn_name
}