use cudarc::driver::sys::CUstream;
use std::os::raw::{c_int, c_void};
extern "C" {
pub fn ferrum_vllm_gptq_marlin_repack(
qweight_in: *const c_void,
perm_in: *const c_void,
qweight_out: *mut c_void,
size_k: c_int,
size_n: c_int,
num_bits: c_int,
has_perm: c_int,
dev: c_int,
stream: CUstream,
) -> c_int;
pub fn ferrum_marlin_mm_f16_u4b8(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
c_tmp: *mut c_void,
a_s: *mut c_void,
b_s: *mut c_void,
g_idx: *mut c_void,
perm: *mut c_void,
a_tmp: *mut c_void,
prob_m: c_int,
prob_n: c_int,
prob_k: c_int,
lda: c_int,
workspace: *mut c_void,
has_act_order: bool,
is_k_full: bool,
num_groups: c_int,
group_size: c_int,
dev: c_int,
stream: CUstream,
thread_k_init: c_int,
thread_n_init: c_int,
sms: c_int,
use_atomic_add: bool,
use_fp32_reduce: bool,
);
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_marlin_mm_f16_u4b8(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
c_tmp: *mut c_void,
a_s: *mut c_void,
b_s: *mut c_void,
g_idx: *mut c_void,
perm: *mut c_void,
a_tmp: *mut c_void,
prob_m: i32,
prob_n: i32,
prob_k: i32,
lda: i32,
workspace: *mut c_void,
has_act_order: bool,
is_k_full: bool,
num_groups: i32,
group_size: i32,
dev: i32,
stream: CUstream,
sms: i32,
use_atomic_add: bool,
use_fp32_reduce: bool,
) {
ferrum_marlin_mm_f16_u4b8(
a,
b,
c,
c_tmp,
a_s,
b_s,
g_idx,
perm,
a_tmp,
prob_m,
prob_n,
prob_k,
lda,
workspace,
has_act_order,
is_k_full,
num_groups,
group_size,
dev,
stream,
-1, -1, sms,
use_atomic_add,
use_fp32_reduce,
);
}
pub fn load_stacked_gptq_vllm_marlin(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
qweights: &[&[i32]],
scales_f32: &[&[f32]],
bits: u32,
group_size: usize,
k: usize,
n_per_expert: usize,
) -> candle_core::Result<crate::marlin::MarlinWeight> {
if bits != 4 {
return Err(candle_core::Error::Msg(format!(
"vLLM stacked Marlin: bits={bits} unsupported (only 4)"
)));
}
let num_experts = qweights.len();
if num_experts == 0 || scales_f32.len() != num_experts {
return Err(candle_core::Error::Msg(format!(
"vLLM stacked Marlin: shape mismatch qw={} sc={}",
num_experts,
scales_f32.len()
)));
}
let qw_per = (k / 8) * n_per_expert;
let groups = k / group_size;
let sc_per = groups * n_per_expert;
let total_qw = num_experts * qw_per;
let total_sc = num_experts * sc_per;
let qw_out: cudarc::driver::CudaSlice<i32> = stream
.alloc_zeros::<i32>(total_qw)
.map_err(|err| candle_core::Error::Msg(format!("alloc stacked qw: {err}")))?;
use cudarc::driver::DevicePtr;
let raw_stream = stream.cu_stream();
for e in 0..num_experts {
if qweights[e].len() != qw_per {
return Err(candle_core::Error::Msg(format!(
"vLLM stacked Marlin: qweight[{e}].len()={} expected {qw_per}",
qweights[e].len()
)));
}
let qw_in_dev: cudarc::driver::CudaSlice<i32> = stream
.clone_htod(qweights[e])
.map_err(|err| candle_core::Error::Msg(format!("htod qw[{e}]: {err}")))?;
let (out_base_ptr, _g) = qw_out.device_ptr(stream);
let out_offset_bytes = (e * qw_per * std::mem::size_of::<i32>()) as u64;
let (in_ptr, _ig) = qw_in_dev.device_ptr(stream);
let ret = unsafe {
ferrum_vllm_gptq_marlin_repack(
in_ptr as *const _,
std::ptr::null(),
(out_base_ptr + out_offset_bytes) as *mut _,
k as i32,
n_per_expert as i32,
bits as i32,
0, 0, raw_stream,
)
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"repack expert {e} failed ret={ret}"
)));
}
}
let mut sc_flat_f16: Vec<half::f16> = Vec::with_capacity(total_sc);
for e in 0..num_experts {
if scales_f32[e].len() != sc_per {
return Err(candle_core::Error::Msg(format!(
"vLLM stacked Marlin: scales[{e}].len()={} expected {sc_per}",
scales_f32[e].len()
)));
}
let sc_e_f16: Vec<half::f16> = scales_f32[e]
.iter()
.map(|&x| half::f16::from_f32(x))
.collect();
let sc_e_perm =
crate::marlin::repack_scales_to_marlin(&sc_e_f16, k, n_per_expert, group_size);
sc_flat_f16.extend(sc_e_perm);
}
let sc_dev: cudarc::driver::CudaSlice<half::f16> = stream
.clone_htod(sc_flat_f16.as_slice())
.map_err(|err| candle_core::Error::Msg(format!("htod stacked scales: {err}")))?;
let ws_per_expert = (n_per_expert / 64).max(1) * 16;
let ws_total = num_experts * ws_per_expert;
let workspace: cudarc::driver::CudaSlice<i32> = stream
.alloc_zeros::<i32>(ws_total)
.map_err(|err| candle_core::Error::Msg(format!("alloc workspace: {err}")))?;
stream
.synchronize()
.map_err(|err| candle_core::Error::Msg(format!("sync after repack: {err}")))?;
Ok(crate::marlin::MarlinWeight {
qweight: qw_out,
scales: sc_dev,
workspace,
k,
n: n_per_expert * num_experts, group_size: group_size as i32,
perm: None,
})
}
pub fn vllm_gptq_marlin_repack(
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
qweight_in_dev: &cudarc::driver::CudaSlice<i32>,
qweight_out_dev: &mut cudarc::driver::CudaSlice<i32>,
size_k: i32,
size_n: i32,
) -> candle_core::Result<()> {
use cudarc::driver::DevicePtr;
let raw_stream = stream.cu_stream();
let (in_ptr, _ig) = qweight_in_dev.device_ptr(stream);
let (out_ptr, _og) = qweight_out_dev.device_ptr(stream);
let ret = unsafe {
ferrum_vllm_gptq_marlin_repack(
in_ptr as *const _,
std::ptr::null(),
out_ptr as *mut _,
size_k,
size_n,
4, 0, 0, raw_stream,
)
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"vllm gptq_marlin_repack failed: ret={ret} (size_k={size_k}, size_n={size_n})"
)));
}
Ok(())
}