impl CudaExecutor {
pub fn batched_q6k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
m: u32,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "batched_q6k_gemv_into")?;
debug_assert!(
k.is_multiple_of(256),
"K must be multiple of 256 for Q6K super-blocks"
);
let kernel_type = KernelType::BatchedQ6KGemv { k, n, m };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("batched_q6k_gemv_{}_{}_{}", m, k, n);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
let mut m_val = m;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut m_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
#[inline]
pub fn q6k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "q6k_gemv_into")?;
use crate::cuda::gpu_profile::Q6kVariant;
let can_use_advanced = k.is_multiple_of(256);
if can_use_advanced && self.gpu_profile.q6k == Q6kVariant::HwDp4a {
return self.hw_dp4a_q6k_gemv_into(weight_ptr, input, output, n, k);
}
if can_use_advanced && self.gpu_profile.q6k == Q6kVariant::Dp4a {
return self.dp4a_q6k_gemv_into(weight_ptr, input, output, n, k);
}
if can_use_advanced && self.gpu_profile.q6k == Q6kVariant::Mwv {
return self.mwv_q6k_gemv_into(weight_ptr, input, output, n, k);
}
let kernel_type = KernelType::Q6KGemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("q6k_gemv_{}_{}", k, n);
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_input, k_val as u64, n_val as u64],
});
}
Ok(())
}
#[inline]
pub fn mwv_q6k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "mwv_q6k_gemv_into")?;
debug_assert!(
k.is_multiple_of(256),
"K must be multiple of 256 for Q6K super-blocks"
);
let num_warps = self.gpu_profile.mwv_warps;
let kernel_type = KernelType::MwvQ6KGemv { k, n, num_warps };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("mwv_q6k_gemv_{}_{}_{}", k, n, num_warps);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let threads = num_warps * 32;
let config = LaunchConfig::grid_2d(n, 1, threads, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_input, k_val as u64, n_val as u64],
});
}
Ok(())
}
pub fn dp4a_q6k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "dp4a_q6k_gemv_into")?;
let q8_ptr = self
.workspace
.q8_activation_buf
.as_ref()
.expect("dp4a_q6k: workspace.q8_activation_buf not initialized")
.as_ptr();
let q8_len = self
.workspace
.q8_activation_buf
.as_ref()
.expect("q8_activation_buf must be initialized")
.len();
let q8_buf = unsafe { GpuBuffer::<u8>::from_raw_parts(q8_ptr, q8_len) };
if !self.q8_activation_valid {
self.q8_quantize_into(input, &q8_buf, k)?;
self.q8_activation_valid = true;
}
let num_warps = self.gpu_profile.mwv_warps;
let kernel_type = KernelType::Dp4aQ6KGemv { k, n, num_warps };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("dp4a_q6k_gemv_{}_{}_{}", k, n, num_warps);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let threads = num_warps * 32;
let grid_x = n.min(self.num_sms * 16);
let config = LaunchConfig::grid_2d(grid_x, 1, threads, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_q8 = q8_buf.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_q8) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_q8, k_val as u64, n_val as u64],
});
}
std::mem::forget(q8_buf);
Ok(())
}
#[inline]
pub fn hw_dp4a_q6k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "hw_dp4a_q6k_gemv_into")?;
let q8_ptr = self
.workspace
.q8_activation_buf
.as_ref()
.expect("hw_dp4a_q6k: workspace.q8_activation_buf not initialized")
.as_ptr();
let q8_len = self
.workspace
.q8_activation_buf
.as_ref()
.expect("q8_activation_buf must be initialized")
.len();
let q8_buf = unsafe { GpuBuffer::<u8>::from_raw_parts(q8_ptr, q8_len) };
if !self.q8_activation_valid {
self.q8_quantize_into(input, &q8_buf, k)?;
self.q8_activation_valid = true;
}
let num_warps = self.gpu_profile.mwv_warps;
let kernel_type = KernelType::HwDp4aQ6KGemv { k, n, num_warps };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("hw_dp4a_q6k_gemv_{}_{}_{}", k, n, num_warps);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let threads = num_warps * 32;
let grid_x = n.min(self.num_sms * 16);
let config = LaunchConfig::grid_2d(grid_x, 1, threads, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_q8 = q8_buf.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_q8) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_q8, k_val as u64, n_val as u64],
});
}
std::mem::forget(q8_buf);
Ok(())
}
#[inline]
pub fn coalesced_q6k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "coalesced_q6k_gemv_into")?;
let kernel_type = KernelType::CoalescedQ6KGemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("coalesced_q6k_gemv_{}_{}", k, n);
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
#[inline]
pub fn q8_0_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "q8_0_gemv_into")?;
let kernel_type = KernelType::Q8_0Gemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("q8_0_gemv_{}_{}", k, n);
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_input, k_val as u64, n_val as u64],
});
}
Ok(())
}
#[inline]
pub fn q5_0_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "q5_0_gemv_into")?;
let kernel_type = KernelType::Q5_0Gemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("q5_0_gemv_{}_{}", k, n);
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_input, k_val as u64, n_val as u64],
});
}
Ok(())
}
#[inline]
pub fn q4_0_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "q4_0_gemv_into")?;
let kernel_type = KernelType::Q4_0Gemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("q4_0_gemv_{}_{}", k, n);
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_input, k_val as u64, n_val as u64],
});
}
Ok(())
}
pub fn f32_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "f32_gemv_into")?;
let kernel_type = KernelType::Gemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("f32_gemv_{}_{}", k, n);
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
if self.graph_recording {
let module = self.modules.get_mut(&cache_key).expect("module exists");
let func = module.get_function(kernel_name)?;
self.graph_recorded_kernels.push(RecordedKernel {
func: SendCUfunction(func),
config,
arg_data: vec![ptr_output, ptr_weights, ptr_input, k_val as u64, n_val as u64],
});
}
Ok(())
}
}