#[cfg(cuda_available)]
use crate::{DType, Device, Result, Tensor, TensorError};
#[cfg(cuda_available)]
use std::collections::HashMap;
#[cfg(cuda_available)]
use std::ffi::{CStr, CString};
#[cfg(cuda_available)]
use std::sync::Arc;
#[cfg(cuda_available)]
#[derive(Debug)]
pub struct CudaDevice {
device_id: i32,
context: CudaContext,
streams: Vec<CudaStream>,
module_cache: HashMap<String, CudaModule>,
kernel_cache: HashMap<String, CudaKernel>,
properties: CudaDeviceProperties,
}
#[cfg(cuda_available)]
#[derive(Debug, Clone)]
pub struct CudaKernelConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_memory: u32,
pub stream: Option<CudaStream>,
}
#[cfg(cuda_available)]
#[derive(Debug)]
pub struct CudaMemoryPool {
available_chunks: HashMap<usize, Vec<*mut std::ffi::c_void>>,
total_allocated: usize,
alignment: usize,
}
#[cfg(cuda_available)]
impl CudaDevice {
pub fn new(device_id: i32) -> Result<Self> {
unsafe {
cuda_init(0)?;
}
unsafe {
cuda_set_device(device_id)?;
}
let context = CudaContext::new(device_id)?;
let mut streams = Vec::new();
for _ in 0..4 {
streams.push(CudaStream::new()?);
}
let properties = CudaDeviceProperties::query(device_id)?;
Ok(CudaDevice {
device_id,
context,
streams,
module_cache: HashMap::new(),
kernel_cache: HashMap::new(),
properties,
})
}
pub fn get_device_properties(&self) -> &CudaDeviceProperties {
&self.properties
}
pub fn compile_kernel(
&mut self,
source: &str,
kernel_name: &str,
compile_options: &[&str],
) -> Result<()> {
let module = self.compile_module(source, compile_options)?;
let kernel = module.get_function(kernel_name)?;
self.module_cache
.insert(format!("{}_module", kernel_name), module);
self.kernel_cache.insert(kernel_name.to_string(), kernel);
Ok(())
}
pub fn launch_kernel<T>(
&mut self,
kernel_name: &str,
config: &CudaKernelConfig,
args: &[&[T]],
) -> Result<*mut std::ffi::c_void>
where
T: Clone + Default + Send + Sync + 'static,
{
let kernel = self.kernel_cache.get(kernel_name).ok_or_else(|| {
TensorError::invalid_operation_simple(format!(
"Kernel '{}' not found. Compile it first.",
kernel_name
))
})?;
let mut device_ptrs = Vec::new();
for arg in args {
let device_ptr = self.allocate_device_memory(std::mem::size_of_val(*arg))?;
unsafe {
cuda_memcpy_htod(
device_ptr,
arg.as_ptr() as *const std::ffi::c_void,
std::mem::size_of_val(*arg),
)?;
}
device_ptrs.push(device_ptr);
}
let output_size = if !args.is_empty() {
std::mem::size_of_val(args[0])
} else {
std::mem::size_of::<T>()
};
let output_ptr = self.allocate_device_memory(output_size)?;
device_ptrs.push(output_ptr);
let mut kernel_args: Vec<*mut std::ffi::c_void> = device_ptrs.to_vec();
unsafe {
cuda_launch_kernel(
kernel.function,
config.grid_dim,
config.block_dim,
kernel_args.as_mut_ptr(),
config.shared_memory,
config
.stream
.as_ref()
.map(|s| s.handle)
.unwrap_or(std::ptr::null_mut()),
)?;
}
if config.stream.is_none() {
unsafe {
cuda_device_synchronize()?;
}
}
for ptr in &device_ptrs[..device_ptrs.len() - 1] {
unsafe {
cuda_free(*ptr)?;
}
}
Ok(output_ptr)
}
pub fn matmul_cuda<T>(&mut self, a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(TensorError::invalid_operation_simple(
"Matrix multiplication requires 2D tensors".to_string(),
));
}
let (m, k) = (a_shape[0], a_shape[1]);
let (k2, n) = (b_shape[0], b_shape[1]);
if k != k2 {
return Err(TensorError::invalid_operation_simple(format!(
"Matrix dimension mismatch: {} vs {}",
k, k2
)));
}
if !self.kernel_cache.contains_key("cuda_gemm") {
let gemm_ptx = self.generate_gemm_ptx()?;
self.compile_kernel(&gemm_ptx, "cuda_gemm", &["-O3", "-use_fast_math"])?;
}
let config = self.calculate_gemm_config(m, n, k)?;
let a_data = a.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Failed to access tensor data".to_string())
})?;
let b_data = b.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Failed to access tensor data".to_string())
})?;
let output_ptr = self.launch_kernel("cuda_gemm", &config, &[a_data, b_data])?;
let output_size = m * n;
let mut output_data = vec![T::default(); output_size];
unsafe {
cuda_memcpy_dtoh(
output_data.as_mut_ptr() as *mut std::ffi::c_void,
output_ptr,
output_size * std::mem::size_of::<T>(),
)?;
cuda_free(output_ptr)?;
}
let output_shape = vec![m, n];
Tensor::from_vec(output_data, &output_shape)
}
pub fn elementwise_cuda<T>(
&mut self,
a: &Tensor<T>,
b: &Tensor<T>,
operation: ElementwiseOp,
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
let kernel_name = match operation {
ElementwiseOp::Add => "cuda_elementwise_add",
ElementwiseOp::Mul => "cuda_elementwise_mul",
ElementwiseOp::Sub => "cuda_elementwise_sub",
ElementwiseOp::Div => "cuda_elementwise_div",
};
if !self.kernel_cache.contains_key(kernel_name) {
let elementwise_ptx = self.generate_elementwise_ptx(operation)?;
self.compile_kernel(&elementwise_ptx, kernel_name, &["-O3", "-use_fast_math"])?;
}
let total_elements = a.numel();
let config = self.calculate_elementwise_config(total_elements)?;
let a_data = a.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Failed to access tensor data".to_string())
})?;
let b_data = b.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple("Failed to access tensor data".to_string())
})?;
let output_ptr = self.launch_kernel(kernel_name, &config, &[a_data, b_data])?;
let mut output_data = vec![T::default(); total_elements];
unsafe {
cuda_memcpy_dtoh(
output_data.as_mut_ptr() as *mut std::ffi::c_void,
output_ptr,
total_elements * std::mem::size_of::<T>(),
)?;
cuda_free(output_ptr)?;
}
Tensor::from_vec(output_data, a.shape().dims())
}
pub fn execute_custom_kernel<T>(
&mut self,
ptx_source: &str,
kernel_name: &str,
config: &CudaKernelConfig,
inputs: &[&Tensor<T>],
) -> Result<Tensor<T>>
where
T: Clone + Default + Send + Sync + 'static,
{
if !self.kernel_cache.contains_key(kernel_name) {
self.compile_kernel(ptx_source, kernel_name, &["-O3"])?;
}
let input_slices: Result<Vec<&[T]>> = inputs
.iter()
.map(|t| {
t.as_slice().ok_or_else(|| {
TensorError::invalid_operation_simple(
"Failed to access tensor data".to_string(),
)
})
})
.collect();
let input_slices = input_slices?;
let output_ptr = self.launch_kernel(kernel_name, config, &input_slices)?;
let output_shape = if !inputs.is_empty() {
inputs[0].shape().to_vec()
} else {
vec![1]
};
let output_size = output_shape.iter().product::<usize>();
let mut output_data = vec![T::default(); output_size];
unsafe {
cuda_memcpy_dtoh(
output_data.as_mut_ptr() as *mut std::ffi::c_void,
output_ptr,
output_size * std::mem::size_of::<T>(),
)?;
cuda_free(output_ptr)?;
}
Tensor::from_vec(output_data, &output_shape)
}
fn compile_module(&mut self, source: &str, options: &[&str]) -> Result<CudaModule> {
let c_options: Vec<CString> = options
.iter()
.map(|&s| {
CString::new(s).map_err(|e| TensorError::ComputeError {
operation: "compile_cuda_module".to_string(),
details: format!("Invalid C string in CUDA options: {}", e),
retry_possible: false,
context: None,
})
})
.collect::<Result<Vec<_>>>()?;
let option_ptrs: Vec<*const i8> = c_options.iter().map(|s| s.as_ptr()).collect();
unsafe {
let mut module = std::ptr::null_mut();
cuda_module_load_data_ex(
&mut module,
source.as_ptr() as *const std::ffi::c_void,
option_ptrs.len() as u32,
option_ptrs.as_ptr(),
std::ptr::null(),
)?;
Ok(CudaModule { handle: module })
}
}
fn allocate_device_memory(&self, size: usize) -> Result<*mut std::ffi::c_void> {
unsafe {
let mut ptr = std::ptr::null_mut();
cuda_malloc(&mut ptr, size)?;
Ok(ptr)
}
}
fn calculate_gemm_config(&self, m: usize, n: usize, k: usize) -> Result<CudaKernelConfig> {
let tile_size = 16;
let grid_x = (n + tile_size - 1) / tile_size;
let grid_y = (m + tile_size - 1) / tile_size;
Ok(CudaKernelConfig {
grid_dim: (grid_x as u32, grid_y as u32, 1),
block_dim: (tile_size as u32, tile_size as u32, 1),
shared_memory: (2 * tile_size * tile_size * 4) as u32, stream: None,
})
}
fn calculate_elementwise_config(&self, total_elements: usize) -> Result<CudaKernelConfig> {
let threads_per_block = 256.min(total_elements);
let blocks_needed = (total_elements + threads_per_block - 1) / threads_per_block;
Ok(CudaKernelConfig {
grid_dim: (blocks_needed as u32, 1, 1),
block_dim: (threads_per_block as u32, 1, 1),
shared_memory: 0,
stream: None,
})
}
fn generate_gemm_ptx(&self) -> Result<String> {
Ok(include_str!("cuda_kernels/gemm_kernel.ptx").to_string())
}
fn generate_elementwise_ptx(&self, operation: ElementwiseOp) -> Result<String> {
match operation {
ElementwiseOp::Add => Ok(include_str!("cuda_kernels/elementwise_add.ptx").to_string()),
ElementwiseOp::Mul => Ok(include_str!("cuda_kernels/elementwise_mul.ptx").to_string()),
ElementwiseOp::Sub => Ok(include_str!("cuda_kernels/elementwise_sub.ptx").to_string()),
ElementwiseOp::Div => Ok(include_str!("cuda_kernels/elementwise_div.ptx").to_string()),
}
}
}
#[cfg(cuda_available)]
#[derive(Debug, Clone, Copy)]
pub enum ElementwiseOp {
Add,
Mul,
Sub,
Div,
}
#[cfg(cuda_available)]
#[derive(Debug, Clone)]
pub struct CudaDeviceProperties {
pub name: String,
pub total_global_memory: usize,
pub shared_memory_per_block: usize,
pub max_threads_per_block: usize,
pub max_grid_size: [u32; 3],
pub max_block_size: [u32; 3],
pub warp_size: usize,
pub compute_capability: (i32, i32),
pub multiprocessor_count: i32,
pub memory_clock_rate: i32,
pub memory_bus_width: i32,
}
#[cfg(cuda_available)]
impl CudaDeviceProperties {
fn query(device_id: i32) -> Result<Self> {
unsafe {
let mut props: CudaDeviceProp = std::mem::zeroed();
cuda_get_device_properties(&mut props, device_id)?;
Ok(CudaDeviceProperties {
name: CStr::from_ptr(props.name.as_ptr())
.to_string_lossy()
.into_owned(),
total_global_memory: props.total_global_memory,
shared_memory_per_block: props.shared_memory_per_block,
max_threads_per_block: props.max_threads_per_block,
max_grid_size: props.max_grid_size,
max_block_size: props.max_block_size,
warp_size: props.warp_size,
compute_capability: (props.major, props.minor),
multiprocessor_count: props.multiprocessor_count,
memory_clock_rate: props.memory_clock_rate,
memory_bus_width: props.memory_bus_width,
})
}
}
pub fn supports_cooperative_groups(&self) -> bool {
self.compute_capability >= (6, 0)
}
pub fn supports_tensor_cores(&self) -> bool {
self.compute_capability >= (7, 0)
}
pub fn memory_bandwidth_gb_s(&self) -> f64 {
(self.memory_clock_rate as f64 * 2.0 * self.memory_bus_width as f64) / (8.0 * 1e9)
}
}
#[cfg(cuda_available)]
#[repr(C)]
struct CudaDeviceProp {
name: [i8; 256],
total_global_memory: usize,
shared_memory_per_block: usize,
max_threads_per_block: usize,
max_grid_size: [u32; 3],
max_block_size: [u32; 3],
warp_size: usize,
major: i32,
minor: i32,
multiprocessor_count: i32,
memory_clock_rate: i32,
memory_bus_width: i32,
}
#[derive(Debug)]
#[cfg(cuda_available)]
struct CudaContext {
device_id: i32,
}
#[cfg(cuda_available)]
impl CudaContext {
fn new(device_id: i32) -> Result<Self> {
Ok(Self { device_id })
}
}
#[cfg(cuda_available)]
#[derive(Debug, Clone)]
pub struct CudaStream {
handle: *mut std::ffi::c_void,
}
#[cfg(cuda_available)]
impl CudaStream {
fn new() -> Result<Self> {
unsafe {
let mut stream = std::ptr::null_mut();
cuda_stream_create(&mut stream)?;
Ok(Self { handle: stream })
}
}
}
#[derive(Debug)]
#[cfg(cuda_available)]
struct CudaModule {
handle: *mut std::ffi::c_void,
}
#[cfg(cuda_available)]
impl CudaModule {
fn get_function(&self, name: &str) -> Result<CudaKernel> {
let c_name = CString::new(name).map_err(|e| TensorError::ComputeError {
operation: "get_cuda_function".to_string(),
details: format!("Invalid function name: {}", e),
retry_possible: false,
context: None,
})?;
unsafe {
let mut function = std::ptr::null_mut();
cuda_module_get_function(&mut function, self.handle, c_name.as_ptr())?;
Ok(CudaKernel { function })
}
}
}
#[derive(Debug)]
#[cfg(cuda_available)]
struct CudaKernel {
function: *mut std::ffi::c_void,
}
#[cfg(cuda_available)]
type CudaResult = i32;
#[cfg(cuda_available)]
const CUDA_SUCCESS: CudaResult = 0;
#[cfg(cuda_available)]
fn check_cuda_error(result: CudaResult, operation: &str) -> Result<()> {
if result != CUDA_SUCCESS {
return Err(TensorError::DeviceError {
operation: operation.to_string(),
details: format!("CUDA error code {}", result),
device: "CUDA".to_string(),
context: None,
});
}
Ok(())
}
#[cfg(cuda_available)]
extern "C" {
fn cudaSetDevice(device: i32) -> CudaResult;
fn cudaGetDeviceProperties(prop: *mut CudaDeviceProp, device: i32) -> CudaResult;
fn cudaMalloc(devPtr: *mut *mut std::ffi::c_void, size: usize) -> CudaResult;
fn cudaFree(devPtr: *mut std::ffi::c_void) -> CudaResult;
fn cudaMemcpy(
dst: *mut std::ffi::c_void,
src: *const std::ffi::c_void,
count: usize,
kind: u32,
) -> CudaResult;
fn cudaStreamCreate(pStream: *mut *mut std::ffi::c_void) -> CudaResult;
fn cudaStreamDestroy(stream: *mut std::ffi::c_void) -> CudaResult;
fn cudaDeviceSynchronize() -> CudaResult;
fn cudaGetLastError() -> CudaResult;
fn cudaGetErrorString(error: CudaResult) -> *const i8;
fn cuInit(flags: u32) -> CudaResult;
fn cuModuleLoadDataEx(
module: *mut *mut std::ffi::c_void,
image: *const std::ffi::c_void,
numOptions: u32,
options: *const *const i8,
optionValues: *const *const std::ffi::c_void,
) -> CudaResult;
fn cuModuleGetFunction(
hfunc: *mut *mut std::ffi::c_void,
hmod: *mut std::ffi::c_void,
name: *const i8,
) -> CudaResult;
fn cuLaunchKernel(
f: *mut std::ffi::c_void,
gridDimX: u32,
gridDimY: u32,
gridDimZ: u32,
blockDimX: u32,
blockDimY: u32,
blockDimZ: u32,
sharedMemBytes: u32,
hStream: *mut std::ffi::c_void,
kernelParams: *mut *mut std::ffi::c_void,
extra: *mut *mut std::ffi::c_void,
) -> CudaResult;
fn cuModuleUnload(hmod: *mut std::ffi::c_void) -> CudaResult;
}
#[cfg(cuda_available)]
const CUDA_MEMCPY_HOST_TO_DEVICE: u32 = 1;
#[cfg(cuda_available)]
const CUDA_MEMCPY_DEVICE_TO_HOST: u32 = 2;
#[cfg(cuda_available)]
unsafe fn cuda_init(flags: u32) -> Result<()> {
check_cuda_error(cuInit(flags), "cuInit")
}
#[cfg(cuda_available)]
unsafe fn cuda_set_device(device_id: i32) -> Result<()> {
check_cuda_error(cudaSetDevice(device_id), "cudaSetDevice")
}
#[cfg(cuda_available)]
unsafe fn cuda_get_device_properties(props: *mut CudaDeviceProp, device: i32) -> Result<()> {
check_cuda_error(
cudaGetDeviceProperties(props, device),
"cudaGetDeviceProperties",
)
}
#[cfg(cuda_available)]
unsafe fn cuda_malloc(ptr: *mut *mut std::ffi::c_void, size: usize) -> Result<()> {
check_cuda_error(cudaMalloc(ptr, size), "cudaMalloc")
}
#[cfg(cuda_available)]
unsafe fn cuda_free(ptr: *mut std::ffi::c_void) -> Result<()> {
check_cuda_error(cudaFree(ptr), "cudaFree")
}
#[cfg(cuda_available)]
unsafe fn cuda_memcpy_htod(
dst: *mut std::ffi::c_void,
src: *const std::ffi::c_void,
size: usize,
) -> Result<()> {
check_cuda_error(
cudaMemcpy(dst, src, size, CUDA_MEMCPY_HOST_TO_DEVICE),
"cudaMemcpy H2D",
)
}
#[cfg(cuda_available)]
unsafe fn cuda_memcpy_dtoh(
dst: *mut std::ffi::c_void,
src: *const std::ffi::c_void,
size: usize,
) -> Result<()> {
check_cuda_error(
cudaMemcpy(dst, src, size, CUDA_MEMCPY_DEVICE_TO_HOST),
"cudaMemcpy D2H",
)
}
#[cfg(cuda_available)]
unsafe fn cuda_stream_create(stream: *mut *mut std::ffi::c_void) -> Result<()> {
check_cuda_error(cudaStreamCreate(stream), "cudaStreamCreate")
}
#[cfg(cuda_available)]
unsafe fn cuda_module_load_data_ex(
module: *mut *mut std::ffi::c_void,
image: *const std::ffi::c_void,
num_options: u32,
options: *const *const i8,
option_values: *const *const std::ffi::c_void,
) -> Result<()> {
check_cuda_error(
cuModuleLoadDataEx(module, image, num_options, options, option_values),
"cuModuleLoadDataEx",
)
}
#[cfg(cuda_available)]
unsafe fn cuda_module_get_function(
function: *mut *mut std::ffi::c_void,
module: *mut std::ffi::c_void,
name: *const i8,
) -> Result<()> {
check_cuda_error(
cuModuleGetFunction(function, module, name),
"cuModuleGetFunction",
)
}
#[cfg(cuda_available)]
unsafe fn cuda_launch_kernel(
function: *mut std::ffi::c_void,
grid_dim: (u32, u32, u32),
block_dim: (u32, u32, u32),
args: *mut *mut std::ffi::c_void,
shared_memory: u32,
stream: *mut std::ffi::c_void,
) -> Result<()> {
check_cuda_error(
cuLaunchKernel(
function,
grid_dim.0,
grid_dim.1,
grid_dim.2,
block_dim.0,
block_dim.1,
block_dim.2,
shared_memory,
stream,
args,
std::ptr::null_mut(),
),
"cuLaunchKernel",
)
}
#[cfg(cuda_available)]
unsafe fn cuda_device_synchronize() -> Result<()> {
check_cuda_error(cudaDeviceSynchronize(), "cudaDeviceSynchronize")
}
#[cfg(cuda_available)]
pub fn is_cuda_available() -> bool {
true
}
#[cfg(not(cuda_available))]
pub fn is_cuda_available() -> bool {
false
}
#[cfg(not(cuda_available))]
pub mod cuda_stub {
use crate::{Result, TensorError};
pub fn cuda_not_available() -> Result<()> {
Err(TensorError::device_error_simple(
"CUDA kernels are only available with the 'cuda' feature enabled".to_string(),
))
}
}
#[cfg(cuda_available)]
pub mod benchmarks {
use super::*;
use std::time::{Duration, Instant};
pub struct CudaBenchmark {
device: CudaDevice,
results: Vec<BenchmarkResult>,
}
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub operation: String,
pub input_shape: Vec<usize>,
pub duration: Duration,
pub throughput_gflops: f64,
pub memory_bandwidth_gb_s: f64,
pub kernel_efficiency: f64,
}
impl CudaBenchmark {
pub fn new(device_id: i32) -> Result<Self> {
Ok(CudaBenchmark {
device: CudaDevice::new(device_id)?,
results: Vec::new(),
})
}
pub fn benchmark_kernels(
&mut self,
sizes: &[(usize, usize, usize)],
) -> Result<Vec<BenchmarkResult>> {
let mut results = Vec::new();
for &(m, n, k) in sizes {
let a = Tensor::<f32>::zeros(&[m, k]);
let b = Tensor::<f32>::zeros(&[k, n]);
let start = Instant::now();
let _result = self.device.matmul_cuda(&a, &b)?;
let duration = start.elapsed();
let operations = 2 * m * n * k; let gflops = operations as f64 / duration.as_secs_f64() / 1e9;
let memory_accessed = (m * k + k * n + m * n) * 4; let bandwidth = memory_accessed as f64 / duration.as_secs_f64() / 1e9;
let theoretical_bandwidth =
self.device.get_device_properties().memory_bandwidth_gb_s();
let efficiency = bandwidth / theoretical_bandwidth;
results.push(BenchmarkResult {
operation: format!("cuda_matmul_{}x{}x{}", m, n, k),
input_shape: vec![m, k, n],
duration,
throughput_gflops: gflops,
memory_bandwidth_gb_s: bandwidth,
kernel_efficiency: efficiency,
});
}
self.results.extend(results.clone());
Ok(results)
}
pub fn generate_report(&self) -> String {
let mut report = String::from("CUDA Kernel Performance Report\n");
report.push_str("==================================\n\n");
for result in &self.results {
report.push_str(&format!(
"Operation: {}\n Duration: {:?}\n Throughput: {:.2} GFLOPS\n Bandwidth: {:.2} GB/s\n Efficiency: {:.1}%\n\n",
result.operation, result.duration, result.throughput_gflops,
result.memory_bandwidth_gb_s, result.kernel_efficiency * 100.0
));
}
report
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(cuda_available)]
fn test_cuda_device_creation() {
let result = CudaDevice::new(0);
assert!(result.is_ok() || result.unwrap_err().to_string().contains("CUDA"));
}
#[test]
#[cfg(not(cuda_available))]
fn test_cuda_not_available() {
let result = cuda_stub::cuda_not_available();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("CUDA kernels are only available"));
}
#[test]
#[cfg(cuda_available)]
fn test_kernel_config_calculation() {
if let Ok(device) = CudaDevice::new(0) {
let config = device.calculate_gemm_config(1024, 1024, 1024);
assert!(config.is_ok());
let cfg = config.expect("test: operation should succeed");
assert!(cfg.grid_dim.0 > 0);
assert!(cfg.block_dim.0 > 0);
}
}
}