use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use hanzo_ml::cuda::cudarc::driver::{CudaSlice, DevicePtr};
use hanzo_ml::{
quantized::{GgmlDType, QTensor},
CudaDevice, CudaStorage, DType, Device, Result, Shape, Storage, Tensor,
};
use super::ffi;
use crate::utils::slice_ptr;
const QK8_1: usize = 32;
const BLOCK_Q8_1_MMQ_SIZE: usize = 4 * QK8_1 + 4 * 4; const MATRIX_ROW_PADDING: usize = 512;
#[inline]
fn pad(p: usize, q: usize) -> usize {
p.div_ceil(q) * q
}
fn output_shape(xs: &Tensor, nrows: usize) -> Shape {
let mut out_dims = xs.dims().to_vec();
let last = out_dims.len() - 1;
out_dims[last] = nrows;
Shape::from(out_dims)
}
pub fn supports(dtype: GgmlDType) -> bool {
matches!(
dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
)
}
fn qk_for(dtype: GgmlDType) -> usize {
match dtype {
GgmlDType::Q4_0 | GgmlDType::Q4_1 | GgmlDType::Q5_0 | GgmlDType::Q5_1 | GgmlDType::Q8_0 => {
32
}
GgmlDType::Q2K | GgmlDType::Q3K | GgmlDType::Q4K | GgmlDType::Q5K | GgmlDType::Q6K => 256,
_ => unreachable!(),
}
}
enum DsLayout {
D4,
DS4,
D2S6,
}
fn ds_layout_for(dtype: GgmlDType) -> DsLayout {
match dtype {
GgmlDType::Q4_0 | GgmlDType::Q4_1 => DsLayout::DS4,
GgmlDType::Q5_0 => DsLayout::D4,
GgmlDType::Q5_1 => DsLayout::DS4,
GgmlDType::Q8_0 => DsLayout::D4,
GgmlDType::Q2K => DsLayout::D2S6,
GgmlDType::Q3K => DsLayout::D4,
GgmlDType::Q4K | GgmlDType::Q5K => DsLayout::DS4,
GgmlDType::Q6K => DsLayout::D4,
_ => unreachable!(),
}
}
type QuantizeLauncher = unsafe extern "C" fn(
x: *const std::ffi::c_void,
ids: *const i32,
vy: *mut std::ffi::c_void,
type_x: i32,
ne00: i64,
s01: i64,
s02: i64,
s03: i64,
ne0: i64,
ne1: i64,
ne2: i64,
ne3: i64,
stream: *mut std::ffi::c_void,
);
type QuantizeGluF32Launcher = unsafe extern "C" fn(
gate: *const f32,
up: *const f32,
ids: *const i32,
vy: *mut std::ffi::c_void,
ne00: i64,
s01: i64,
ne0: i64,
ne1: i64,
activation: i32,
stream: *mut std::ffi::c_void,
);
fn quantize_launcher(layout: DsLayout) -> QuantizeLauncher {
match layout {
DsLayout::D4 => ffi::launch_mmq_quantize_q8_1_D4,
DsLayout::DS4 => ffi::launch_mmq_quantize_q8_1_DS4,
DsLayout::D2S6 => ffi::launch_mmq_quantize_q8_1_D2S6,
}
}
fn quantize_glu_f32_launcher(layout: DsLayout) -> QuantizeGluF32Launcher {
match layout {
DsLayout::D4 => ffi::launch_mmq_quantize_glu_q8_1_D4_f32,
DsLayout::DS4 => ffi::launch_mmq_quantize_glu_q8_1_DS4_f32,
DsLayout::D2S6 => ffi::launch_mmq_quantize_glu_q8_1_D2S6_f32,
}
}
type MmqLauncher = unsafe extern "C" fn(
tmp_fixup: *mut std::ffi::c_void,
x: *const std::ffi::c_void,
y: *const std::ffi::c_void,
dst: *mut std::ffi::c_void,
ncols_x: i64,
nrows_x: i64,
ncols_y: i64,
stride_row_x: i64,
stride_col_dst: i64,
cc: i32,
nsm: i32,
smpbo: i64,
warp_size: i32,
stream: *mut std::ffi::c_void,
);
type MmqMoeLauncher = unsafe extern "C" fn(
tmp_fixup: *mut std::ffi::c_void,
x: *const std::ffi::c_void,
y: *const std::ffi::c_void,
ids_dst: *const i32,
expert_bounds: *const i32,
dst: *mut std::ffi::c_void,
ncols_x: i64,
nrows_x: i64,
ncols_dst: i64,
stride_row_x: i64,
stride_col_dst: i64,
num_experts: i64,
ncols_max: i64,
cc: i32,
nsm: i32,
smpbo: i64,
warp_size: i32,
stream: *mut std::ffi::c_void,
);
fn mmq_launcher(dtype: GgmlDType) -> Option<MmqLauncher> {
let f: MmqLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmq_gguf_q4_0,
GgmlDType::Q4_1 => ffi::launch_mmq_gguf_q4_1,
GgmlDType::Q5_0 => ffi::launch_mmq_gguf_q5_0,
GgmlDType::Q5_1 => ffi::launch_mmq_gguf_q5_1,
GgmlDType::Q8_0 => ffi::launch_mmq_gguf_q8_0,
GgmlDType::Q2K => ffi::launch_mmq_gguf_q2_k,
GgmlDType::Q3K => ffi::launch_mmq_gguf_q3_k,
GgmlDType::Q4K => ffi::launch_mmq_gguf_q4_k,
GgmlDType::Q5K => ffi::launch_mmq_gguf_q5_k,
GgmlDType::Q6K => ffi::launch_mmq_gguf_q6_k,
_ => return None,
};
Some(f)
}
fn mmq_moe_launcher(dtype: GgmlDType) -> Option<MmqMoeLauncher> {
let f: MmqMoeLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmq_gguf_q4_0_moe,
GgmlDType::Q4_1 => ffi::launch_mmq_gguf_q4_1_moe,
GgmlDType::Q5_0 => ffi::launch_mmq_gguf_q5_0_moe,
GgmlDType::Q5_1 => ffi::launch_mmq_gguf_q5_1_moe,
GgmlDType::Q8_0 => ffi::launch_mmq_gguf_q8_0_moe,
GgmlDType::Q2K => ffi::launch_mmq_gguf_q2_k_moe,
GgmlDType::Q3K => ffi::launch_mmq_gguf_q3_k_moe,
GgmlDType::Q4K => ffi::launch_mmq_gguf_q4_k_moe,
GgmlDType::Q5K => ffi::launch_mmq_gguf_q5_k_moe,
GgmlDType::Q6K => ffi::launch_mmq_gguf_q6_k_moe,
_ => return None,
};
Some(f)
}
struct WorkspaceSlot {
slice: CudaSlice<u8>,
cap: usize,
}
type WsMap = Mutex<HashMap<hanzo_ml::cuda::DeviceId, &'static Mutex<WorkspaceSlot>>>;
static MMQ_WORKSPACE: OnceLock<WsMap> = OnceLock::new();
static FIXUP_WORKSPACE: OnceLock<WsMap> = OnceLock::new();
#[derive(Clone, Copy)]
struct DeviceInfo {
cc: i32,
nsm: i32,
smpbo: i64,
warp_size: i32,
}
static DEVICE_INFO: OnceLock<Mutex<HashMap<hanzo_ml::cuda::DeviceId, DeviceInfo>>> =
OnceLock::new();
fn get_device_info(dev: &CudaDevice) -> DeviceInfo {
use hanzo_ml::cuda::cudarc::driver::{result, sys};
let map = DEVICE_INFO.get_or_init(|| Mutex::new(HashMap::new()));
let key = dev.id();
let mut guard = map.lock().unwrap();
if let Some(info) = guard.get(&key) {
return *info;
}
let cu_device = dev.cuda_stream().context().cu_device();
let major = unsafe {
result::device::get_attribute(
cu_device,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
)
}
.unwrap_or(8);
let minor = unsafe {
result::device::get_attribute(
cu_device,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
)
}
.unwrap_or(0);
let nsm = unsafe {
result::device::get_attribute(
cu_device,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
)
}
.unwrap_or(1);
let smpbo = unsafe {
result::device::get_attribute(
cu_device,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
)
}
.unwrap_or(49152);
let warp_size = unsafe {
result::device::get_attribute(
cu_device,
sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_WARP_SIZE,
)
}
.unwrap_or(32);
let info = DeviceInfo {
cc: major * 100 + minor * 10,
nsm,
smpbo: smpbo as i64,
warp_size,
};
guard.insert(key, info);
info
}
fn workspace_ensure(
ws: &'static OnceLock<WsMap>,
dev: &CudaDevice,
bytes: usize,
) -> Result<(u64, std::sync::MutexGuard<'static, WorkspaceSlot>)> {
let map = ws.get_or_init(|| Mutex::new(HashMap::new()));
let device_key = dev.id();
let device_mtx: &'static Mutex<WorkspaceSlot> = {
let mut guard = map.lock().unwrap();
match guard.get(&device_key).copied() {
Some(mtx) => mtx,
None => {
let slice = unsafe { dev.alloc::<u8>(bytes.max(1))? };
let leaked = Box::leak(Box::new(Mutex::new(WorkspaceSlot {
slice,
cap: bytes.max(1),
})));
guard.insert(device_key, leaked);
leaked
}
}
};
let mut slot = device_mtx.lock().unwrap();
if slot.cap < bytes {
slot.slice = unsafe { dev.alloc::<u8>(bytes)? };
slot.cap = bytes;
}
let ptr = slot.slice.device_ptr(slot.slice.stream()).0;
Ok((ptr, slot))
}
pub fn plain(w: &QTensor, xs: &Tensor) -> Result<Tensor> {
let dtype = w.dtype();
if !supports(dtype) {
hanzo_ml::bail!("fast_mmq: unsupported quant dtype {dtype:?}");
}
let Device::Cuda(dev) = w.device() else {
hanzo_ml::bail!("fast_mmq: weight must live on CUDA");
};
let (nrows, ncols) = w.shape().dims2()?;
let (b_size, k) = match xs.dims() {
[b, k] => (*b, *k),
[b, m, k] => (*b * *m, *k),
other => hanzo_ml::bail!("fast_mmq: unexpected input rank {other:?}"),
};
if k != ncols {
hanzo_ml::bail!("fast_mmq: shape mismatch — weight [{nrows}, {ncols}] vs input tail {k}");
}
if b_size == 0 {
hanzo_ml::bail!("fast_mmq: batch size must be > 0");
}
let qk = qk_for(dtype);
if k % qk != 0 {
hanzo_ml::bail!("fast_mmq: k={k} not divisible by qk={qk}");
}
let input_ty = xs.dtype();
if !matches!(input_ty, DType::BF16 | DType::F16 | DType::F32) {
hanzo_ml::bail!("fast_mmq: input dtype must be BF16, F16, or F32, got {input_ty:?}");
}
let xs = xs.contiguous()?;
let (xs_storage, xs_layout) = xs.storage_and_layout();
let Storage::Cuda(xs_cuda) = &*xs_storage else {
hanzo_ml::bail!("fast_mmq: input must live on CUDA");
};
let xs_offset = xs_layout.start_offset();
let type_x = match input_ty {
DType::F32 => 0,
DType::F16 => 1,
DType::BF16 => 30,
_ => unreachable!(),
};
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(k, MATRIX_ROW_PADDING);
let k_padded = pad(k_padded, 4 * QK8_1);
let blocks_per_row = k_padded / (4 * QK8_1);
let workspace_main = b_size * blocks_per_row * BLOCK_Q8_1_MMQ_SIZE;
let workspace_extra = 128 * BLOCK_Q8_1_MMQ_SIZE;
let workspace_bytes = workspace_main + workspace_extra;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&MMQ_WORKSPACE, &dev, workspace_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
const MMQ_X_MAX: usize = 128;
const MMQ_Y_MAX: usize = 128;
const MAX_SMS: usize = 256; let fixup_bytes = MAX_SMS * MMQ_X_MAX * MMQ_Y_MAX * std::mem::size_of::<f32>();
let (fixup_ptr, _fixup_guard) = workspace_ensure(&FIXUP_WORKSPACE, &dev, fixup_bytes)?;
let fixup_ptr = fixup_ptr as *mut std::ffi::c_void;
let weight_ptr = w.device_ptr()? as *const std::ffi::c_void;
let stride_row_x = (k / qk) as i64;
let di = get_device_info(&dev);
let out = unsafe { dev.alloc::<f32>(nrows * b_size)? };
let stride_col_dst = nrows as i64;
let quantize = quantize_launcher(ds_layout_for(dtype));
let launcher = mmq_launcher(dtype).expect("supports() checked");
match input_ty {
DType::BF16 => {
let slice = xs_cuda.as_cuda_slice::<half::bf16>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
quantize(
xs_ptr as *const std::ffi::c_void,
std::ptr::null(),
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
b_size as i64,
1,
1,
stream_ptr,
);
launcher(
fixup_ptr,
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i64,
nrows as i64,
b_size as i64,
stride_row_x,
stride_col_dst,
di.cc,
di.nsm,
di.smpbo,
di.warp_size,
stream_ptr,
);
}
}
DType::F16 => {
let slice = xs_cuda.as_cuda_slice::<half::f16>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
quantize(
xs_ptr as *const std::ffi::c_void,
std::ptr::null(),
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
b_size as i64,
1,
1,
stream_ptr,
);
launcher(
fixup_ptr,
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i64,
nrows as i64,
b_size as i64,
stride_row_x,
stride_col_dst,
di.cc,
di.nsm,
di.smpbo,
di.warp_size,
stream_ptr,
);
}
}
DType::F32 => {
let slice = xs_cuda.as_cuda_slice::<f32>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
quantize(
xs_ptr as *const std::ffi::c_void,
std::ptr::null(),
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
b_size as i64,
1,
1,
stream_ptr,
);
launcher(
fixup_ptr,
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i64,
nrows as i64,
b_size as i64,
stride_row_x,
stride_col_dst,
di.cc,
di.nsm,
di.smpbo,
di.warp_size,
stream_ptr,
);
}
}
_ => unreachable!(),
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
let out_tensor = Tensor::from((Storage::Cuda(out_storage), output_shape(&xs, nrows)));
if input_ty == DType::F32 {
Ok(out_tensor)
} else {
out_tensor.to_dtype(input_ty)
}
}
#[allow(clippy::too_many_arguments)]
pub fn grouped(
weight: &QTensor,
xs: &Tensor,
ids_src: &CudaSlice<u32>,
ids_dst: &CudaSlice<u32>,
expert_bounds: &CudaSlice<u32>,
total_assignments: usize,
ncols_max: usize,
num_experts: usize,
dev: &CudaDevice,
) -> Result<Tensor> {
let dtype = weight.dtype();
if !supports(dtype) {
hanzo_ml::bail!("fast_mmq grouped: unsupported quant dtype {dtype:?}");
}
let (_, k) = xs.dims2()?;
let (weight_experts, nrows, ncols) = weight.shape().dims3()?;
if weight_experts != num_experts {
hanzo_ml::bail!("fast_mmq grouped: expected {num_experts} experts, got {weight_experts}");
}
if k != ncols {
hanzo_ml::bail!("fast_mmq grouped: shape mismatch — weight cols {ncols} vs input tail {k}");
}
let qk = qk_for(dtype);
if k % qk != 0 {
hanzo_ml::bail!("fast_mmq grouped: k={k} not divisible by qk={qk}");
}
let input_ty = xs.dtype();
if !matches!(input_ty, DType::BF16 | DType::F16 | DType::F32) {
hanzo_ml::bail!(
"fast_mmq grouped: input dtype must be BF16, F16, or F32, got {input_ty:?}"
);
}
let xs = xs.contiguous()?;
let (xs_storage, xs_layout) = xs.storage_and_layout();
let Storage::Cuda(xs_cuda) = &*xs_storage else {
hanzo_ml::bail!("fast_mmq grouped: input must live on CUDA");
};
let xs_offset = xs_layout.start_offset();
let type_x = match input_ty {
DType::F32 => 0,
DType::F16 => 1,
DType::BF16 => 30,
_ => unreachable!(),
};
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(pad(k, MATRIX_ROW_PADDING), 4 * QK8_1);
let blocks_per_row = k_padded / (4 * QK8_1);
let workspace_main = total_assignments * blocks_per_row * BLOCK_Q8_1_MMQ_SIZE;
let workspace_extra = 128 * BLOCK_Q8_1_MMQ_SIZE;
let workspace_bytes = workspace_main + workspace_extra;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&MMQ_WORKSPACE, dev, workspace_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
const MMQ_X_MAX: usize = 128;
const MMQ_Y_MAX: usize = 128;
const MAX_SMS: usize = 256;
let fixup_bytes = MAX_SMS * MMQ_X_MAX * MMQ_Y_MAX * std::mem::size_of::<f32>();
let (fixup_ptr, _fixup_guard) = workspace_ensure(&FIXUP_WORKSPACE, dev, fixup_bytes)?;
let fixup_ptr = fixup_ptr as *mut std::ffi::c_void;
let out = unsafe { dev.alloc::<f32>(total_assignments * nrows)? };
let weight_ptr = weight.device_ptr()? as *const std::ffi::c_void;
let stride_row_x = (k / qk) as i64;
let stride_col_dst = nrows as i64;
let di = get_device_info(dev);
let quantize = quantize_launcher(ds_layout_for(dtype));
let launcher = mmq_moe_launcher(dtype).expect("supports() checked");
let (ids_src_ptr, _ids_src_guard) = slice_ptr(ids_src, 0);
let (ids_dst_ptr, _ids_dst_guard) = slice_ptr(ids_dst, 0);
let (bounds_ptr, _bounds_guard) = slice_ptr(expert_bounds, 0);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
match input_ty {
DType::BF16 => {
let slice = xs_cuda.as_cuda_slice::<half::bf16>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
quantize(
xs_ptr as *const std::ffi::c_void,
ids_src_ptr as *const i32,
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
total_assignments as i64,
1,
1,
stream_ptr,
);
}
DType::F16 => {
let slice = xs_cuda.as_cuda_slice::<half::f16>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
quantize(
xs_ptr as *const std::ffi::c_void,
ids_src_ptr as *const i32,
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
total_assignments as i64,
1,
1,
stream_ptr,
);
}
DType::F32 => {
let slice = xs_cuda.as_cuda_slice::<f32>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
quantize(
xs_ptr as *const std::ffi::c_void,
ids_src_ptr as *const i32,
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
total_assignments as i64,
1,
1,
stream_ptr,
);
}
_ => unreachable!(),
}
launcher(
fixup_ptr,
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
ids_dst_ptr as *const i32,
bounds_ptr as *const i32,
out_ptr as *mut std::ffi::c_void,
k as i64,
nrows as i64,
total_assignments as i64,
stride_row_x,
stride_col_dst,
num_experts as i64,
ncols_max as i64,
di.cc,
di.nsm,
di.smpbo,
di.warp_size,
stream_ptr,
);
}
drop(_out_guard);
drop(_bounds_guard);
drop(_ids_dst_guard);
drop(_ids_src_guard);
let out_shape: Shape = vec![total_assignments, nrows].into();
Ok(Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(out, dev.clone())),
out_shape,
)))
}
#[allow(clippy::too_many_arguments)]
pub fn grouped_from_glu_pair(
weight: &QTensor,
gate: &Tensor,
up: &Tensor,
ids_src: &CudaSlice<u32>,
ids_dst: &CudaSlice<u32>,
expert_bounds: &CudaSlice<u32>,
total_assignments: usize,
ncols_max: usize,
num_experts: usize,
activation: i32,
dev: &CudaDevice,
) -> Result<Tensor> {
let dtype = weight.dtype();
if !supports(dtype) {
hanzo_ml::bail!("fast_mmq grouped_from_glu_pair: unsupported quant dtype {dtype:?}");
}
let (gate_rows, k) = gate.dims2()?;
let (up_rows, up_k) = up.dims2()?;
if gate_rows != total_assignments || up_rows != total_assignments || up_k != k {
hanzo_ml::bail!(
"fast_mmq grouped_from_glu_pair: gate/up shape mismatch {:?} vs {:?}, total_assignments={total_assignments}",
gate.shape(),
up.shape()
);
}
if gate.dtype() != DType::F32 || up.dtype() != DType::F32 {
hanzo_ml::bail!(
"fast_mmq grouped_from_glu_pair: gate/up must be F32, got {:?} and {:?}",
gate.dtype(),
up.dtype()
);
}
let (weight_experts, nrows, ncols) = weight.shape().dims3()?;
if weight_experts != num_experts {
hanzo_ml::bail!(
"fast_mmq grouped_from_glu_pair: expected {num_experts} experts, got {weight_experts}"
);
}
if k != ncols {
hanzo_ml::bail!(
"fast_mmq grouped_from_glu_pair: shape mismatch — weight cols {ncols} vs input tail {k}"
);
}
let qk = qk_for(dtype);
if k % qk != 0 {
hanzo_ml::bail!("fast_mmq grouped_from_glu_pair: k={k} not divisible by qk={qk}");
}
let gate = gate.contiguous()?;
let up = up.contiguous()?;
let (gate_storage, gate_layout) = gate.storage_and_layout();
let Storage::Cuda(gate_cuda) = &*gate_storage else {
hanzo_ml::bail!("fast_mmq grouped_from_glu_pair: gate must live on CUDA");
};
let (up_storage, up_layout) = up.storage_and_layout();
let Storage::Cuda(up_cuda) = &*up_storage else {
hanzo_ml::bail!("fast_mmq grouped_from_glu_pair: up must live on CUDA");
};
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(pad(k, MATRIX_ROW_PADDING), 4 * QK8_1);
let blocks_per_row = k_padded / (4 * QK8_1);
let workspace_main = total_assignments * blocks_per_row * BLOCK_Q8_1_MMQ_SIZE;
let workspace_extra = 128 * BLOCK_Q8_1_MMQ_SIZE;
let workspace_bytes = workspace_main + workspace_extra;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&MMQ_WORKSPACE, dev, workspace_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
const MMQ_X_MAX: usize = 128;
const MMQ_Y_MAX: usize = 128;
const MAX_SMS: usize = 256;
let fixup_bytes = MAX_SMS * MMQ_X_MAX * MMQ_Y_MAX * std::mem::size_of::<f32>();
let (fixup_ptr, _fixup_guard) = workspace_ensure(&FIXUP_WORKSPACE, dev, fixup_bytes)?;
let fixup_ptr = fixup_ptr as *mut std::ffi::c_void;
let out = unsafe { dev.alloc::<f32>(total_assignments * nrows)? };
let weight_ptr = weight.device_ptr()? as *const std::ffi::c_void;
let stride_row_x = (k / qk) as i64;
let stride_col_dst = nrows as i64;
let di = get_device_info(dev);
let quantize = quantize_glu_f32_launcher(ds_layout_for(dtype));
let launcher = mmq_moe_launcher(dtype).expect("supports() checked");
let gate_slice = gate_cuda.as_cuda_slice::<f32>()?;
let up_slice = up_cuda.as_cuda_slice::<f32>()?;
let (gate_ptr, _gate_guard) = slice_ptr(gate_slice, gate_layout.start_offset());
let (up_ptr, _up_guard) = slice_ptr(up_slice, up_layout.start_offset());
let (ids_src_ptr, _ids_src_guard) = slice_ptr(ids_src, 0);
let (ids_dst_ptr, _ids_dst_guard) = slice_ptr(ids_dst, 0);
let (bounds_ptr, _bounds_guard) = slice_ptr(expert_bounds, 0);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
quantize(
gate_ptr as *const f32,
up_ptr as *const f32,
ids_src_ptr as *const i32,
scratch_ptr,
k as i64,
k as i64,
k_padded as i64,
total_assignments as i64,
activation,
stream_ptr,
);
launcher(
fixup_ptr,
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
ids_dst_ptr as *const i32,
bounds_ptr as *const i32,
out_ptr as *mut std::ffi::c_void,
k as i64,
nrows as i64,
total_assignments as i64,
stride_row_x,
stride_col_dst,
num_experts as i64,
ncols_max as i64,
di.cc,
di.nsm,
di.smpbo,
di.warp_size,
stream_ptr,
);
}
drop(_out_guard);
drop(_bounds_guard);
drop(_ids_dst_guard);
drop(_ids_src_guard);
drop(_up_guard);
drop(_gate_guard);
let out_shape: Shape = vec![total_assignments, nrows].into();
Ok(Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(out, dev.clone())),
out_shape,
)))
}
#[allow(clippy::too_many_arguments)]
pub fn grouped_pair(
gate: &QTensor,
up: &QTensor,
xs: &Tensor,
ids_src: &CudaSlice<u32>,
ids_dst: &CudaSlice<u32>,
expert_bounds: &CudaSlice<u32>,
total_assignments: usize,
topk: usize,
num_experts: usize,
dev: &CudaDevice,
) -> Result<(Tensor, Tensor)> {
let dtype = gate.dtype();
if dtype != up.dtype() {
hanzo_ml::bail!(
"fast_mmq grouped_pair requires matching gate/up dtypes, got {:?} and {:?}",
dtype,
up.dtype()
);
}
if !supports(dtype) {
hanzo_ml::bail!("fast_mmq grouped_pair: unsupported quant dtype {dtype:?}");
}
let (num_tokens, k) = xs.dims2()?;
if total_assignments != num_tokens * topk {
hanzo_ml::bail!(
"fast_mmq grouped_pair: total_assignments={total_assignments} does not match num_tokens={num_tokens} * topk={topk}"
);
}
let (gate_experts, nrows, ncols) = gate.shape().dims3()?;
let (up_experts, up_nrows, up_ncols) = up.shape().dims3()?;
if gate_experts != num_experts || up_experts != num_experts {
hanzo_ml::bail!(
"fast_mmq grouped_pair: expected {num_experts} experts, got gate={gate_experts} up={up_experts}"
);
}
if nrows != up_nrows || ncols != up_ncols {
hanzo_ml::bail!(
"fast_mmq grouped_pair: gate/up shape mismatch {:?} vs {:?}",
gate.shape(),
up.shape()
);
}
if k != ncols {
hanzo_ml::bail!(
"fast_mmq grouped_pair: shape mismatch — weight cols {ncols} vs input tail {k}"
);
}
let qk = qk_for(dtype);
if k % qk != 0 {
hanzo_ml::bail!("fast_mmq grouped_pair: k={k} not divisible by qk={qk}");
}
let input_ty = xs.dtype();
if !matches!(input_ty, DType::BF16 | DType::F16 | DType::F32) {
hanzo_ml::bail!(
"fast_mmq grouped_pair: input dtype must be BF16, F16, or F32, got {input_ty:?}"
);
}
let xs = xs.contiguous()?;
let (xs_storage, xs_layout) = xs.storage_and_layout();
let Storage::Cuda(xs_cuda) = &*xs_storage else {
hanzo_ml::bail!("fast_mmq grouped_pair: input must live on CUDA");
};
let xs_offset = xs_layout.start_offset();
let type_x = match input_ty {
DType::F32 => 0,
DType::F16 => 1,
DType::BF16 => 30,
_ => unreachable!(),
};
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(pad(k, MATRIX_ROW_PADDING), 4 * QK8_1);
let blocks_per_row = k_padded / (4 * QK8_1);
let workspace_main = total_assignments * blocks_per_row * BLOCK_Q8_1_MMQ_SIZE;
let workspace_extra = 128 * BLOCK_Q8_1_MMQ_SIZE;
let workspace_bytes = workspace_main + workspace_extra;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&MMQ_WORKSPACE, dev, workspace_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
const MMQ_X_MAX: usize = 128;
const MMQ_Y_MAX: usize = 128;
const MAX_SMS: usize = 256;
let fixup_bytes = MAX_SMS * MMQ_X_MAX * MMQ_Y_MAX * std::mem::size_of::<f32>();
let (fixup_ptr, _fixup_guard) = workspace_ensure(&FIXUP_WORKSPACE, dev, fixup_bytes)?;
let fixup_ptr = fixup_ptr as *mut std::ffi::c_void;
let gate_out = unsafe { dev.alloc::<f32>(total_assignments * nrows)? };
let up_out = unsafe { dev.alloc::<f32>(total_assignments * nrows)? };
let gate_ptr = gate.device_ptr()? as *const std::ffi::c_void;
let up_ptr = up.device_ptr()? as *const std::ffi::c_void;
let stride_row_x = (k / qk) as i64;
let stride_col_dst = nrows as i64;
let di = get_device_info(dev);
let quantize = quantize_launcher(ds_layout_for(dtype));
let launcher = mmq_moe_launcher(dtype).expect("supports() checked");
let (ids_src_ptr, _ids_src_guard) = slice_ptr(ids_src, 0);
let (ids_dst_ptr, _ids_dst_guard) = slice_ptr(ids_dst, 0);
let (bounds_ptr, _bounds_guard) = slice_ptr(expert_bounds, 0);
let (gate_out_ptr, _gate_out_guard) = slice_ptr(&gate_out, 0);
let (up_out_ptr, _up_out_guard) = slice_ptr(&up_out, 0);
unsafe {
match input_ty {
DType::BF16 => {
let slice = xs_cuda.as_cuda_slice::<half::bf16>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
quantize(
xs_ptr as *const std::ffi::c_void,
ids_src_ptr as *const i32,
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
total_assignments as i64,
1,
1,
stream_ptr,
);
}
DType::F16 => {
let slice = xs_cuda.as_cuda_slice::<half::f16>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
quantize(
xs_ptr as *const std::ffi::c_void,
ids_src_ptr as *const i32,
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
total_assignments as i64,
1,
1,
stream_ptr,
);
}
DType::F32 => {
let slice = xs_cuda.as_cuda_slice::<f32>()?;
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
quantize(
xs_ptr as *const std::ffi::c_void,
ids_src_ptr as *const i32,
scratch_ptr,
type_x,
k as i64,
k as i64,
0,
0,
k_padded as i64,
total_assignments as i64,
1,
1,
stream_ptr,
);
}
_ => unreachable!(),
}
for (weight_ptr, out_ptr) in [
(gate_ptr, gate_out_ptr as *mut std::ffi::c_void),
(up_ptr, up_out_ptr as *mut std::ffi::c_void),
] {
launcher(
fixup_ptr,
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
ids_dst_ptr as *const i32,
bounds_ptr as *const i32,
out_ptr,
k as i64,
nrows as i64,
total_assignments as i64,
stride_row_x,
stride_col_dst,
num_experts as i64,
num_tokens as i64,
di.cc,
di.nsm,
di.smpbo,
di.warp_size,
stream_ptr,
);
}
}
drop(_gate_out_guard);
drop(_up_out_guard);
drop(_bounds_guard);
drop(_ids_dst_guard);
drop(_ids_src_guard);
let out_shape: Shape = vec![total_assignments, nrows].into();
let gate_tensor = Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(gate_out, dev.clone())),
out_shape.clone(),
));
let up_tensor = Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(up_out, dev.clone())),
out_shape,
));
Ok((gate_tensor, up_tensor))
}