use crate::{
compat::*,
cuda::*,
error::{CudaKernelError, Result},
kernel::Kernels,
kernels::macros::ops,
source::Source,
};
ops!(matmul, dot);
pub fn call_ops_matmul<T>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: cudarc::driver::DeviceRepr,
{
let func = kernels.load_function(context, Source::OpsMatrix, kernel.0)?;
let lhs_ndim = metadata[1];
let rhs_ndim = metadata[2];
let batch_ndim = metadata[3];
let metadata_base = 4 + lhs_ndim + rhs_ndim + batch_ndim + lhs_ndim + rhs_ndim;
let m = metadata[metadata_base + 2];
let n = metadata[metadata_base + 4];
let num_batches = if batch_ndim == 0 {
1
} else {
let batch_shape = &metadata[4 + lhs_ndim + rhs_ndim..4 + lhs_ndim + rhs_ndim + batch_ndim];
batch_shape.iter().product()
};
const TILE_SIZE: u32 = 16;
let grid_width = (n as u32).div_ceil(TILE_SIZE).max(1);
let grid_height = (m as u32).div_ceil(TILE_SIZE).max(1);
let cfg = LaunchConfig {
grid_dim: (grid_width, grid_height, num_batches as u32),
block_dim: (TILE_SIZE, TILE_SIZE, 1),
shared_mem_bytes: 0,
};
let stream = context.default_stream();
let metadata_dev = stream
.memcpy_stod(metadata)
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy metadata: {:?}", e)))?;
unsafe {
func.launch(&stream, cfg, |args| {
args.arg(lhs).arg(rhs).arg(output).arg(&metadata_dev);
})
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to launch kernel: {:?}", e)))?;
}
Ok(())
}
pub fn call_ops_dot<T>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: cudarc::driver::DeviceRepr,
{
let func = kernels.load_function(context, Source::OpsMatrix, kernel.0)?;
let m = metadata[0];
let n = metadata[2];
const DOT_TILE_SIZE: u32 = 32;
const BLOCK_SIZE: u32 = 4;
const THREADS_PER_DIM: u32 = DOT_TILE_SIZE / BLOCK_SIZE;
let grid_width = (n as u32).div_ceil(DOT_TILE_SIZE).max(1);
let grid_height = (m as u32).div_ceil(DOT_TILE_SIZE).max(1);
let cfg = LaunchConfig {
grid_dim: (grid_width, grid_height, 1),
block_dim: (THREADS_PER_DIM, THREADS_PER_DIM, 1),
shared_mem_bytes: 0,
};
let stream = context.default_stream();
let metadata_dev = stream
.memcpy_stod(metadata)
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy metadata: {:?}", e)))?;
unsafe {
func.launch(&stream, cfg, |args| {
args.arg(lhs).arg(rhs).arg(output).arg(&metadata_dev);
})
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to launch kernel: {:?}", e)))?;
}
Ok(())
}