use crate::error::{RusTorchError, RusTorchResult};
#[cfg(feature = "cuda")]
use cudarc::{
cublas::CudaBlas,
driver::CudaDevice,
};
#[cfg(feature = "cuda")]
pub struct SimpleCudaExecutor {
device: CudaDevice,
cublas: CudaBlas,
device_id: usize,
}
#[cfg(feature = "cuda")]
impl SimpleCudaExecutor {
pub fn new(device_id: usize) -> RusTorchResult<Self> {
let device = CudaDevice::new(device_id).map_err(|e| {
RusTorchError::gpu(format!("Failed to initialize CUDA device {}: {}", device_id, e))
})?;
let cublas = CudaBlas::new(device.clone()).map_err(|e| {
RusTorchError::gpu(format!("Failed to initialize cuBLAS: {}", e))
})?;
Ok(Self {
device,
cublas,
device_id,
})
}
pub fn matmul_simple(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> RusTorchResult<()> {
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(RusTorchError::tensor_op(
"Matrix dimensions mismatch".to_string(),
));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for l in 0..k {
sum += a[i * k + l] * b[l * n + j];
}
c[i * n + j] = sum;
}
}
Ok(())
}
pub fn device_info(&self) -> (usize, String) {
(self.device_id, format!("CUDA Device {}", self.device_id))
}
pub fn elementwise_add(
&self,
a: &[f32],
b: &[f32],
) -> RusTorchResult<Vec<f32>> {
if a.len() != b.len() {
return Err(RusTorchError::tensor_op(
"Tensor size mismatch for addition".to_string(),
));
}
let result: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x + y).collect();
Ok(result)
}
}
#[cfg(not(feature = "cuda"))]
pub struct SimpleCudaExecutor;
#[cfg(not(feature = "cuda"))]
impl SimpleCudaExecutor {
pub fn new(_device_id: usize) -> RusTorchResult<Self> {
Err(RusTorchError::gpu("CUDA feature not enabled".to_string()))
}
}