use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use cudarc::driver::{
CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, DeviceRepr, LaunchConfig,
PushKernelArg,
};
use ferrum_types::{FerrumError, Result};
use half::f16;
use super::{current_device_ordinal, default_stream, CudaBackend, CudaState};
use crate::backend::{Backend, BackendQuantGguf, BackendQuantMarlin};
use crate::ptx;
#[cfg(feature = "triton-kernels")]
pub enum GptqStoreCuda {
Marlin(crate::marlin::MarlinWeight),
Triton(crate::triton_w4a16::TritonGptqWeight),
}
#[cfg(not(feature = "triton-kernels"))]
pub type GptqStoreCuda = crate::marlin::MarlinWeight;
fn use_triton_int4() -> bool {
cuda_quant_runtime_config().triton_int4
}
#[cfg(feature = "vllm-moe-marlin")]
pub(crate) fn use_vllm_moe() -> bool {
cuda_quant_runtime_config().vllm_moe
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct CudaQuantRuntimeConfig {
triton_int4: bool,
vllm_moe: bool,
vllm_marlin: bool,
vllm_marlin_sms: i32,
vllm_atomic_add: bool,
vllm_fp32_reduce: bool,
moe_fused: bool,
moe_streams: usize,
}
impl CudaQuantRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self {
triton_int4: false,
vllm_moe: false,
vllm_marlin: false,
vllm_marlin_sms: 128,
vllm_atomic_add: false,
vllm_fp32_reduce: false,
moe_fused: true,
moe_streams: 4,
};
for (name, value) in vars {
let value = value.as_ref();
match name.as_ref() {
"FERRUM_TRITON_INT4" => config.triton_int4 = value == "1",
"FERRUM_VLLM_MOE" => config.vllm_moe = value == "1",
"FERRUM_VLLM_MARLIN" => config.vllm_marlin = value == "1",
"FERRUM_VLLM_MARLIN_SMS" => {
if let Ok(sms) = value.parse::<i32>() {
config.vllm_marlin_sms = sms;
}
}
"FERRUM_VLLM_ATOMIC_ADD" => config.vllm_atomic_add = value == "1",
"FERRUM_VLLM_FP32_REDUCE" => config.vllm_fp32_reduce = value == "1",
"FERRUM_MOE_FUSED" => config.moe_fused = value != "0",
"FERRUM_MOE_STREAMS" => {
if let Ok(streams) = value.parse::<usize>() {
config.moe_streams = streams.max(1);
}
}
_ => {}
}
}
config
}
}
fn cuda_quant_runtime_config() -> &'static CudaQuantRuntimeConfig {
static CONFIG: OnceLock<CudaQuantRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(CudaQuantRuntimeConfig::from_env)
}
#[cfg(feature = "marlin")]
struct MarlinGatherScratch {
buf: CudaSlice<f16>,
capacity: usize, }
unsafe impl Send for MarlinGatherScratch {}
unsafe impl Sync for MarlinGatherScratch {}
static MARLIN_GATHER_SCRATCH: std::sync::OnceLock<
std::sync::RwLock<HashMap<usize, MarlinGatherScratch>>,
> = std::sync::OnceLock::new();
fn marlin_gather_scratch_slots() -> &'static std::sync::RwLock<HashMap<usize, MarlinGatherScratch>>
{
MARLIN_GATHER_SCRATCH.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
#[cfg(feature = "marlin")]
pub fn pregrow_marlin_gather_scratch(stream: &Arc<CudaStream>, required: usize) {
pregrow_marlin_gather_scratch_for_ordinal(current_device_ordinal(), stream, required);
}
fn pregrow_marlin_gather_scratch_for_ordinal(
ordinal: usize,
stream: &Arc<CudaStream>,
required: usize,
) {
let slots = marlin_gather_scratch_slots();
{
let g = slots.read().expect("MARLIN_GATHER_SCRATCH poisoned");
if let Some(s) = g.get(&ordinal) {
if s.capacity >= required {
return;
}
}
}
let mut w = slots.write().expect("MARLIN_GATHER_SCRATCH poisoned");
let need_new = match w.get(&ordinal) {
Some(s) => s.capacity < required,
None => true,
};
if need_new {
let buf = unsafe { stream.alloc::<f16>(required) }
.expect("MARLIN_GATHER_SCRATCH pregrow alloc failed");
w.insert(
ordinal,
MarlinGatherScratch {
buf,
capacity: required,
},
);
}
}
fn with_marlin_gather_scratch<R>(
stream: &Arc<CudaStream>,
ordinal: usize,
required: usize,
body: impl FnOnce(&mut CudaSlice<f16>) -> R,
) -> R {
let slots = marlin_gather_scratch_slots();
{
let g = slots.read().expect("MARLIN_GATHER_SCRATCH poisoned");
if let Some(s) = g.get(&ordinal) {
if s.capacity >= required {
drop(g);
let mut w = slots.write().expect("MARLIN_GATHER_SCRATCH poisoned");
let s = w.get_mut(&ordinal).expect("just observed Some");
return body(&mut s.buf);
}
}
}
let mut w = slots.write().expect("MARLIN_GATHER_SCRATCH poisoned");
let need_new = match w.get(&ordinal) {
Some(s) => s.capacity < required,
None => true,
};
if need_new {
let buf =
unsafe { stream.alloc::<f16>(required) }.expect("MARLIN_GATHER_SCRATCH alloc failed");
w.insert(
ordinal,
MarlinGatherScratch {
buf,
capacity: required,
},
);
}
let s = w.get_mut(&ordinal).expect("just allocated");
body(&mut s.buf)
}
#[cfg(feature = "marlin")]
fn moe_gemm_phase_fused_impl(
ctx: &mut CudaState,
input: &CudaSlice<f16>,
weight: &crate::marlin::MarlinWeight,
dispatches: &[(usize, usize, usize, usize)],
n_per_expert: usize,
output: &mut CudaSlice<f16>,
_k: usize,
) -> Result<()> {
if dispatches.is_empty() {
return Ok(());
}
let num_active = dispatches.len();
let max_global_e = dispatches.iter().map(|d| d.0).max().unwrap();
let num_experts_global = max_global_e + 1;
let mut tokens_global = vec![0i32; num_experts_global];
let mut row_offsets_global = vec![0i32; num_experts_global];
let mut bucket_active_ids: [Vec<i32>; 4] = [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
for &(e_idx, in_row, _out_row, m_e) in dispatches {
debug_assert!(m_e > 0);
debug_assert!(m_e <= 64);
tokens_global[e_idx] = m_e as i32;
row_offsets_global[e_idx] = in_row as i32;
let bucket = ((m_e + 15) / 16).clamp(1, 4) - 1;
bucket_active_ids[bucket].push(e_idx as i32);
}
let stream = ctx.stream.clone();
let row_off_dev = stream
.clone_htod(&row_offsets_global)
.map_err(|e| FerrumError::model(format!("htod row_offsets: {e}")))?;
let tok_dev = stream
.clone_htod(&tokens_global)
.map_err(|e| FerrumError::model(format!("htod tokens: {e}")))?;
for (b, ids) in bucket_active_ids.iter().enumerate() {
if ids.is_empty() {
continue;
}
let prob_m_bucket = ((b + 1) * 16) as i32;
let active_dev = stream
.clone_htod(ids)
.map_err(|e| FerrumError::model(format!("htod active_ids[b={b}]: {e}")))?;
crate::marlin::marlin_gemm_moe(
&stream,
input,
weight,
output,
&row_off_dev,
&tok_dev,
Some(&active_dev),
ids.len() as i32,
prob_m_bucket,
n_per_expert as i32,
num_experts_global as i32,
)
.map_err(|e| FerrumError::model(format!("marlin_gemm_moe (bucket={b}): {e}")))?;
}
let _ = num_active;
Ok(())
}
pub fn marlin_gemm_with_perm(
ctx: &mut CudaState,
a: &CudaSlice<f16>,
weight: &crate::marlin::MarlinWeight,
out: &mut CudaSlice<f16>,
m: usize,
) -> Result<()> {
let use_vllm = cuda_quant_runtime_config().vllm_marlin;
if let Some(perm) = weight.perm.as_ref() {
let k = weight.k;
let stream = ctx.stream.clone();
let func = ctx.func("gather_columns", ptx::GATHER_COLUMNS, "gather_columns_f16");
let m_i32 = m as i32;
let k_i32 = k as i32;
let block_x: u32 = 512;
let grid_y: u32 = ((k as u32) + block_x - 1) / block_x;
with_marlin_gather_scratch(&stream, ctx.ordinal, m * k, |a_gathered| -> Result<()> {
let mut b = stream.launch_builder(&func);
b.arg(a);
b.arg(perm);
b.arg(&mut *a_gathered);
b.arg(&m_i32);
b.arg(&k_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (m as u32, grid_y, 1),
block_dim: (block_x, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("gather_columns launch: {e}")))?;
if use_vllm {
return launch_vllm_marlin(&ctx.stream, a_gathered, weight, out, m);
}
crate::marlin::marlin_gemm(&ctx.stream, a_gathered, weight, out, m as i32)
.map_err(|e| FerrumError::model(format!("marlin_gemm (perm): {e}")))
})
} else {
if use_vllm {
return launch_vllm_marlin(&ctx.stream, a, weight, out, m);
}
crate::marlin::marlin_gemm(&ctx.stream, a, weight, out, m as i32)
.map_err(|e| FerrumError::model(format!("marlin_gemm: {e}")))
}
}
#[cfg(feature = "vllm-marlin")]
pub fn launch_vllm_marlin(
stream: &Arc<cudarc::driver::CudaStream>,
a: &CudaSlice<f16>,
weight: &crate::marlin::MarlinWeight,
out: &mut CudaSlice<f16>,
m: usize,
) -> Result<()> {
use cudarc::driver::DevicePtr;
use std::sync::atomic::{AtomicU64, Ordering};
static VLLM_MARLIN_CALLS: AtomicU64 = AtomicU64::new(0);
let n = VLLM_MARLIN_CALLS.fetch_add(1, Ordering::Relaxed);
if n == 0 || n.is_multiple_of(1024) {
eprintln!(
"[vllm-marlin] launch #{n} m={m} n={} k={} group_size={}",
weight.n, weight.k, weight.group_size,
);
}
{
let (ws_ptr, _g) = weight.workspace.device_ptr(stream);
let raw_stream = stream.cu_stream();
unsafe {
cudarc::driver::sys::cuMemsetD32Async(ws_ptr, 0, weight.workspace.len(), raw_stream);
}
}
let (a_ptr, _g_a) = a.device_ptr(stream);
let (b_ptr, _g_b) = weight.qweight.device_ptr(stream);
let (c_ptr, _g_c) = out.device_ptr(stream);
let (s_ptr, _g_s) = weight.scales.device_ptr(stream);
let (ws_ptr, _g_w) = weight.workspace.device_ptr(stream);
let raw_stream = stream.cu_stream();
let n = weight.n as i32;
let k = weight.k as i32;
let group_size = weight.group_size;
let num_groups = if group_size > 0 { k / group_size } else { 1 };
let runtime_config = cuda_quant_runtime_config();
let sms = runtime_config.vllm_marlin_sms;
let use_atomic_add = runtime_config.vllm_atomic_add;
let use_fp32_reduce = runtime_config.vllm_fp32_reduce;
unsafe {
crate::vllm_marlin::launch_marlin_mm_f16_u4b8(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
std::ptr::null_mut(), std::ptr::null_mut(), s_ptr as *mut _, std::ptr::null_mut(), std::ptr::null_mut(), std::ptr::null_mut(), m as i32,
n,
k,
k, ws_ptr as *mut _,
false, true, num_groups,
group_size,
0, raw_stream as cudarc::driver::sys::CUstream,
sms,
use_atomic_add,
use_fp32_reduce,
);
}
Ok(())
}
#[cfg(not(feature = "vllm-marlin"))]
fn launch_vllm_marlin(
_stream: &Arc<cudarc::driver::CudaStream>,
_a: &CudaSlice<f16>,
_weight: &crate::marlin::MarlinWeight,
_out: &mut CudaSlice<f16>,
_m: usize,
) -> Result<()> {
Err(FerrumError::model(
"FERRUM_VLLM_MARLIN=1 set but binary not built with --features vllm-marlin",
))
}
impl BackendQuantMarlin for CudaBackend {
fn pregrow_marlin_gather_scratch(ctx: &mut Self::Context, required: usize) {
#[cfg(feature = "marlin")]
{
let stream = ctx.stream.clone();
pregrow_marlin_gather_scratch_for_ordinal(ctx.ordinal, &stream, required);
}
#[cfg(not(feature = "marlin"))]
{
let _ = (ctx, required);
}
}
fn load_gptq(
qweight: &[i32],
scales: &[f32],
qzeros: &[i32],
g_idx: Option<&[i32]>,
bias_host: Option<&[f32]>,
bits: u32,
group_size: usize,
k: usize,
n: usize,
) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
if bits != 4 {
return Err(FerrumError::unsupported(format!(
"CUDA GPTQ: only bits=4 supported (got {bits})"
)));
}
let _ = qzeros;
#[cfg(feature = "triton-kernels")]
if use_triton_int4() {
let stream = default_stream();
let scales_f16: Vec<f16> = scales.iter().map(|&x| f16::from_f32(x)).collect();
let qweight_dev = stream
.clone_htod(qweight)
.map_err(|e| FerrumError::model(format!("triton qweight htod: {e}")))?;
let scales_dev = stream
.clone_htod(&scales_f16)
.map_err(|e| FerrumError::model(format!("triton scales htod: {e}")))?;
let qzeros_dev = stream
.clone_htod(qzeros)
.map_err(|e| FerrumError::model(format!("triton qzeros htod: {e}")))?;
tracing::info!("GPTQ load (triton-rs w4a16): K={k}, N={n}, gs={group_size}");
let store = GptqStoreCuda::Triton(crate::triton_w4a16::TritonGptqWeight {
qweight: qweight_dev,
scales: scales_dev,
qzeros: qzeros_dev,
k,
n,
group_size: group_size as i32,
});
let bias = bias_host.map(<Self as crate::backend::Backend>::from_slice);
return Ok(Box::new(
crate::quant_linear::cuda_marlin::CudaMarlinLinear {
store,
bias,
in_features: k,
out_features: n,
},
));
}
let (qweight_for_repack, perm_dev_opt): (Vec<i32>, Option<CudaSlice<i32>>) =
if let Some(gx) = g_idx {
let is_desc_act = gx
.iter()
.enumerate()
.any(|(i, &g)| g != (i as i32) / group_size as i32);
if is_desc_act {
let mut perm: Vec<usize> = (0..k).collect();
perm.sort_by_key(|&i| gx[i]);
let permuted_qweight =
crate::marlin::permute_gptq_qweight_rows(qweight, &perm, k, n);
let perm_i32: Vec<i32> = perm.iter().map(|&p| p as i32).collect();
let stream = default_stream();
let perm_dev = stream
.clone_htod(&perm_i32)
.map_err(|e| FerrumError::model(format!("perm htod: {e}")))?;
tracing::info!(
"GPTQ load (Marlin + desc_act perm-aware): K={k} N={n} gs={group_size}"
);
(permuted_qweight, Some(perm_dev))
} else {
(qweight.to_vec(), None)
}
} else {
(qweight.to_vec(), None)
};
let marlin_qweight_i32 = crate::marlin::repack_gptq_to_marlin(&qweight_for_repack, k, n);
let scales_f16: Vec<f16> = scales.iter().map(|&x| f16::from_f32(x)).collect();
let marlin_scales_f16 =
crate::marlin::repack_scales_to_marlin(&scales_f16, k, n, group_size);
let stream = default_stream();
let qweight_dev = stream
.clone_htod(&marlin_qweight_i32)
.map_err(|e| FerrumError::model(format!("qweight htod: {e}")))?;
let scales_dev = stream
.clone_htod(&marlin_scales_f16)
.map_err(|e| FerrumError::model(format!("scales htod: {e}")))?;
let max_par = 16usize;
let ws_len = (n / 128).max(1) * max_par;
let workspace_dev = stream
.alloc_zeros::<i32>(ws_len)
.map_err(|e| FerrumError::model(format!("ws alloc: {e}")))?;
let marlin_weight = crate::marlin::MarlinWeight {
qweight: qweight_dev,
scales: scales_dev,
workspace: workspace_dev,
k,
n,
group_size: group_size as i32,
perm: perm_dev_opt,
};
#[cfg(feature = "triton-kernels")]
let store = GptqStoreCuda::Marlin(marlin_weight);
#[cfg(not(feature = "triton-kernels"))]
let store: GptqStoreCuda = marlin_weight;
let bias = bias_host.map(<Self as crate::backend::Backend>::from_slice);
Ok(Box::new(
crate::quant_linear::cuda_marlin::CudaMarlinLinear {
store,
bias,
in_features: k,
out_features: n,
},
))
}
fn load_gptq_stacked(
qweights: &[&[i32]],
scales: &[&[f32]],
qzeros: &[&[i32]],
g_idx: Option<&[i32]>,
bits: u32,
group_size: usize,
k: usize,
n_per_expert: usize,
) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
if bits != 4 {
return Err(FerrumError::unsupported(format!(
"CUDA GPTQ stacked: only bits=4 supported (got {bits})"
)));
}
let num_experts = qweights.len();
if num_experts == 0 {
return Err(FerrumError::model("load_gptq_stacked: 0 experts"));
}
if scales.len() != num_experts || qzeros.len() != num_experts {
return Err(FerrumError::model(format!(
"load_gptq_stacked length mismatch: qw={} sc={} qz={}",
num_experts,
scales.len(),
qzeros.len()
)));
}
let _ = qzeros;
#[cfg(feature = "vllm-moe-marlin")]
if use_vllm_moe() {
let stream = default_stream();
let mw = crate::vllm_marlin::load_stacked_gptq_vllm_marlin(
&stream,
qweights,
scales,
bits,
group_size,
k,
n_per_expert,
)
.map_err(|e| FerrumError::model(format!("load_stacked_gptq_vllm_marlin: {e}")))?;
tracing::info!(
"GPTQ stacked load (vLLM marlin path): {num_experts} experts × N={n_per_expert} × K={k} (gs={group_size})",
);
#[cfg(feature = "triton-kernels")]
let store: GptqStoreCuda = GptqStoreCuda::Marlin(mw);
#[cfg(not(feature = "triton-kernels"))]
let store: GptqStoreCuda = mw;
return Ok(std::sync::Arc::new(
crate::quant_linear::cuda_marlin_stack::CudaMarlinExpertStack::new(
std::sync::Arc::new(store),
num_experts,
n_per_expert,
k,
),
));
}
#[cfg(feature = "triton-kernels")]
if use_triton_int4() {
return Err(FerrumError::unsupported(
"load_gptq_stacked: Triton w4a16 path not implemented; \
unset FERRUM_TRITON_INT4 to use Marlin",
));
}
let (perm_dev_opt, perm_for_repack): (Option<CudaSlice<i32>>, Option<Vec<usize>>) =
if let Some(gx) = g_idx {
let is_desc_act = gx
.iter()
.enumerate()
.any(|(i, &g)| g != (i as i32) / group_size as i32);
if is_desc_act {
let mut perm: Vec<usize> = (0..k).collect();
perm.sort_by_key(|&i| gx[i]);
let perm_i32: Vec<i32> = perm.iter().map(|&p| p as i32).collect();
let stream = default_stream();
let perm_dev = stream
.clone_htod(&perm_i32)
.map_err(|e| FerrumError::model(format!("perm htod: {e}")))?;
(Some(perm_dev), Some(perm))
} else {
(None, None)
}
} else {
(None, None)
};
use rayon::prelude::*;
let qw_per_expert_i32 = (n_per_expert * k) / 8;
let sc_per_expert_f16 = (k / group_size) * n_per_expert;
let mut packed_qw: Vec<i32> = vec![0i32; num_experts * qw_per_expert_i32];
let mut packed_sc: Vec<f16> = vec![f16::ZERO; num_experts * sc_per_expert_f16];
packed_qw
.par_chunks_mut(qw_per_expert_i32)
.zip(packed_sc.par_chunks_mut(sc_per_expert_f16))
.enumerate()
.for_each(|(e, (qw_out, sc_out))| {
let qw_in: Vec<i32> = if let Some(perm) = &perm_for_repack {
crate::marlin::permute_gptq_qweight_rows(qweights[e], perm, k, n_per_expert)
} else {
qweights[e].to_vec()
};
let qw_packed = crate::marlin::repack_gptq_to_marlin(&qw_in, k, n_per_expert);
qw_out.copy_from_slice(&qw_packed);
let sc_f16: Vec<f16> = scales[e].iter().map(|&x| f16::from_f32(x)).collect();
let sc_packed =
crate::marlin::repack_scales_to_marlin(&sc_f16, k, n_per_expert, group_size);
sc_out.copy_from_slice(&sc_packed);
});
let stream = default_stream();
let qweight_dev = stream
.clone_htod(&packed_qw)
.map_err(|e| FerrumError::model(format!("stacked qweight htod: {e}")))?;
let scales_dev = stream
.clone_htod(&packed_sc)
.map_err(|e| FerrumError::model(format!("stacked scales htod: {e}")))?;
let max_par = 16usize;
let ws_per_expert = (n_per_expert / 128).max(1) * max_par;
let ws_len = num_experts * ws_per_expert;
let workspace_dev = stream
.alloc_zeros::<i32>(ws_len)
.map_err(|e| FerrumError::model(format!("stacked ws alloc: {e}")))?;
let total_n = num_experts * n_per_expert;
let marlin_weight = crate::marlin::MarlinWeight {
qweight: qweight_dev,
scales: scales_dev,
workspace: workspace_dev,
k,
n: total_n,
group_size: group_size as i32,
perm: perm_dev_opt,
};
tracing::info!(
"GPTQ stacked load: {} experts × N={n_per_expert} × K={k} (gs={group_size})",
num_experts
);
#[cfg(feature = "triton-kernels")]
let store: GptqStoreCuda = GptqStoreCuda::Marlin(marlin_weight);
#[cfg(not(feature = "triton-kernels"))]
let store: GptqStoreCuda = marlin_weight;
Ok(std::sync::Arc::new(
crate::quant_linear::cuda_marlin_stack::CudaMarlinExpertStack::new(
std::sync::Arc::new(store),
num_experts,
n_per_expert,
k,
),
))
}
}
#[cfg(feature = "marlin")]
pub(crate) fn moe_gemm_phase_batched_impl(
ctx: &mut CudaState,
input: &<CudaBackend as crate::backend::Backend>::Buffer,
weight: &GptqStoreCuda,
dispatches: &[(usize, usize, usize, usize)],
n_per_expert: usize,
output: &mut <CudaBackend as crate::backend::Backend>::Buffer,
k: usize,
) -> Result<()> {
#[cfg(feature = "triton-kernels")]
let mw = match weight {
GptqStoreCuda::Marlin(mw) => mw,
GptqStoreCuda::Triton(_) => {
return Err(FerrumError::unsupported(
"moe_gemm_phase_batched: Triton w4a16 not supported",
));
}
};
#[cfg(not(feature = "triton-kernels"))]
let mw: &crate::marlin::MarlinWeight = weight;
let runtime_config = cuda_quant_runtime_config();
if runtime_config.moe_fused {
return moe_gemm_phase_fused_impl(
ctx,
input.as_f16(),
mw,
dispatches,
n_per_expert,
output.as_f16_mut(),
k,
);
}
let n_streams = runtime_config.moe_streams;
if n_streams == 1 {
let default_stream = ctx.stream.clone();
for (expert_idx, in_row_offset, out_row_offset, m) in dispatches {
crate::marlin::marlin_gemm_with_offset_strided(
&default_stream,
input.as_f16(),
*in_row_offset as i32,
mw,
output.as_f16_mut(),
*out_row_offset as i32,
*m as i32,
(expert_idx * n_per_expert) as i32,
n_per_expert as i32,
)
.map_err(|e| FerrumError::model(format!("marlin offset_strided: {e}")))?;
}
let _ = k;
return Ok(());
}
let (entry_event, exit_events) = ctx.moe_sync_events();
let pool: Vec<Arc<CudaStream>> = ctx.moe_stream_pool().to_vec();
let default_stream = ctx.stream.clone();
use cudarc::driver::sys as cu;
unsafe {
cu::cuEventRecord(entry_event, default_stream.cu_stream());
}
for stream in &pool {
unsafe {
cu::cuStreamWaitEvent(stream.cu_stream(), entry_event, 0);
}
}
for (i, (expert_idx, in_row_offset, out_row_offset, m)) in dispatches.iter().enumerate() {
let stream = &pool[i % n_streams];
crate::marlin::marlin_gemm_with_offset_strided(
stream,
input.as_f16(),
*in_row_offset as i32,
mw,
output.as_f16_mut(),
*out_row_offset as i32,
*m as i32,
(expert_idx * n_per_expert) as i32,
n_per_expert as i32,
)
.map_err(|e| FerrumError::model(format!("marlin offset_strided: {e}")))?;
}
let _ = k;
debug_assert_eq!(
exit_events.len(),
pool.len(),
"moe_sync_events exit count != pool size"
);
for (i, stream) in pool.iter().enumerate() {
unsafe {
cu::cuEventRecord(exit_events[i], stream.cu_stream());
}
}
for ev in &exit_events {
unsafe {
cu::cuStreamWaitEvent(default_stream.cu_stream(), *ev, 0);
}
}
Ok(())
}
#[cfg(feature = "vllm-moe-marlin")]
pub(crate) fn moe_gemm_phase_vllm_impl(
ctx: &mut CudaState,
input: &<CudaBackend as crate::backend::Backend>::Buffer,
weight: &GptqStoreCuda,
sorted_token_ids: &<CudaBackend as crate::backend::Backend>::Buffer,
expert_ids: &<CudaBackend as crate::backend::Backend>::Buffer,
num_tokens_past_padded: &<CudaBackend as crate::backend::Backend>::Buffer,
output: &mut <CudaBackend as crate::backend::Backend>::Buffer,
prob_m: usize,
n_per_expert: usize,
k: usize,
moe_block_size: usize,
top_k: usize,
) -> Result<()> {
#[cfg(feature = "triton-kernels")]
let mw = match weight {
GptqStoreCuda::Marlin(mw) => mw,
GptqStoreCuda::Triton(_) => {
return Err(FerrumError::unsupported(
"moe_gemm_phase_vllm: Triton store unsupported",
));
}
};
#[cfg(not(feature = "triton-kernels"))]
let mw: &crate::marlin::MarlinWeight = weight;
let stream = ctx.stream.clone();
let st_ref = sorted_token_ids.as_i32();
let eid_ref = expert_ids.as_i32();
let npp_ref = num_tokens_past_padded.as_i32();
crate::backend::cuda::with_vllm_moe_c_tmp(&stream, ctx.ordinal, |c_tmp_mut| {
crate::marlin::marlin_gemm_moe_vllm(
&stream,
input.as_f16(),
mw,
output.as_f16_mut(),
Some(c_tmp_mut),
st_ref,
eid_ref,
npp_ref,
None,
moe_block_size as i32,
top_k as i32,
false,
false,
prob_m as i32,
n_per_expert as i32,
k as i32,
)
.map_err(|e| FerrumError::model(format!("marlin_gemm_moe_vllm: {e}")))
})
}
impl BackendQuantGguf for CudaBackend {}
#[cfg(test)]
mod tests {
use super::CudaQuantRuntimeConfig;
#[test]
fn cuda_quant_runtime_config_parses_marlin_and_moe_knobs() {
let config = CudaQuantRuntimeConfig::from_env_vars([
("FERRUM_TRITON_INT4", "1"),
("FERRUM_VLLM_MOE", "1"),
("FERRUM_VLLM_MARLIN", "1"),
("FERRUM_VLLM_MARLIN_SMS", "132"),
("FERRUM_VLLM_ATOMIC_ADD", "1"),
("FERRUM_VLLM_FP32_REDUCE", "1"),
("FERRUM_MOE_FUSED", "0"),
("FERRUM_MOE_STREAMS", "0"),
]);
assert!(config.triton_int4);
assert!(config.vllm_moe);
assert!(config.vllm_marlin);
assert_eq!(config.vllm_marlin_sms, 132);
assert!(config.vllm_atomic_add);
assert!(config.vllm_fp32_reduce);
assert!(!config.moe_fused);
assert_eq!(config.moe_streams, 1);
}
#[test]
fn cuda_quant_runtime_config_keeps_existing_defaults() {
let config = CudaQuantRuntimeConfig::from_env_vars([
("FERRUM_TRITON_INT4", "true"),
("FERRUM_VLLM_MARLIN_SMS", "bad"),
("FERRUM_MOE_STREAMS", "bad"),
]);
assert!(!config.triton_int4);
assert!(!config.vllm_moe);
assert!(!config.vllm_marlin);
assert_eq!(config.vllm_marlin_sms, 128);
assert!(!config.vllm_atomic_add);
assert!(!config.vllm_fp32_reduce);
assert!(config.moe_fused);
assert_eq!(config.moe_streams, 4);
}
}