use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use super::cuda::{QCudaStorage, MATRIX_ROW_PADDING};
use super::GgmlDType;
use crate::cuda_backend::DeviceId;
use crate::{backend::BackendStorage, CudaDevice, CudaStorage, DType, Result, Shape};
use cudarc::driver::{CudaSlice, DevicePtr};
const Q8_1_BLOCK_SIZE: usize = 32;
const Q8_1_TYPE_SIZE: usize = 36;
#[inline]
fn pad(p: usize, q: usize) -> usize {
p.div_ceil(q) * q
}
fn supports(dtype: GgmlDType) -> bool {
matches!(
dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
)
}
const MMVQ_MAX_BATCH: usize = 8;
struct WorkspaceSlot {
slice: CudaSlice<u8>,
cap: usize,
}
static WORKSPACE: OnceLock<Mutex<HashMap<DeviceId, WorkspaceSlot>>> = OnceLock::new();
fn workspace_ensure(
dev: &CudaDevice,
bytes: usize,
) -> Result<(
u64,
std::sync::MutexGuard<'static, HashMap<DeviceId, WorkspaceSlot>>,
)> {
let map = WORKSPACE.get_or_init(|| Mutex::new(HashMap::new()));
let device_key = dev.id();
let mut guard = map.lock().unwrap();
let slot = match guard.entry(device_key) {
std::collections::hash_map::Entry::Occupied(entry) => {
let slot = entry.into_mut();
if slot.cap < bytes {
slot.slice = unsafe { dev.alloc::<u8>(bytes)? };
slot.cap = bytes;
}
slot
}
std::collections::hash_map::Entry::Vacant(entry) => {
let slice = unsafe { dev.alloc::<u8>(bytes)? };
entry.insert(WorkspaceSlot { slice, cap: bytes })
}
};
let ptr = slot.slice.device_ptr(slot.slice.stream()).0;
Ok((ptr, guard))
}
type PlainLauncher = unsafe extern "C" fn(
vx: *const std::ffi::c_void,
vy: *const std::ffi::c_void,
dst: *mut std::ffi::c_void,
ncols_x: i32,
nrows_x: i32,
stride_col_y: i32,
stride_col_dst: i32,
b_size: i32,
stream: *mut std::ffi::c_void,
);
fn plain_launcher_bf16(dtype: GgmlDType) -> Option<PlainLauncher> {
use hanzo_kernels::ffi;
let f: PlainLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_bf16_plain,
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_bf16_plain,
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_bf16_plain,
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_bf16_plain,
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_bf16_plain,
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_bf16_plain,
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_bf16_plain,
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_bf16_plain,
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_bf16_plain,
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_bf16_plain,
_ => return None,
};
Some(f)
}
fn plain_launcher_f16(dtype: GgmlDType) -> Option<PlainLauncher> {
use hanzo_kernels::ffi;
let f: PlainLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f16_plain,
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f16_plain,
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f16_plain,
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f16_plain,
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f16_plain,
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f16_plain,
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f16_plain,
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f16_plain,
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f16_plain,
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f16_plain,
_ => return None,
};
Some(f)
}
fn plain_launcher_f32(dtype: GgmlDType) -> Option<PlainLauncher> {
use hanzo_kernels::ffi;
let f: PlainLauncher = match dtype {
GgmlDType::Q4_0 => ffi::launch_mmvq_gguf_q4_0_f32_plain,
GgmlDType::Q4_1 => ffi::launch_mmvq_gguf_q4_1_f32_plain,
GgmlDType::Q5_0 => ffi::launch_mmvq_gguf_q5_0_f32_plain,
GgmlDType::Q5_1 => ffi::launch_mmvq_gguf_q5_1_f32_plain,
GgmlDType::Q8_0 => ffi::launch_mmvq_gguf_q8_0_f32_plain,
GgmlDType::Q2K => ffi::launch_mmvq_gguf_q2_k_f32_plain,
GgmlDType::Q3K => ffi::launch_mmvq_gguf_q3_k_f32_plain,
GgmlDType::Q4K => ffi::launch_mmvq_gguf_q4_k_f32_plain,
GgmlDType::Q5K => ffi::launch_mmvq_gguf_q5_k_f32_plain,
GgmlDType::Q6K => ffi::launch_mmvq_gguf_q6_k_f32_plain,
_ => return None,
};
Some(f)
}
pub fn try_fwd(
qstorage: &QCudaStorage,
self_shape: &Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<Option<(CudaStorage, Shape)>> {
use hanzo_kernels::ffi;
let w_dtype = qstorage.dtype();
if !supports(w_dtype) {
return Ok(None);
}
let input_dtype = rhs.dtype();
if !matches!(input_dtype, DType::BF16 | DType::F16 | DType::F32) {
return Ok(None);
}
let (nrows, ncols) = self_shape.dims2()?;
let (b_size, k) = match rhs_l.shape().dims() {
[b, m, k] => (b * m, *k),
[b, k] => (*b, *k),
_ => return Ok(None),
};
if ncols != k {
return Ok(None);
}
if b_size == 0 || b_size > MMVQ_MAX_BATCH {
return Ok(None);
}
let (o1, o2) = match rhs_l.contiguous_offsets() {
Some(offsets) => offsets,
None => return Ok(None),
};
let dev = qstorage.device();
let stream_ptr = dev.cuda_stream().cu_stream() as *mut std::ffi::c_void;
let k_padded = pad(k, MATRIX_ROW_PADDING);
let num_blocks_per_row = k_padded / Q8_1_BLOCK_SIZE;
let dst_row_bytes = num_blocks_per_row * Q8_1_TYPE_SIZE;
let scratch_bytes = b_size * dst_row_bytes;
let (scratch_ptr, _workspace_guard) = workspace_ensure(dev, scratch_bytes)?;
let scratch_ptr = scratch_ptr as *mut std::ffi::c_void;
let stride_col_y = (k_padded / Q8_1_BLOCK_SIZE) as i32;
let stride_col_dst = nrows as i32;
let weight_ptr = qstorage.device_ptr()? as *const std::ffi::c_void;
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
let stream = dev.cuda_stream();
match input_dtype {
DType::BF16 => {
let rhs_slice = rhs.as_cuda_slice::<half::bf16>()?;
let rhs_slice = rhs_slice.slice(o1..o2);
let out = unsafe { dev.alloc::<half::bf16>(nrows * b_size)? };
let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void;
let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void;
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_bf16(
rhs_ptr,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
let launcher = plain_launcher_bf16(w_dtype).unwrap();
launcher(
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
stream_ptr,
);
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Some((out_storage, out_shape.into())))
}
DType::F16 => {
let rhs_slice = rhs.as_cuda_slice::<half::f16>()?;
let rhs_slice = rhs_slice.slice(o1..o2);
let out = unsafe { dev.alloc::<half::f16>(nrows * b_size)? };
let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void;
let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void;
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f16(
rhs_ptr,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
let launcher = plain_launcher_f16(w_dtype).unwrap();
launcher(
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
stream_ptr,
);
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Some((out_storage, out_shape.into())))
}
DType::F32 => {
let rhs_slice = rhs.as_cuda_slice::<f32>()?;
let rhs_slice = rhs_slice.slice(o1..o2);
let out = unsafe { dev.alloc::<f32>(nrows * b_size)? };
let rhs_ptr = rhs_slice.device_ptr(&stream).0 as *const std::ffi::c_void;
let out_ptr = out.device_ptr(&stream).0 as *mut std::ffi::c_void;
unsafe {
ffi::launch_mmvq_gguf_quantize_q8_1_f32(
rhs_ptr,
scratch_ptr,
k as i32,
k_padded as i32,
b_size as i32,
stream_ptr,
);
let launcher = plain_launcher_f32(w_dtype).unwrap();
launcher(
weight_ptr,
scratch_ptr as *const std::ffi::c_void,
out_ptr,
k as i32,
nrows as i32,
stride_col_y,
stride_col_dst,
b_size as i32,
stream_ptr,
);
}
let out_storage = CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok(Some((out_storage, out_shape.into())))
}
_ => Ok(None),
}
}