use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use hanzo_ml::cuda::cudarc::driver::{CudaSlice, DevicePtr};
use hanzo_ml::{
quantized::{GgmlDType, QTensor},
CudaDevice, CudaStorage, DType, Device, Result, Shape, Storage, Tensor,
};
use super::ffi;
use crate::{utils::slice_ptr, GluActivationType};
const Q8_1_BLOCK_SIZE: usize = 32;
const Q8_1_TYPE_SIZE: usize = 36; const MATRIX_ROW_PADDING: usize = 512;
#[inline]
fn pad(p: usize, q: usize) -> usize {
p.div_ceil(q) * q
}
fn output_shape(xs: &Tensor, nrows: usize) -> Shape {
let mut out_dims = xs.dims().to_vec();
let last = out_dims.len() - 1;
out_dims[last] = nrows;
Shape::from(out_dims)
}
pub fn supports(dtype: GgmlDType) -> bool {
matches!(
dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
)
}
pub const MMVQ_MAX_BATCH: usize = 8;
struct WorkspaceSlot {
slice: CudaSlice<u8>,
cap: usize,
}
type WsMap = Mutex<HashMap<hanzo_ml::cuda::DeviceId, &'static Mutex<WorkspaceSlot>>>;
static WORKSPACE: OnceLock<WsMap> = OnceLock::new();
fn workspace_ensure(
dev: &CudaDevice,
bytes: usize,
) -> Result<(u64, std::sync::MutexGuard<'static, WorkspaceSlot>)> {
let map = WORKSPACE.get_or_init(|| Mutex::new(HashMap::new()));
let device_key = dev.id();
let device_mtx: &'static Mutex<WorkspaceSlot> = {
let mut guard = map.lock().unwrap();
match guard.get(&device_key).copied() {
Some(mtx) => mtx,
None => {
let slice = unsafe { dev.alloc::<u8>(bytes.max(1))? };
let leaked = Box::leak(Box::new(Mutex::new(WorkspaceSlot {
slice,
cap: bytes.max(1),
})));
guard.insert(device_key, leaked);
leaked
}
}
};
let mut slot = device_mtx.lock().unwrap();
if slot.cap < bytes {
slot.slice = unsafe { dev.alloc::<u8>(bytes)? };
slot.cap = bytes;
}
let ptr = slot.slice.device_ptr(slot.slice.stream()).0;
Ok((ptr, slot))
}
type PlainLauncher = unsafe extern "C" fn(
vx: *const std::ffi::c_void,
vy: *const std::ffi::c_void,
dst: *mut std::ffi::c_void,
ncols_x: i32,
nrows_x: i32,
stride_col_y: i32,
stride_col_dst: i32,
b_size: i32,
stream: *mut std::ffi::c_void,
);
type FusedGluLauncher = unsafe extern "C" fn(
vx_gate: *const std::ffi::c_void,
vx_up: *const std::ffi::c_void,
vy: *const std::ffi::c_void,
dst: *mut std::ffi::c_void,
ncols_x: i32,
nrows_x: i32,
stride_col_y: i32,
stride_col_dst: i32,
b_size: i32,
activation: i32,
stream: *mut std::ffi::c_void,
);
type FusedQkvLauncher = unsafe extern "C" fn(
vx_q: *const std::ffi::c_void,
vx_k: *const std::ffi::c_void,
vx_v: *const std::ffi::c_void,
vy: *const std::ffi::c_void,
q_dst: *mut std::ffi::c_void,
k_dst: *mut std::ffi::c_void,
v_dst: *mut std::ffi::c_void,
ncols_x: i32,
nrows_q: i32,
nrows_k: i32,
nrows_v: i32,
stride_col_y: i32,
b_size: i32,
stream: *mut std::ffi::c_void,
);
fn plain_launcher_bf16(dtype: GgmlDType) -> Option<PlainLauncher> {
let f: PlainLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_bf16_plain,
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_bf16_plain,
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_bf16_plain,
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_bf16_plain,
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_bf16_plain,
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_bf16_plain,
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_bf16_plain,
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_bf16_plain,
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_bf16_plain,
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_bf16_plain,
_ => return None,
};
Some(f)
}
fn plain_launcher_f16(dtype: GgmlDType) -> Option<PlainLauncher> {
let f: PlainLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f16_plain,
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f16_plain,
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f16_plain,
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f16_plain,
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f16_plain,
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f16_plain,
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f16_plain,
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f16_plain,
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f16_plain,
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f16_plain,
_ => return None,
};
Some(f)
}
fn plain_launcher_f32(dtype: GgmlDType) -> Option<PlainLauncher> {
let f: PlainLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f32_plain,
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f32_plain,
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f32_plain,
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f32_plain,
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f32_plain,
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f32_plain,
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f32_plain,
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f32_plain,
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f32_plain,
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f32_plain,
_ => return None,
};
Some(f)
}
fn fused_glu_launcher(input_ty: DType, dtype: GgmlDType) -> Option<FusedGluLauncher> {
match (input_ty, dtype) {
(DType::BF16, GgmlDType::Q8_0) => Some(ffi::launch_mmvq_gguf_q8_0_bf16_fused_glu),
(DType::F16, GgmlDType::Q8_0) => Some(ffi::launch_mmvq_gguf_q8_0_f16_fused_glu),
(DType::F32, GgmlDType::Q8_0) => Some(ffi::launch_mmvq_gguf_q8_0_f32_fused_glu),
_ => None,
}
}
fn fused_qkv_launcher(input_ty: DType, dtype: GgmlDType) -> Option<FusedQkvLauncher> {
match (input_ty, dtype) {
(DType::BF16, GgmlDType::Q4_0) => Some(ffi::launch_mmvq_gguf_q4_0_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q4_1) => Some(ffi::launch_mmvq_gguf_q4_1_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q5_0) => Some(ffi::launch_mmvq_gguf_q5_0_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q5_1) => Some(ffi::launch_mmvq_gguf_q5_1_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q8_0) => Some(ffi::launch_mmvq_gguf_q8_0_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q2K) => Some(ffi::launch_mmvq_gguf_q2_k_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q3K) => Some(ffi::launch_mmvq_gguf_q3_k_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q4K) => Some(ffi::launch_mmvq_gguf_q4_k_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q5K) => Some(ffi::launch_mmvq_gguf_q5_k_bf16_fused_qkv),
(DType::BF16, GgmlDType::Q6K) => Some(ffi::launch_mmvq_gguf_q6_k_bf16_fused_qkv),
(DType::F16, GgmlDType::Q4_0) => Some(ffi::launch_mmvq_gguf_q4_0_f16_fused_qkv),
(DType::F16, GgmlDType::Q4_1) => Some(ffi::launch_mmvq_gguf_q4_1_f16_fused_qkv),
(DType::F16, GgmlDType::Q5_0) => Some(ffi::launch_mmvq_gguf_q5_0_f16_fused_qkv),
(DType::F16, GgmlDType::Q5_1) => Some(ffi::launch_mmvq_gguf_q5_1_f16_fused_qkv),
(DType::F16, GgmlDType::Q8_0) => Some(ffi::launch_mmvq_gguf_q8_0_f16_fused_qkv),
(DType::F16, GgmlDType::Q2K) => Some(ffi::launch_mmvq_gguf_q2_k_f16_fused_qkv),
(DType::F16, GgmlDType::Q3K) => Some(ffi::launch_mmvq_gguf_q3_k_f16_fused_qkv),
(DType::F16, GgmlDType::Q4K) => Some(ffi::launch_mmvq_gguf_q4_k_f16_fused_qkv),
(DType::F16, GgmlDType::Q5K) => Some(ffi::launch_mmvq_gguf_q5_k_f16_fused_qkv),
(DType::F16, GgmlDType::Q6K) => Some(ffi::launch_mmvq_gguf_q6_k_f16_fused_qkv),
(DType::F32, GgmlDType::Q4_0) => Some(ffi::launch_mmvq_gguf_q4_0_f32_fused_qkv),
(DType::F32, GgmlDType::Q4_1) => Some(ffi::launch_mmvq_gguf_q4_1_f32_fused_qkv),
(DType::F32, GgmlDType::Q5_0) => Some(ffi::launch_mmvq_gguf_q5_0_f32_fused_qkv),
(DType::F32, GgmlDType::Q5_1) => Some(ffi::launch_mmvq_gguf_q5_1_f32_fused_qkv),
(DType::F32, GgmlDType::Q8_0) => Some(ffi::launch_mmvq_gguf_q8_0_f32_fused_qkv),
(DType::F32, GgmlDType::Q2K) => Some(ffi::launch_mmvq_gguf_q2_k_f32_fused_qkv),
(DType::F32, GgmlDType::Q3K) => Some(ffi::launch_mmvq_gguf_q3_k_f32_fused_qkv),
(DType::F32, GgmlDType::Q4K) => Some(ffi::launch_mmvq_gguf_q4_k_f32_fused_qkv),
(DType::F32, GgmlDType::Q5K) => Some(ffi::launch_mmvq_gguf_q5_k_f32_fused_qkv),
(DType::F32, GgmlDType::Q6K) => Some(ffi::launch_mmvq_gguf_q6_k_f32_fused_qkv),
_ => None,
}
}
pub fn plain(w: &QTensor, xs: &Tensor) -> Result<Tensor> {
let dtype = w.dtype();
if !supports(dtype) {
hanzo_ml::bail!("fast_mmvq: unsupported quant dtype {dtype:?}");
}
let Device::Cuda(dev) = w.device() else {
hanzo_ml::bail!("fast_mmvq: weight must live on CUDA");
};
let (nrows, ncols) = w.shape().dims2()?;
let (b_size, k) = match xs.dims() {
[b, k] => (*b, *k),
[b, m, k] => (*b * *m, *k),
other => hanzo_ml::bail!("fast_mmvq: unexpected input rank {other:?}"),
};
if k != ncols {
hanzo_ml::bail!("fast_mmvq: shape mismatch — weight [{nrows}, {ncols}] vs input tail {k}");
}
if b_size == 0 || b_size > MMVQ_MAX_BATCH {
hanzo_ml::bail!(
"fast_mmvq: batch size {b_size} out of supported range 1..={MMVQ_MAX_BATCH}"
);
}
let input_ty = xs.dtype();
if !matches!(input_ty, DType::BF16 | DType::F16 | DType::F32) {
hanzo_ml::bail!("fast_mmvq: input dtype must be BF16, F16, or F32, got {input_ty:?}");
}
let xs = xs.contiguous()?;
let (xs_storage, xs_layout) = xs.storage_and_layout();
let Storage::Cuda(xs_cuda) = &*xs_storage else {
hanzo_ml::bail!("fast_mmvq: input must live on CUDA");
};
let xs_offset = xs_layout.start_offset();
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(k, MATRIX_ROW_PADDING);
let num_blocks_per_row = k_padded / Q8_1_BLOCK_SIZE;
let dst_row_bytes = num_blocks_per_row * Q8_1_TYPE_SIZE;
let scratch_bytes = b_size * dst_row_bytes;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&dev, scratch_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
let stride_col_y = (k_padded / Q8_1_BLOCK_SIZE) as i32;
let stride_col_dst = nrows as i32;
let weight_ptr = w.device_ptr()? as *const std::ffi::c_void;
match input_ty {
DType::BF16 => {
let slice = xs_cuda.as_cuda_slice::<half::bf16>()?;
let out = unsafe { dev.alloc::<half::bf16>(nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_bf16(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
let launcher = plain_launcher_bf16(dtype).expect("supports() checked");
launcher(
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
stream_ptr,
);
}
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Tensor::from((
Storage::Cuda(out_storage),
output_shape(&xs, nrows),
)))
}
DType::F16 => {
let slice = xs_cuda.as_cuda_slice::<half::f16>()?;
let out = unsafe { dev.alloc::<half::f16>(nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f16(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
let launcher = plain_launcher_f16(dtype).expect("supports() checked");
launcher(
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
stream_ptr,
);
}
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Tensor::from((
Storage::Cuda(out_storage),
output_shape(&xs, nrows),
)))
}
DType::F32 => {
let slice = xs_cuda.as_cuda_slice::<f32>()?;
let out = unsafe { dev.alloc::<f32>(nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f32(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
let launcher = plain_launcher_f32(dtype).expect("supports() checked");
launcher(
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
stream_ptr,
);
}
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Tensor::from((
Storage::Cuda(out_storage),
output_shape(&xs, nrows),
)))
}
_ => unreachable!(),
}
}
pub fn fused_glu(
gate_w: &QTensor,
up_w: &QTensor,
xs: &Tensor,
activation: GluActivationType,
) -> Result<Tensor> {
let dtype = gate_w.dtype();
if dtype != up_w.dtype() {
hanzo_ml::bail!(
"fast_mmvq fused_glu: gate/up dtype mismatch {:?} vs {:?}",
dtype,
up_w.dtype()
);
}
let Some(launcher) = fused_glu_launcher(xs.dtype(), dtype) else {
hanzo_ml::bail!("fast_mmvq fused_glu: unsupported dtype combination");
};
let Device::Cuda(dev) = gate_w.device() else {
hanzo_ml::bail!("fast_mmvq fused_glu: gate weight must live on CUDA");
};
let Device::Cuda(up_dev) = up_w.device() else {
hanzo_ml::bail!("fast_mmvq fused_glu: up weight must live on CUDA");
};
if dev.id() != up_dev.id() {
hanzo_ml::bail!("fast_mmvq fused_glu: gate/up weights are on different CUDA devices");
}
let (nrows, ncols) = gate_w.shape().dims2()?;
let (up_nrows, up_ncols) = up_w.shape().dims2()?;
if (nrows, ncols) != (up_nrows, up_ncols) {
hanzo_ml::bail!(
"fast_mmvq fused_glu: gate/up shape mismatch [{nrows}, {ncols}] vs [{up_nrows}, {up_ncols}]"
);
}
let (b_size, k) = match xs.dims() {
[b, k] => (*b, *k),
[b, m, k] => (*b * *m, *k),
other => hanzo_ml::bail!("fast_mmvq fused_glu: unexpected input rank {other:?}"),
};
if k != ncols {
hanzo_ml::bail!(
"fast_mmvq fused_glu: shape mismatch — weight [{nrows}, {ncols}] vs input tail {k}"
);
}
if b_size == 0 || b_size > MMVQ_MAX_BATCH {
hanzo_ml::bail!(
"fast_mmvq fused_glu: batch size {b_size} out of supported range 1..={MMVQ_MAX_BATCH}"
);
}
let input_ty = xs.dtype();
if !matches!(input_ty, DType::BF16 | DType::F16 | DType::F32) {
hanzo_ml::bail!(
"fast_mmvq fused_glu: input dtype must be BF16, F16, or F32, got {input_ty:?}"
);
}
let xs = xs.contiguous()?;
let (xs_storage, xs_layout) = xs.storage_and_layout();
let Storage::Cuda(xs_cuda) = &*xs_storage else {
hanzo_ml::bail!("fast_mmvq fused_glu: input must live on CUDA");
};
let xs_offset = xs_layout.start_offset();
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(k, MATRIX_ROW_PADDING);
let num_blocks_per_row = k_padded / Q8_1_BLOCK_SIZE;
let dst_row_bytes = num_blocks_per_row * Q8_1_TYPE_SIZE;
let scratch_bytes = b_size * dst_row_bytes;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&dev, scratch_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
let stride_col_y = (k_padded / Q8_1_BLOCK_SIZE) as i32;
let stride_col_dst = nrows as i32;
let gate_ptr = gate_w.device_ptr()? as *const std::ffi::c_void;
let up_ptr = up_w.device_ptr()? as *const std::ffi::c_void;
let activation = activation as i32;
match input_ty {
DType::BF16 => {
let slice = xs_cuda.as_cuda_slice::<half::bf16>()?;
let out = unsafe { dev.alloc::<half::bf16>(nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_bf16(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
launcher(
gate_ptr,
up_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
activation,
stream_ptr,
);
}
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Tensor::from((
Storage::Cuda(out_storage),
output_shape(&xs, nrows),
)))
}
DType::F16 => {
let slice = xs_cuda.as_cuda_slice::<half::f16>()?;
let out = unsafe { dev.alloc::<half::f16>(nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f16(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
launcher(
gate_ptr,
up_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
activation,
stream_ptr,
);
}
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Tensor::from((
Storage::Cuda(out_storage),
output_shape(&xs, nrows),
)))
}
DType::F32 => {
let slice = xs_cuda.as_cuda_slice::<f32>()?;
let out = unsafe { dev.alloc::<f32>(nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (out_ptr, _out_guard) = slice_ptr(&out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f32(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
launcher(
gate_ptr,
up_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
activation,
stream_ptr,
);
}
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Tensor::from((
Storage::Cuda(out_storage),
output_shape(&xs, nrows),
)))
}
_ => unreachable!(),
}
}
pub fn fused_qkv(
q_w: &QTensor,
k_w: &QTensor,
v_w: &QTensor,
xs: &Tensor,
) -> Result<(Tensor, Tensor, Tensor)> {
let dtype = q_w.dtype();
if dtype != k_w.dtype() || dtype != v_w.dtype() {
hanzo_ml::bail!(
"fast_mmvq fused_qkv: q/k/v dtype mismatch {:?}, {:?}, {:?}",
dtype,
k_w.dtype(),
v_w.dtype()
);
}
let Some(launcher) = fused_qkv_launcher(xs.dtype(), dtype) else {
hanzo_ml::bail!("fast_mmvq fused_qkv: unsupported dtype combination");
};
let Device::Cuda(dev) = q_w.device() else {
hanzo_ml::bail!("fast_mmvq fused_qkv: q weight must live on CUDA");
};
let Device::Cuda(k_dev) = k_w.device() else {
hanzo_ml::bail!("fast_mmvq fused_qkv: k weight must live on CUDA");
};
let Device::Cuda(v_dev) = v_w.device() else {
hanzo_ml::bail!("fast_mmvq fused_qkv: v weight must live on CUDA");
};
if dev.id() != k_dev.id() || dev.id() != v_dev.id() {
hanzo_ml::bail!("fast_mmvq fused_qkv: q/k/v weights are on different CUDA devices");
}
let (q_nrows, ncols) = q_w.shape().dims2()?;
let (k_nrows, k_ncols) = k_w.shape().dims2()?;
let (v_nrows, v_ncols) = v_w.shape().dims2()?;
if ncols != k_ncols || ncols != v_ncols {
hanzo_ml::bail!("fast_mmvq fused_qkv: q/k/v ncols mismatch {ncols}, {k_ncols}, {v_ncols}");
}
let (b_size, k) = match xs.dims() {
[b, k] => (*b, *k),
[b, m, k] => (*b * *m, *k),
other => hanzo_ml::bail!("fast_mmvq fused_qkv: unexpected input rank {other:?}"),
};
if k != ncols {
hanzo_ml::bail!(
"fast_mmvq fused_qkv: shape mismatch — weight ncols {ncols} vs input tail {k}"
);
}
if b_size == 0 || b_size > MMVQ_MAX_BATCH {
hanzo_ml::bail!(
"fast_mmvq fused_qkv: batch size {b_size} out of supported range 1..={MMVQ_MAX_BATCH}"
);
}
let input_ty = xs.dtype();
if !matches!(input_ty, DType::BF16 | DType::F16 | DType::F32) {
hanzo_ml::bail!(
"fast_mmvq fused_qkv: input dtype must be BF16, F16, or F32, got {input_ty:?}"
);
}
let xs = xs.contiguous()?;
let (xs_storage, xs_layout) = xs.storage_and_layout();
let Storage::Cuda(xs_cuda) = &*xs_storage else {
hanzo_ml::bail!("fast_mmvq fused_qkv: input must live on CUDA");
};
let xs_offset = xs_layout.start_offset();
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(k, MATRIX_ROW_PADDING);
let num_blocks_per_row = k_padded / Q8_1_BLOCK_SIZE;
let dst_row_bytes = num_blocks_per_row * Q8_1_TYPE_SIZE;
let scratch_bytes = b_size * dst_row_bytes;
let (scratch_ptr, _workspace_guard) = workspace_ensure(&dev, scratch_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
let stride_col_y = (k_padded / Q8_1_BLOCK_SIZE) as i32;
let q_ptr = q_w.device_ptr()? as *const std::ffi::c_void;
let k_ptr = k_w.device_ptr()? as *const std::ffi::c_void;
let v_ptr = v_w.device_ptr()? as *const std::ffi::c_void;
match input_ty {
DType::BF16 => {
let slice = xs_cuda.as_cuda_slice::<half::bf16>()?;
let q_out = unsafe { dev.alloc::<half::bf16>(q_nrows * b_size)? };
let k_out = unsafe { dev.alloc::<half::bf16>(k_nrows * b_size)? };
let v_out = unsafe { dev.alloc::<half::bf16>(v_nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (q_out_ptr, _q_out_guard) = slice_ptr(&q_out, 0);
let (k_out_ptr, _k_out_guard) = slice_ptr(&k_out, 0);
let (v_out_ptr, _v_out_guard) = slice_ptr(&v_out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_bf16(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
launcher(
q_ptr,
k_ptr,
v_ptr,
scratch_ptr as *const std::ffi::c_void,
q_out_ptr as *mut std::ffi::c_void,
k_out_ptr as *mut std::ffi::c_void,
v_out_ptr as *mut std::ffi::c_void,
k as i32,
q_nrows as i32,
k_nrows as i32,
v_nrows as i32,
stride_col_y,
b_size as i32,
stream_ptr,
);
}
}
Ok((
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(q_out, dev.clone())),
output_shape(&xs, q_nrows),
)),
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(k_out, dev.clone())),
output_shape(&xs, k_nrows),
)),
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(v_out, dev.clone())),
output_shape(&xs, v_nrows),
)),
))
}
DType::F16 => {
let slice = xs_cuda.as_cuda_slice::<half::f16>()?;
let q_out = unsafe { dev.alloc::<half::f16>(q_nrows * b_size)? };
let k_out = unsafe { dev.alloc::<half::f16>(k_nrows * b_size)? };
let v_out = unsafe { dev.alloc::<half::f16>(v_nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (q_out_ptr, _q_out_guard) = slice_ptr(&q_out, 0);
let (k_out_ptr, _k_out_guard) = slice_ptr(&k_out, 0);
let (v_out_ptr, _v_out_guard) = slice_ptr(&v_out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f16(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
launcher(
q_ptr,
k_ptr,
v_ptr,
scratch_ptr as *const std::ffi::c_void,
q_out_ptr as *mut std::ffi::c_void,
k_out_ptr as *mut std::ffi::c_void,
v_out_ptr as *mut std::ffi::c_void,
k as i32,
q_nrows as i32,
k_nrows as i32,
v_nrows as i32,
stride_col_y,
b_size as i32,
stream_ptr,
);
}
}
Ok((
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(q_out, dev.clone())),
output_shape(&xs, q_nrows),
)),
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(k_out, dev.clone())),
output_shape(&xs, k_nrows),
)),
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(v_out, dev.clone())),
output_shape(&xs, v_nrows),
)),
))
}
DType::F32 => {
let slice = xs_cuda.as_cuda_slice::<f32>()?;
let q_out = unsafe { dev.alloc::<f32>(q_nrows * b_size)? };
let k_out = unsafe { dev.alloc::<f32>(k_nrows * b_size)? };
let v_out = unsafe { dev.alloc::<f32>(v_nrows * b_size)? };
{
let (xs_ptr, _xs_guard) = slice_ptr(slice, xs_offset);
let (q_out_ptr, _q_out_guard) = slice_ptr(&q_out, 0);
let (k_out_ptr, _k_out_guard) = slice_ptr(&k_out, 0);
let (v_out_ptr, _v_out_guard) = slice_ptr(&v_out, 0);
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f32(
xs_ptr as *const std::ffi::c_void,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
launcher(
q_ptr,
k_ptr,
v_ptr,
scratch_ptr as *const std::ffi::c_void,
q_out_ptr as *mut std::ffi::c_void,
k_out_ptr as *mut std::ffi::c_void,
v_out_ptr as *mut std::ffi::c_void,
k as i32,
q_nrows as i32,
k_nrows as i32,
v_nrows as i32,
stride_col_y,
b_size as i32,
stream_ptr,
);
}
}
Ok((
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(q_out, dev.clone())),
output_shape(&xs, q_nrows),
)),
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(k_out, dev.clone())),
output_shape(&xs, k_nrows),
)),
Tensor::from((
Storage::Cuda(CudaStorage::wrap_cuda_slice(v_out, dev.clone())),
output_shape(&xs, v_nrows),
)),
))
}
_ => unreachable!(),
}
}