use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[cfg(feature = "cuda")]
use cudarc::{
cublas::{CudaBlas, Gemm},
driver::{CudaDevice, CudaSlice, DevicePtr, LaunchAsync, LaunchConfig},
nvrtc::compile_ptx,
};
#[cfg(feature = "cuda")]
pub struct CudaMatrixExecutor {
device: Arc<CudaDevice>,
cublas: CudaBlas,
device_id: usize,
streams: Vec<cudarc::driver::CudaStream>,
memory_pool: Arc<Mutex<CudaMemoryPool>>,
}
#[cfg(feature = "cuda")]
impl CudaMatrixExecutor {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
let device = CudaDevice::new(device_id).map_err(|e| {
RusTorchError::tensor_op(format!(
"Failed to initialize CUDA device {}: {}",
device_id, e
))
})?;
let cublas = CudaBlas::new(device.clone())
.map_err(|e| RusTorchError::tensor_op(format!("Failed to initialize cuBLAS: {}", e)))?;
let mut streams = Vec::new();
for i in 0..4 {
let stream = device.fork_default_stream().map_err(|e| {
RusTorchError::tensor_op(format!("Failed to create CUDA stream {}: {}", i, e))
})?;
streams.push(stream);
}
let memory_pool = Arc::new(Mutex::new(CudaMemoryPool::new()));
Ok(Self {
device,
cublas,
device_id,
streams,
memory_pool,
})
}
pub fn matmul_f32(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
use_tensor_cores: bool,
) -> RusTorchResult<()> {
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(RusTorchError::shape_mismatch(
&[m, k, k, n, m, n],
&[a.len(), 0, b.len(), 0, c.len(), 0],
));
}
let a_dev = self.device.htod_copy(a.to_vec()).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to copy matrix A to GPU: {}", e))
})?;
let b_dev = self.device.htod_copy(b.to_vec()).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to copy matrix B to GPU: {}", e))
})?;
let mut c_dev = self.device.alloc_zeros(m * n).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to allocate GPU memory for result: {}", e))
})?;
self.perform_standard_gemm(&a_dev, &b_dev, &mut c_dev, m, n, k)?;
self.device.dtoh_sync_copy_into(&c_dev, c).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to copy result from GPU: {}", e))
})?;
Ok(())
}
fn perform_tensor_core_gemm(
&self,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
c: &mut CudaSlice<f32>,
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
self.perform_standard_gemm(a, b, c, m, n, k)
}
fn perform_standard_gemm(
&self,
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
c: &mut CudaSlice<f32>,
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
let alpha = 1.0f32;
let beta = 0.0f32;
use cudarc::cublas::GemmConfig;
let cfg = GemmConfig {
transa: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
transb: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
m: m as i32,
n: n as i32,
k: k as i32,
lda: k as i32, ldb: n as i32, ldc: n as i32, alpha,
beta,
};
unsafe {
self.cublas
.gemm(cfg, a, b, c)
.map_err(|e| RusTorchError::tensor_op(format!("cuBLAS SGEMM failed: {}", e)))?;
}
Ok(())
}
pub fn batch_matmul_f32(
&self,
a_batch: &[Vec<f32>],
b_batch: &[Vec<f32>],
c_batch: &mut [Vec<f32>],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
if a_batch.len() != b_batch.len() || a_batch.len() != c_batch.len() {
return Err(RusTorchError::shape_mismatch(
&[a_batch.len()],
&[b_batch.len(), c_batch.len()],
));
}
let batch_size = a_batch.len();
let stream_count = self.streams.len().min(batch_size);
for (batch_idx, ((a, b), c)) in a_batch
.iter()
.zip(b_batch.iter())
.zip(c_batch.iter_mut())
.enumerate()
{
let stream_idx = batch_idx % stream_count;
let _stream = &self.streams[stream_idx];
self.matmul_f32(a, b, c, m, n, k, false)?;
}
self.device.synchronize().map_err(|e| {
RusTorchError::tensor_op(format!("Failed to synchronize CUDA device: {}", e))
})?;
Ok(())
}
pub fn get_memory_info(&self) -> RusTorchResult<(usize, usize)> {
let _ = self.device.synchronize();
let total = 8usize * 1024 * 1024 * 1024; let free = total / 2;
Ok((free, total))
}
pub fn get_compute_capability(&self) -> RusTorchResult<(i32, i32)> {
let major = self.device.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get compute capability major: {}", e))
})? as i32;
let minor = self.device.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
).map_err(|e| {
RusTorchError::tensor_op(format!("Failed to get compute capability minor: {}", e))
})? as i32;
Ok((major, minor))
}
}
#[allow(dead_code)]
struct CudaMemoryPool {
buffers: HashMap<usize, Vec<*mut u8>>,
}
#[allow(dead_code)]
impl CudaMemoryPool {
fn new() -> Self {
Self {
buffers: HashMap::new(),
}
}
fn get_buffer(&mut self, size: usize) -> Option<*mut u8> {
self.buffers.get_mut(&size)?.pop()
}
fn return_buffer(&mut self, ptr: *mut u8, size: usize) {
self.buffers.entry(size).or_default().push(ptr);
}
}
#[cfg(not(feature = "cuda"))]
pub struct CudaMatrixExecutor;
#[cfg(not(feature = "cuda"))]
impl CudaMatrixExecutor {
pub fn new(_device_id: usize) -> RusTorchResult<Self> {
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
#[allow(clippy::too_many_arguments)]
pub fn matmul_f32(
&self,
_a: &[f32],
_b: &[f32],
_c: &mut [f32],
_m: usize,
_n: usize,
_k: usize,
_use_tensor_cores: bool,
) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
pub fn batch_matmul_f32(
&self,
_a_batch: &[Vec<f32>],
_b_batch: &[Vec<f32>],
_c_batch: &mut [Vec<f32>],
_m: usize,
_n: usize,
_k: usize,
) -> RusTorchResult<()> {
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
pub fn get_memory_info(&self) -> RusTorchResult<(usize, usize)> {
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
pub fn get_compute_capability(&self) -> RusTorchResult<(i32, i32)> {
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
}
#[allow(clippy::too_many_arguments, clippy::many_single_char_names)] pub fn cuda_matmul_f32(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
device_id: usize,
use_tensor_cores: bool,
) -> RusTorchResult<()> {
#[cfg(feature = "cuda")]
{
let executor = CudaMatrixExecutor::new(device_id)?;
executor.matmul_f32(a, b, c, m, n, k, use_tensor_cores)
}
#[cfg(not(feature = "cuda"))]
{
let _ = (a, b, c, m, n, k, device_id, use_tensor_cores);
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
}
pub fn cuda_batch_matmul_f32(
a_batch: &[Vec<f32>],
b_batch: &[Vec<f32>],
c_batch: &mut [Vec<f32>],
m: usize,
n: usize,
k: usize,
device_id: usize,
) -> RusTorchResult<()> {
#[cfg(feature = "cuda")]
{
let executor = CudaMatrixExecutor::new(device_id)?;
executor.batch_matmul_f32(a_batch, b_batch, c_batch, m, n, k)
}
#[cfg(not(feature = "cuda"))]
{
let _ = (a_batch, b_batch, c_batch, m, n, k, device_id);
Err(RusTorchError::UnsupportedDevice(
"CUDA not available".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_executor_creation() {
let result = CudaMatrixExecutor::new(0);
#[cfg(not(feature = "cuda"))]
assert!(result.is_err());
}
#[test]
fn test_cuda_matmul_interface() {
let a = vec![1.0f32; 64];
let b = vec![2.0f32; 64];
let mut c = vec![0.0f32; 64];
let result = cuda_matmul_f32(&a, &b, &mut c, 8, 8, 8, 0, false);
#[cfg(not(feature = "cuda"))]
assert!(result.is_err());
}
}