use cudarc::driver::{CudaSlice, CudaStream, DevicePtr};
use std::sync::Arc;
#[cfg(feature = "marlin")]
extern "C" {
fn marlin_cuda(
A: *const std::ffi::c_void,
B: *const std::ffi::c_void,
C: *mut std::ffi::c_void,
s: *const std::ffi::c_void,
prob_m: i32,
prob_n: i32,
prob_k: i32,
workspace: *mut std::ffi::c_void,
groupsize: i32,
dev: i32,
stream: cudarc::driver::sys::CUstream,
thread_k: i32,
thread_n: i32,
sms: i32,
max_par: i32,
) -> i32;
}
pub fn is_available() -> bool {
cfg!(feature = "marlin")
}
pub struct MarlinWeight {
pub qweight: CudaSlice<i32>,
pub scales: CudaSlice<half::f16>,
pub workspace: CudaSlice<i32>,
pub k: usize,
pub n: usize,
pub group_size: i32,
}
#[cfg(feature = "marlin")]
pub fn marlin_gemm(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
m: i32,
) -> candle_core::Result<()> {
let n = weight.n as i32;
let k = weight.k as i32;
let raw_stream = stream.cu_stream();
{
let (ws_ptr, _guard) = weight.workspace.device_ptr(stream);
unsafe {
cudarc::driver::sys::cuMemsetD32Async(ws_ptr, 0, weight.workspace.len(), raw_stream);
}
}
let (a_ptr, _a_guard) = input.device_ptr(stream);
let (b_ptr, _b_guard) = weight.qweight.device_ptr(stream);
let (c_ptr, _c_guard) = output.device_ptr(stream);
let (s_ptr, _s_guard) = weight.scales.device_ptr(stream);
let (ws_ptr, _ws_guard) = weight.workspace.device_ptr(stream);
let ret = unsafe {
marlin_cuda(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
s_ptr as *const _,
m,
n,
k,
ws_ptr as *mut _,
weight.group_size,
0, raw_stream,
-1, -1, -1, 16, )
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_cuda failed: ret={ret} (m={m}, n={n}, k={k}, gs={})",
weight.group_size
)));
}
Ok(())
}
#[cfg(not(feature = "marlin"))]
pub fn marlin_gemm(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_m: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"Marlin kernel not available (compile with --features marlin)".into(),
))
}
pub fn repack_gptq_to_marlin(
qweight_gptq: &[i32], k: usize,
n: usize,
) -> Vec<i32> {
let packed_rows = k / 8;
let mut kn = vec![0u8; k * n]; for pr in 0..packed_rows {
for col in 0..n {
let packed = qweight_gptq[pr * n + col];
for i in 0..8 {
kn[(pr * 8 + i) * n + col] = ((packed >> (i * 4)) & 0xF) as u8;
}
}
}
let tile = 16;
let kt = k / tile;
let nt = n / tile;
let mut tiled = vec![0u8; k * n]; for tk in 0..kt {
for tn in 0..nt {
for ik in 0..tile {
for in_ in 0..tile {
let src = (tk * tile + ik) * n + (tn * tile + in_);
let dst = tk * (n * tile) + tn * (tile * tile) + ik * tile + in_;
tiled[dst] = kn[src];
}
}
}
}
let perm = build_marlin_perm();
let total = k * n;
let mut permuted = vec![0u8; total];
let num_blocks = total / 1024;
for blk in 0..num_blocks {
let base = blk * 1024;
for (dst, &src) in perm.iter().enumerate() {
permuted[base + dst] = tiled[base + src];
}
}
let packed_len = total / 8;
let mut result = vec![0i32; packed_len];
for i in 0..packed_len {
let mut word = 0u32;
for j in 0..8 {
word |= (permuted[i * 8 + j] as u32) << (j * 4);
}
result[i] = word as i32;
}
result
}
pub fn repack_scales_to_marlin(
scales_gptq: &[half::f16], k: usize,
n: usize,
group_size: usize,
) -> Vec<half::f16> {
let num_groups = k / group_size;
let scale_perm: Vec<usize> = if num_groups > 1 {
(0..8)
.flat_map(|i| (0..8).map(move |j| i + 8 * j))
.collect()
} else {
(0..4)
.flat_map(|i| [0, 1, 8, 9, 16, 17, 24, 25].map(move |j| 2 * i + j))
.collect()
};
let total = num_groups * n;
let perm_len = scale_perm.len();
let mut result = vec![half::f16::ZERO; total];
for blk in 0..(total / perm_len) {
let base = blk * perm_len;
for (dst, &src) in scale_perm.iter().enumerate() {
result[base + dst] = scales_gptq[base + src];
}
}
let rem_start = (total / perm_len) * perm_len;
for i in rem_start..total {
result[i] = scales_gptq[i];
}
result
}
fn build_marlin_perm() -> Vec<usize> {
let mut perm = Vec::with_capacity(1024);
for i in 0..32 {
let col = i / 4;
let mut perm1 = Vec::with_capacity(8);
for block in 0..2 {
for &row_off in &[0, 1, 8, 9] {
let row = 2 * (i % 4) + row_off / 8 * 8 + row_off % 8;
let _ = row; }
}
perm1.clear();
for block in 0..2 {
for &row in &[
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
] {
perm1.push(16 * row + col + 8 * block);
}
}
for j in 0..4 {
for &p in &perm1 {
perm.push(p + 256 * j);
}
}
}
assert_eq!(perm.len(), 1024);
let interleave = [0usize, 2, 4, 6, 1, 3, 5, 7];
let mut perm_interleaved = vec![0usize; 1024];
for g in 0..128 {
for i in 0..8 {
perm_interleaved[g * 8 + i] = perm[g * 8 + interleave[i]];
}
}
perm_interleaved
}