impl CudaExecutor {
#[inline]
pub fn q6k_gemv_indexed_async(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<GpuBuffer<f32>, GpuError> {
if weight_ptr == 0 {
return Err(GpuError::InvalidLaunchConfig(
"null weight pointer in q6k_gemv_indexed_async".to_string(),
));
}
use crate::cuda::gpu_profile::Q6kVariant;
let num_warps = self.gpu_profile.mwv_warps;
let can_use_advanced = k.is_multiple_of(256);
if can_use_advanced && self.gpu_profile.q6k == Q6kVariant::HwDp4a {
let buf_output = GpuBuffer::<f32>::new(&self.context, n as usize)?;
self.hw_dp4a_q6k_gemv_into(weight_ptr, input, &buf_output, n, k)?;
return Ok(buf_output);
}
if can_use_advanced && self.gpu_profile.q6k == Q6kVariant::Dp4a {
let buf_output = GpuBuffer::<f32>::new(&self.context, n as usize)?;
self.dp4a_q6k_gemv_into(weight_ptr, input, &buf_output, n, k)?;
return Ok(buf_output);
}
let (kernel_type, cache_key, config) = if can_use_advanced && self.gpu_profile.q6k == Q6kVariant::Mwv {
let kt = KernelType::MwvQ6KGemv { k, n, num_warps };
let ck = format!("mwv_q6k_gemv_{}_{}_{}", k, n, num_warps);
let cfg = LaunchConfig::grid_2d(n, 1, num_warps * 32, 1);
(kt, ck, cfg)
} else {
let kt = KernelType::Q6KGemv { k, n };
let ck = format!("q6k_gemv_{}_{}", k, n);
let cfg = LaunchConfig::grid_2d(n, 1, 32, 1);
(kt, ck, cfg)
};
let kernel_name = self.kernels.kernel_name(&kernel_type);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let buf_output = GpuBuffer::<f32>::new(&self.context, n as usize)?;
let mut ptr_output = buf_output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(buf_output)
}
#[inline]
pub fn q4k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
use crate::cuda::gpu_profile::Q4kVariant;
match self.gpu_profile.q4k {
Q4kVariant::Legacy => self.q4k_gemv_into_legacy(weight_ptr, input, output, n, k),
Q4kVariant::Wide => self.wide_q4k_gemv_into(weight_ptr, input, output, n, k),
Q4kVariant::Vectorized => self.vectorized_q4k_gemv_into(weight_ptr, input, output, n, k),
Q4kVariant::MwvDp4a => self.mwv_dp4a_q4k_gemv_into(weight_ptr, input, output, n, k),
Q4kVariant::HwDp4a => self.hw_dp4a_q4k_gemv_into(weight_ptr, input, output, n, k),
Q4kVariant::Mwv => self.mwv_q4k_gemv_into(weight_ptr, input, output, n, k),
}
}
fn q4k_gemv_into_legacy(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "q4k_gemv_into_legacy")?;
const MAX_TILED_K: u32 = 12_288;
let use_tiled = k.is_multiple_of(256) && k <= MAX_TILED_K;
let use_chunked = k.is_multiple_of(256) && k > MAX_TILED_K;
let outputs_per_block = 4u32;
let (kernel_type, cache_key, config) = if use_chunked {
let kt = KernelType::ChunkedTiledQ4KGemv {
k,
n,
outputs_per_block,
};
let ck = format!("chunked_tiled_q4k_gemv_{}_{}_{}", k, n, outputs_per_block);
let num_blocks = (n + outputs_per_block - 1) / outputs_per_block;
let cfg = LaunchConfig::grid_2d(num_blocks, 1, 128, 1);
(kt, ck, cfg)
} else if use_tiled {
let kt = KernelType::TiledQ4KGemv {
k,
n,
outputs_per_block,
};
let ck = format!("tiled_q4k_gemv_{}_{}_{}", k, n, outputs_per_block);
let num_blocks = (n + outputs_per_block - 1) / outputs_per_block;
let cfg = LaunchConfig::grid_2d(num_blocks, 1, 128, 1);
(kt, ck, cfg)
} else {
let kt = KernelType::Q4KGemv { k, n };
let ck = format!("q4k_gemv_{}_{}", k, n);
let cfg = LaunchConfig::grid_2d(n, 1, 32, 1);
(kt, ck, cfg)
};
let kernel_name = self.kernels.kernel_name(&kernel_type);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
#[inline]
pub fn coalesced_q4k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "coalesced_q4k_gemv_into")?;
let kernel_type = KernelType::CoalescedQ4KGemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("coalesced_q4k_gemv_{}_{}", k, n);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
#[inline]
pub fn wide_q4k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "wide_q4k_gemv_into")?;
let kernel_type = KernelType::WideQ4KGemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("wide_q4k_gemv_{}_{}", k, n);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let config = LaunchConfig::grid_2d(n, 1, 256, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
#[inline]
pub fn vectorized_q4k_gemv_into(
&mut self,
weight_ptr: u64,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
n: u32,
k: u32,
) -> Result<(), GpuError> {
validate_device_ptr(weight_ptr, "vectorized_q4k_gemv_into")?;
let kernel_type = KernelType::VectorizedQ4KGemv { k, n };
let kernel_name = self.kernels.kernel_name(&kernel_type);
let cache_key = format!("vectorized_q4k_gemv_{}_{}", k, n);
if !self.modules.contains_key(&cache_key) {
let ptx = self.kernels.generate_ptx(&kernel_type);
let module = self.compile_ptx(&ptx)?;
self.modules.insert(cache_key.clone(), module);
}
let module = self
.modules
.get_mut(&cache_key)
.expect("module just inserted");
let config = LaunchConfig::grid_2d(n, 1, 32, 1);
let mut ptr_output = output.as_ptr();
let mut ptr_weights = weight_ptr;
let mut ptr_input = input.as_ptr();
let mut k_val = k;
let mut n_val = n;
unsafe {
self.stream.launch_kernel(
module,
kernel_name,
&config,
&mut [
std::ptr::from_mut(&mut ptr_output) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_weights) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut ptr_input) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut k_val) as *mut std::ffi::c_void,
std::ptr::from_mut(&mut n_val) as *mut std::ffi::c_void,
],
)?;
}
Ok(())
}
}