use cudarc::driver::{CudaSlice, CudaStream, DevicePtr};
use std::sync::Arc;
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CudaMarlinRuntimeConfig {
skip_ws_zero: bool,
}
impl CudaMarlinRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self {
skip_ws_zero: false,
};
for (name, value) in vars {
if name.as_ref() == "FERRUM_MARLIN_SKIP_WS_ZERO" {
config.skip_ws_zero = value.as_ref() == "1";
}
}
config
}
}
fn cuda_marlin_runtime_config() -> &'static CudaMarlinRuntimeConfig {
static CONFIG: OnceLock<CudaMarlinRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(CudaMarlinRuntimeConfig::from_env)
}
fn skip_ws_zero() -> bool {
cuda_marlin_runtime_config().skip_ws_zero
}
#[cfg(feature = "marlin")]
extern "C" {
fn marlin_cuda(
A: *const std::ffi::c_void,
B: *const std::ffi::c_void,
C: *mut std::ffi::c_void,
s: *const std::ffi::c_void,
prob_m: i32,
prob_n: i32,
prob_k: i32,
workspace: *mut std::ffi::c_void,
groupsize: i32,
dev: i32,
stream: cudarc::driver::sys::CUstream,
thread_k: i32,
thread_n: i32,
sms: i32,
max_par: i32,
prob_n_full: i32,
) -> i32;
fn marlin_cuda_moe(
A: *const std::ffi::c_void,
B: *const std::ffi::c_void,
C: *mut std::ffi::c_void,
s: *const std::ffi::c_void,
prob_m: i32,
prob_n: i32,
prob_k: i32,
workspace: *mut std::ffi::c_void,
a_row_offsets: *const i32, tokens_per_expert: *const i32, active_expert_ids: *const i32, expert_count: i32,
b_int4_per_expert: i32,
s_int4_per_expert: i32,
locks_i32_per_expert: i32,
groupsize: i32,
dev: i32,
stream: cudarc::driver::sys::CUstream,
thread_k: i32,
thread_n: i32,
sms: i32,
prob_n_full: i32,
) -> i32;
}
#[cfg(feature = "vllm-moe-marlin")]
extern "C" {
fn ferrum_vllm_marlin_moe_set_profile_config(
path: *const std::ffi::c_char,
commit_sha: *const std::ffi::c_char,
env_hash: *const std::ffi::c_char,
model: *const std::ffi::c_char,
concurrency: i32,
runtime_flags_json: *const std::ffi::c_char,
);
fn ferrum_vllm_marlin_moe_clear_profile_config();
fn ferrum_vllm_marlin_moe_f16(
a: *const std::ffi::c_void, b: *const std::ffi::c_void, c: *mut std::ffi::c_void, c_tmp: *mut std::ffi::c_void, b_scales: *const std::ffi::c_void, workspace: *mut std::ffi::c_void, sorted_token_ids: *const i32,
expert_ids: *const i32,
num_tokens_past_padded: *const i32,
topk_weights: *const f32, moe_block_size: i32, top_k: i32,
mul_topk_weights: i32, is_ep: i32, prob_m: i32,
prob_n: i32,
prob_k: i32,
group_size: i32, dev: i32,
stream: cudarc::driver::sys::CUstream,
use_atomic_add: i32,
use_fp32_reduce: i32,
) -> i32;
}
#[cfg(feature = "vllm-moe-marlin")]
pub fn configure_vllm_moe_profile_sink(
config: &ferrum_bench_core::ProfileSinkConfig,
) -> std::io::Result<()> {
use std::ffi::CString;
let Some(path) = &config.jsonl_path else {
unsafe { ferrum_vllm_marlin_moe_clear_profile_config() };
return Ok(());
};
let path = CString::new(path.as_os_str().to_string_lossy().into_owned()).map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("profile path contains NUL byte: {err}"),
)
})?;
let commit_sha = CString::new(
config
.metadata
.commit_sha
.as_deref()
.unwrap_or_default()
.to_string(),
)
.map_err(profile_cstring_error("profile commit_sha"))?;
let env_hash = CString::new(config.metadata.env_hash.clone())
.map_err(profile_cstring_error("env_hash"))?;
let model =
CString::new(config.metadata.model.clone()).map_err(profile_cstring_error("model"))?;
let runtime_flags_json =
serde_json::to_string(&config.metadata.runtime_flags).unwrap_or_else(|_| "{}".to_string());
let runtime_flags_json =
CString::new(runtime_flags_json).map_err(profile_cstring_error("runtime_flags_json"))?;
unsafe {
ferrum_vllm_marlin_moe_set_profile_config(
path.as_ptr(),
commit_sha.as_ptr(),
env_hash.as_ptr(),
model.as_ptr(),
config.metadata.concurrency.min(i32::MAX as u32) as i32,
runtime_flags_json.as_ptr(),
);
}
Ok(())
}
#[cfg(feature = "vllm-moe-marlin")]
fn profile_cstring_error(field: &'static str) -> impl FnOnce(std::ffi::NulError) -> std::io::Error {
move |err| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("{field} contains NUL byte: {err}"),
)
}
}
pub fn is_available() -> bool {
cfg!(feature = "marlin")
}
pub struct MarlinWeight {
pub qweight: CudaSlice<i32>,
pub scales: CudaSlice<half::f16>,
pub workspace: CudaSlice<i32>,
pub k: usize,
pub n: usize,
pub group_size: i32,
pub perm: Option<CudaSlice<i32>>,
}
#[cfg(feature = "marlin")]
pub fn marlin_gemm(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
m: i32,
) -> candle_core::Result<()> {
let n = weight.n as i32;
let k = weight.k as i32;
let raw_stream = stream.cu_stream();
{
let (ws_ptr, _guard) = weight.workspace.device_ptr(stream);
unsafe {
cudarc::driver::sys::cuMemsetD32Async(ws_ptr, 0, weight.workspace.len(), raw_stream);
}
}
let (a_ptr, _a_guard) = input.device_ptr(stream);
let (b_ptr, _b_guard) = weight.qweight.device_ptr(stream);
let (c_ptr, _c_guard) = output.device_ptr(stream);
let (s_ptr, _s_guard) = weight.scales.device_ptr(stream);
let (ws_ptr, _ws_guard) = weight.workspace.device_ptr(stream);
let ret = unsafe {
marlin_cuda(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
s_ptr as *const _,
m,
n,
k,
ws_ptr as *mut _,
weight.group_size,
0, raw_stream,
-1, -1, -1, 16, -1, )
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_cuda failed: ret={ret} (m={m}, n={n}, k={k}, gs={})",
weight.group_size
)));
}
Ok(())
}
#[cfg(not(feature = "marlin"))]
pub fn marlin_gemm(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_m: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"Marlin kernel not available (compile with --features marlin)".into(),
))
}
#[cfg(feature = "marlin")]
pub fn marlin_gemm_with_offset(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
m: i32,
expert_offset: i32,
expert_n: i32,
) -> candle_core::Result<()> {
use cudarc::driver::DevicePtr;
let n = expert_n;
let k = weight.k as i32;
if expert_offset < 0 || expert_n <= 0 || expert_offset + expert_n > weight.n as i32 {
return Err(candle_core::Error::Msg(format!(
"marlin offset out of range: offset={expert_offset} n={expert_n} stacked_n={}",
weight.n
)));
}
let raw_stream = stream.cu_stream();
let expert_idx = (expert_offset / expert_n) as usize;
let n_per = expert_n as usize;
let k_us = k as usize;
const MAX_PAR: usize = 16;
let ws_per_expert = (n_per / 128).max(1) * MAX_PAR;
let ws_offset_bytes = expert_idx * ws_per_expert * std::mem::size_of::<i32>();
{
let (ws_ptr, _g) = weight.workspace.device_ptr(stream);
unsafe {
cudarc::driver::sys::cuMemsetD32Async(
ws_ptr + ws_offset_bytes as u64,
0,
ws_per_expert,
raw_stream,
);
}
}
let qw_per_expert_i32 = (n_per * k_us) / 8;
let qw_offset_bytes = expert_idx * qw_per_expert_i32 * std::mem::size_of::<i32>();
let num_groups = k_us / weight.group_size as usize;
let sc_per_expert_f16 = num_groups * n_per;
let scales_offset_bytes = expert_idx * sc_per_expert_f16 * std::mem::size_of::<half::f16>();
let (a_ptr, _a_guard) = input.device_ptr(stream);
let (b_ptr_full, _b_guard) = weight.qweight.device_ptr(stream);
let (c_ptr, _c_guard) = output.device_ptr(stream);
let (s_ptr_full, _s_guard) = weight.scales.device_ptr(stream);
let (ws_ptr_full, _ws_guard) = weight.workspace.device_ptr(stream);
let b_ptr = b_ptr_full + qw_offset_bytes as u64;
let s_ptr = s_ptr_full + scales_offset_bytes as u64;
let ws_ptr = ws_ptr_full + ws_offset_bytes as u64;
let ret = unsafe {
marlin_cuda(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
s_ptr as *const _,
m,
n,
k,
ws_ptr as *mut _,
weight.group_size,
0,
raw_stream,
-1,
-1,
-1,
16,
-1,
)
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_cuda (offset) failed ret={ret} m={m} n={n} k={k} offset={expert_offset}"
)));
}
Ok(())
}
#[cfg(not(feature = "marlin"))]
pub fn marlin_gemm_with_offset(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_m: i32,
_expert_offset: i32,
_expert_n: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"Marlin kernel not available (compile with --features marlin)".into(),
))
}
#[cfg(feature = "marlin")]
#[allow(clippy::too_many_arguments)]
pub fn marlin_gemm_with_offset_strided(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
in_row_offset: i32,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
out_row_offset: i32,
m: i32,
expert_offset: i32,
expert_n: i32,
) -> candle_core::Result<()> {
use cudarc::driver::DevicePtr;
let n = expert_n;
let k = weight.k as i32;
if expert_offset < 0 || expert_n <= 0 || expert_offset + expert_n > weight.n as i32 {
return Err(candle_core::Error::Msg(format!(
"marlin offset out of range: offset={expert_offset} n={expert_n} stacked_n={}",
weight.n
)));
}
let raw_stream = stream.cu_stream();
let expert_idx = (expert_offset / expert_n) as usize;
let n_per = expert_n as usize;
let k_us = k as usize;
const MAX_PAR: usize = 16;
let ws_per_expert = (n_per / 128).max(1) * MAX_PAR;
let ws_offset_bytes = expert_idx * ws_per_expert * std::mem::size_of::<i32>();
if !skip_ws_zero() {
let (ws_ptr, _g) = weight.workspace.device_ptr(stream);
unsafe {
cudarc::driver::sys::cuMemsetD32Async(
ws_ptr + ws_offset_bytes as u64,
0,
ws_per_expert,
raw_stream,
);
}
}
let qw_per_expert_i32 = (n_per * k_us) / 8;
let qw_offset_bytes = expert_idx * qw_per_expert_i32 * std::mem::size_of::<i32>();
let num_groups = k_us / weight.group_size as usize;
let sc_per_expert_f16 = num_groups * n_per;
let scales_offset_bytes = expert_idx * sc_per_expert_f16 * std::mem::size_of::<half::f16>();
let in_offset_bytes = in_row_offset as usize * (k as usize) * std::mem::size_of::<half::f16>();
let out_offset_bytes =
out_row_offset as usize * (n as usize) * std::mem::size_of::<half::f16>();
let (a_ptr, _a_guard) = input.device_ptr(stream);
let (b_ptr_full, _b_guard) = weight.qweight.device_ptr(stream);
let (c_ptr, _c_guard) = output.device_ptr(stream);
let (s_ptr_full, _s_guard) = weight.scales.device_ptr(stream);
let (ws_ptr_full, _ws_guard) = weight.workspace.device_ptr(stream);
let a_ptr_off = a_ptr + in_offset_bytes as u64;
let b_ptr = b_ptr_full + qw_offset_bytes as u64;
let c_ptr_off = c_ptr + out_offset_bytes as u64;
let s_ptr = s_ptr_full + scales_offset_bytes as u64;
let ws_ptr = ws_ptr_full + ws_offset_bytes as u64;
let ret = unsafe {
marlin_cuda(
a_ptr_off as *const _,
b_ptr as *const _,
c_ptr_off as *mut _,
s_ptr as *const _,
m,
n,
k,
ws_ptr as *mut _,
weight.group_size,
0,
raw_stream,
-1,
-1,
-1,
16,
-1,
)
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_cuda (offset_strided) failed ret={ret} m={m} n={n} k={k} \
expert_offset={expert_offset} in_row_offset={in_row_offset} \
out_row_offset={out_row_offset}"
)));
}
Ok(())
}
#[cfg(not(feature = "marlin"))]
#[allow(clippy::too_many_arguments)]
pub fn marlin_gemm_with_offset_strided(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_in_row_offset: i32,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_out_row_offset: i32,
_m: i32,
_expert_offset: i32,
_expert_n: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"Marlin kernel not available (compile with --features marlin)".into(),
))
}
#[cfg(feature = "marlin")]
#[allow(clippy::too_many_arguments)]
pub fn marlin_gemm_moe(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
a_row_offsets: &CudaSlice<i32>,
tokens_per_expert: &CudaSlice<i32>,
active_expert_ids: Option<&CudaSlice<i32>>,
expert_count: i32,
prob_m: i32,
n_per_expert: i32,
num_experts_global: i32,
) -> candle_core::Result<()> {
use cudarc::driver::DevicePtr;
if expert_count <= 0 {
return Ok(());
}
if prob_m <= 0 || prob_m > 64 || prob_m % 16 != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_gemm_moe: prob_m must be in {{16, 32, 48, 64}}, got {prob_m}"
)));
}
let n = n_per_expert;
let k = weight.k as i32;
let n_per = n as usize;
let k_us = k as usize;
if n_per == 0 || (weight.n as i32) < num_experts_global * n {
return Err(candle_core::Error::Msg(format!(
"marlin_gemm_moe: stacked weight N={} too small for E_global={num_experts_global} × n_per={n}",
weight.n
)));
}
const MAX_PAR: usize = 16;
let b_int4_per_expert = ((n_per * k_us) / 32) as i32;
let groups = k_us / weight.group_size as usize;
let s_int4_per_expert = ((groups * n_per) / 8) as i32;
let locks_i32_per_expert = (((n_per / 128).max(1)) * MAX_PAR) as i32;
let raw_stream = stream.cu_stream();
let (a_ptr, _ag) = input.device_ptr(stream);
let (b_ptr, _bg) = weight.qweight.device_ptr(stream);
let (c_ptr, _cg) = output.device_ptr(stream);
let (s_ptr, _sg) = weight.scales.device_ptr(stream);
let (ws_ptr, _wg) = weight.workspace.device_ptr(stream);
let (off_ptr, _og) = a_row_offsets.device_ptr(stream);
let (tok_ptr, _tg) = tokens_per_expert.device_ptr(stream);
let act_ptr_opt = active_expert_ids.map(|s| s.device_ptr(stream));
let act_raw: u64 = match &act_ptr_opt {
Some((p, _)) => *p,
None => 0,
};
let ret = unsafe {
marlin_cuda_moe(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
s_ptr as *const _,
prob_m,
n,
k,
ws_ptr as *mut _,
off_ptr as *const _,
tok_ptr as *const _,
act_raw as *const _,
expert_count,
b_int4_per_expert,
s_int4_per_expert,
locks_i32_per_expert,
weight.group_size,
0, raw_stream,
-1,
-1,
-1,
n, )
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_cuda_moe failed: ret={ret} (prob_m={prob_m}, n={n}, k={k}, \
experts={expert_count}, gs={})",
weight.group_size
)));
}
Ok(())
}
#[cfg(not(feature = "marlin"))]
#[allow(clippy::too_many_arguments)]
pub fn marlin_gemm_moe(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_a_row_offsets: &CudaSlice<i32>,
_tokens_per_expert: &CudaSlice<i32>,
_active_expert_ids: Option<&CudaSlice<i32>>,
_expert_count: i32,
_prob_m: i32,
_n_per_expert: i32,
_num_experts_global: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"Marlin kernel not available (compile with --features marlin)".into(),
))
}
#[cfg(feature = "vllm-moe-marlin")]
#[allow(clippy::too_many_arguments)]
pub fn marlin_gemm_moe_vllm(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
c_tmp: Option<&mut CudaSlice<f32>>,
sorted_token_ids: &CudaSlice<i32>,
expert_ids: &CudaSlice<i32>,
num_tokens_past_padded: &CudaSlice<i32>,
topk_weights: Option<&CudaSlice<f32>>,
moe_block_size: i32,
top_k: i32,
mul_topk_weights: bool,
is_ep: bool,
prob_m: i32,
prob_n: i32,
prob_k: i32,
) -> candle_core::Result<()> {
use cudarc::driver::DevicePtr;
let raw_stream = stream.cu_stream();
let (a_ptr, _ag) = input.device_ptr(stream);
let (b_ptr, _bg) = weight.qweight.device_ptr(stream);
let (c_ptr, _cg) = output.device_ptr(stream);
let (s_ptr, _sg) = weight.scales.device_ptr(stream);
let (ws_ptr, _wg) = weight.workspace.device_ptr(stream);
let (st_ptr, _stg) = sorted_token_ids.device_ptr(stream);
let (eid_ptr, _eidg) = expert_ids.device_ptr(stream);
let (npp_ptr, _nppg) = num_tokens_past_padded.device_ptr(stream);
let c_tmp_ptr = match c_tmp.as_ref() {
Some(c) => c.device_ptr(stream).0 as *mut std::ffi::c_void,
None => std::ptr::null_mut(),
};
let topk_w_ptr = match topk_weights {
Some(w) => w.device_ptr(stream).0 as *const f32,
None => std::ptr::null(),
};
let ret = unsafe {
ferrum_vllm_marlin_moe_f16(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
c_tmp_ptr,
s_ptr as *const _,
ws_ptr as *mut _,
st_ptr as *const _,
eid_ptr as *const _,
npp_ptr as *const _,
topk_w_ptr,
moe_block_size,
top_k,
if mul_topk_weights { 1 } else { 0 },
if is_ep { 1 } else { 0 },
prob_m,
prob_n,
prob_k,
weight.group_size,
0, raw_stream,
if c_tmp_ptr.is_null() { 1 } else { 0 }, if c_tmp_ptr.is_null() { 0 } else { 1 }, )
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"ferrum_vllm_marlin_moe_f16 failed: ret={ret} (m={prob_m}, n={prob_n}, k={prob_k})"
)));
}
Ok(())
}
#[cfg(not(feature = "vllm-moe-marlin"))]
#[allow(clippy::too_many_arguments)]
pub fn marlin_gemm_moe_vllm(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_c_tmp: Option<&mut CudaSlice<f32>>,
_sorted_token_ids: &CudaSlice<i32>,
_expert_ids: &CudaSlice<i32>,
_num_tokens_past_padded: &CudaSlice<i32>,
_topk_weights: Option<&CudaSlice<f32>>,
_moe_block_size: i32,
_top_k: i32,
_mul_topk_weights: bool,
_is_ep: bool,
_prob_m: i32,
_prob_n: i32,
_prob_k: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"vLLM marlin_moe_wna16 not built — compile with --features vllm-moe-marlin".into(),
))
}
pub fn permute_gptq_qweight_rows(
qweight_gptq: &[i32], perm: &[usize], k: usize,
n: usize,
) -> Vec<i32> {
debug_assert_eq!(perm.len(), k);
debug_assert_eq!(qweight_gptq.len(), (k / 8) * n);
let mut kn = vec![0u8; k * n];
let packed_rows = k / 8;
for pr in 0..packed_rows {
for col in 0..n {
let packed = qweight_gptq[pr * n + col] as u32;
for i in 0..8 {
kn[(pr * 8 + i) * n + col] = ((packed >> (i * 4)) & 0xF) as u8;
}
}
}
let mut sorted = vec![0u8; k * n];
for i in 0..k {
let src_row = perm[i];
for col in 0..n {
sorted[i * n + col] = kn[src_row * n + col];
}
}
let mut packed = vec![0i32; (k / 8) * n];
for pr in 0..packed_rows {
for col in 0..n {
let mut word = 0u32;
for i in 0..8 {
word |= (sorted[(pr * 8 + i) * n + col] as u32) << (i * 4);
}
packed[pr * n + col] = word as i32;
}
}
packed
}
pub fn repack_gptq_to_marlin(
qweight_gptq: &[i32], k: usize,
n: usize,
) -> Vec<i32> {
use rayon::prelude::*;
let _packed_rows = k / 8;
let mut kn = vec![0u8; k * n];
kn.par_chunks_mut(8 * n)
.zip(qweight_gptq.par_chunks(n))
.for_each(|(kn_block, qw_row)| {
for col in 0..n {
let packed = qw_row[col];
for i in 0..8 {
kn_block[i * n + col] = ((packed >> (i * 4)) & 0xF) as u8;
}
}
});
let tile = 16;
let _kt = k / tile;
let nt = n / tile;
let mut tiled = vec![0u8; k * n];
tiled
.par_chunks_mut(n * tile)
.enumerate()
.for_each(|(tk, tile_block)| {
for tn in 0..nt {
for ik in 0..tile {
for in_ in 0..tile {
let src = (tk * tile + ik) * n + (tn * tile + in_);
let dst = tn * (tile * tile) + ik * tile + in_;
tile_block[dst] = kn[src];
}
}
}
});
drop(kn);
let perm = build_marlin_perm();
let total = k * n;
let mut permuted = vec![0u8; total];
permuted
.par_chunks_mut(1024)
.zip(tiled.par_chunks(1024))
.for_each(|(out_blk, in_blk)| {
for (dst, &src) in perm.iter().enumerate() {
out_blk[dst] = in_blk[src];
}
});
drop(tiled);
let packed_len = total / 8;
let mut result = vec![0i32; packed_len];
result
.par_iter_mut()
.zip(permuted.par_chunks_exact(8))
.for_each(|(out, chunk)| {
let mut word = 0u32;
for (j, &b) in chunk.iter().enumerate() {
word |= (b as u32) << (j * 4);
}
*out = word as i32;
});
result
}
pub fn repack_scales_to_marlin(
scales_gptq: &[half::f16], k: usize,
n: usize,
group_size: usize,
) -> Vec<half::f16> {
let num_groups = k / group_size;
let scale_perm: Vec<usize> = if num_groups > 1 {
(0..8)
.flat_map(|i| (0..8).map(move |j| i + 8 * j))
.collect()
} else {
(0..4)
.flat_map(|i| [0, 1, 8, 9, 16, 17, 24, 25].map(move |j| 2 * i + j))
.collect()
};
let total = num_groups * n;
let perm_len = scale_perm.len();
let mut result = vec![half::f16::ZERO; total];
for blk in 0..(total / perm_len) {
let base = blk * perm_len;
for (dst, &src) in scale_perm.iter().enumerate() {
result[base + dst] = scales_gptq[base + src];
}
}
let rem_start = (total / perm_len) * perm_len;
for i in rem_start..total {
result[i] = scales_gptq[i];
}
result
}
fn build_marlin_perm() -> Vec<usize> {
let mut perm = Vec::with_capacity(1024);
for i in 0..32 {
let col = i / 4;
let mut perm1 = Vec::with_capacity(8);
for _block in 0..2 {
for &row_off in &[0, 1, 8, 9] {
let row = 2 * (i % 4) + row_off / 8 * 8 + row_off % 8;
let _ = row; }
}
perm1.clear();
for block in 0..2 {
for &row in &[
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
] {
perm1.push(16 * row + col + 8 * block);
}
}
for j in 0..4 {
for &p in &perm1 {
perm.push(p + 256 * j);
}
}
}
assert_eq!(perm.len(), 1024);
let interleave = [0usize, 2, 4, 6, 1, 3, 5, 7];
let mut perm_interleaved = vec![0usize; 1024];
for g in 0..128 {
for i in 0..8 {
perm_interleaved[g * 8 + i] = perm[g * 8 + interleave[i]];
}
}
perm_interleaved
}
#[cfg(test)]
mod tests {
use super::CudaMarlinRuntimeConfig;
#[test]
fn cuda_marlin_runtime_config_parses_skip_ws_zero() {
let config = CudaMarlinRuntimeConfig::from_env_vars([("FERRUM_MARLIN_SKIP_WS_ZERO", "1")]);
assert!(config.skip_ws_zero);
}
#[test]
fn cuda_marlin_runtime_config_defaults_to_zero_workspace() {
let config =
CudaMarlinRuntimeConfig::from_env_vars([("FERRUM_MARLIN_SKIP_WS_ZERO", "true")]);
assert!(!config.skip_ws_zero);
}
}