#![allow(clippy::similar_names)]
#[cfg(feature = "cuda")]
use crate::driver::{CudaContext, CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use crate::error::Result;
#[cfg(feature = "cuda")]
use crate::kernels::{GemmKernel, Kernel};
#[cfg(feature = "cuda")]
use super::super::cache::compile_lock_launch;
#[cfg(feature = "cuda")]
use super::super::GpuResidentTensor;
#[cfg(feature = "cuda")]
impl GpuResidentTensor<f32> {
pub fn matmul(
&self,
ctx: &CudaContext,
other: &GpuResidentTensor<f32>,
m: u32,
n: u32,
k: u32,
) -> Result<GpuResidentTensor<f32>> {
let expected_a = (m * k) as usize;
let expected_b = (k * n) as usize;
let output_size = (m * n) as usize;
if self.len() != expected_a {
return Err(crate::GpuError::InvalidParameter(format!(
"A has {} elements, expected {} ({}x{})",
self.len(),
expected_a,
m,
k
)));
}
if other.len() != expected_b {
return Err(crate::GpuError::InvalidParameter(format!(
"B has {} elements, expected {} ({}x{})",
other.len(),
expected_b,
k,
n
)));
}
let output_buffer = GpuBuffer::new(ctx, output_size)?;
let tile_size = 16u32;
let force_fp32 = std::env::var("TRUENO_FORCE_FP32_GEMM").is_ok();
let use_wmma = !force_fp32 && k >= 64 && m >= 64 && n >= 64;
let use_tiled = !use_wmma && k >= 64;
let (kernel, cache_key, config) = if use_wmma {
let kernel = GemmKernel::wmma_fp16(m, n, k);
let key = format!("gemm_wmma_fp16:{}x{}x{}", m, n, k);
let grid_x = (n + 15) / 16;
let grid_y = (m + 15) / 16;
let cfg = LaunchConfig {
grid: (grid_x, grid_y, 1),
block: (32, 1, 1), shared_mem: 1024,
};
(kernel, key, cfg)
} else if use_tiled {
let kernel = GemmKernel::tiled_unrolled(m, n, k, tile_size);
let key = format!("gemm_tiled_unrolled:{}x{}x{}", m, n, k);
let grid_x = (n + tile_size - 1) / tile_size;
let grid_y = (m + tile_size - 1) / tile_size;
let cfg = LaunchConfig {
grid: (grid_x, grid_y, 1),
block: (tile_size, tile_size, 1),
shared_mem: tile_size * tile_size * 4 * 2,
};
(kernel, key, cfg)
} else {
let kernel = GemmKernel::naive(m, n, k);
let key = format!("gemm_naive:{}x{}x{}", m, n, k);
let block_size = 16u32;
let grid_x = (n + block_size - 1) / block_size;
let grid_y = (m + block_size - 1) / block_size;
let cfg = LaunchConfig {
grid: (grid_x, grid_y, 1),
block: (block_size, block_size, 1),
shared_mem: 0,
};
(kernel, key, cfg)
};
let ptx = kernel.emit_ptx();
let stream = CudaStream::new(ctx)?;
let a_ptr = self.as_ptr();
let b_ptr = other.as_ptr();
let c_ptr = output_buffer.as_ptr();
let m_val = m;
let n_val = n;
let k_val = k;
let mut args: Vec<*mut std::ffi::c_void> = vec![
std::ptr::addr_of!(a_ptr) as *mut _,
std::ptr::addr_of!(b_ptr) as *mut _,
std::ptr::addr_of!(c_ptr) as *mut _,
std::ptr::addr_of!(m_val) as *mut _,
std::ptr::addr_of!(n_val) as *mut _,
std::ptr::addr_of!(k_val) as *mut _,
];
compile_lock_launch(ctx, &stream, &cache_key, &ptx, kernel.name(), &config, &mut args)?;
stream.synchronize()?;
Ok(GpuResidentTensor::from_buffer_internal(output_buffer, 1))
}
pub fn matmul_with_stream(
&self,
ctx: &CudaContext,
other: &GpuResidentTensor<f32>,
m: u32,
n: u32,
k: u32,
stream: &CudaStream,
) -> Result<GpuResidentTensor<f32>> {
let expected_a = (m * k) as usize;
let expected_b = (k * n) as usize;
let output_size = (m * n) as usize;
if self.len() != expected_a {
return Err(crate::GpuError::InvalidParameter(format!(
"A has {} elements, expected {} ({}x{})",
self.len(),
expected_a,
m,
k
)));
}
if other.len() != expected_b {
return Err(crate::GpuError::InvalidParameter(format!(
"B has {} elements, expected {} ({}x{})",
other.len(),
expected_b,
k,
n
)));
}
let output_buffer = GpuBuffer::new(ctx, output_size)?;
let tile_size = 16u32;
let force_fp32 = std::env::var("TRUENO_FORCE_FP32_GEMM").is_ok();
let use_wmma = !force_fp32 && k >= 64 && m >= 64 && n >= 64;
let use_tiled = !use_wmma && k >= 64;
let (kernel, cache_key, config) = if use_wmma {
let kernel = GemmKernel::wmma_fp16(m, n, k);
let key = format!("gemm_wmma_fp16:{}x{}x{}", m, n, k);
let grid_x = (n + 15) / 16;
let grid_y = (m + 15) / 16;
let cfg =
LaunchConfig { grid: (grid_x, grid_y, 1), block: (32, 1, 1), shared_mem: 1024 };
(kernel, key, cfg)
} else if use_tiled {
let kernel = GemmKernel::tiled_unrolled(m, n, k, tile_size);
let key = format!("gemm_tiled_unrolled:{}x{}x{}", m, n, k);
let grid_x = (n + tile_size - 1) / tile_size;
let grid_y = (m + tile_size - 1) / tile_size;
let cfg = LaunchConfig {
grid: (grid_x, grid_y, 1),
block: (tile_size, tile_size, 1),
shared_mem: tile_size * tile_size * 4 * 2,
};
(kernel, key, cfg)
} else {
let kernel = GemmKernel::naive(m, n, k);
let key = format!("gemm_naive:{}x{}x{}", m, n, k);
let block_size = 16u32;
let grid_x = (n + block_size - 1) / block_size;
let grid_y = (m + block_size - 1) / block_size;
let cfg = LaunchConfig {
grid: (grid_x, grid_y, 1),
block: (block_size, block_size, 1),
shared_mem: 0,
};
(kernel, key, cfg)
};
let ptx = kernel.emit_ptx();
let a_ptr = self.as_ptr();
let b_ptr = other.as_ptr();
let c_ptr = output_buffer.as_ptr();
let m_val = m;
let n_val = n;
let k_val = k;
let mut args: Vec<*mut std::ffi::c_void> = vec![
std::ptr::addr_of!(a_ptr) as *mut _,
std::ptr::addr_of!(b_ptr) as *mut _,
std::ptr::addr_of!(c_ptr) as *mut _,
std::ptr::addr_of!(m_val) as *mut _,
std::ptr::addr_of!(n_val) as *mut _,
std::ptr::addr_of!(k_val) as *mut _,
];
compile_lock_launch(ctx, stream, &cache_key, &ptx, kernel.name(), &config, &mut args)?;
Ok(GpuResidentTensor::from_buffer_internal(output_buffer, 1))
}
}