use crate::error::{Result, SynaError};
#[cfg(feature = "gpu")]
mod gpu_impl {
use super::*;
use std::ffi::c_void;
#[link(name = "cudart")]
extern "C" {
fn cudaSetDevice(device: i32) -> i32;
fn cudaGetDeviceCount(count: *mut i32) -> i32;
fn cudaMalloc(devPtr: *mut *mut c_void, size: usize) -> i32;
fn cudaFree(devPtr: *mut c_void) -> i32;
fn cudaMemcpy(dst: *mut c_void, src: *const c_void, count: usize, kind: i32) -> i32;
fn cudaMemcpyAsync(
dst: *mut c_void,
src: *const c_void,
count: usize,
kind: i32,
stream: *mut c_void,
) -> i32;
fn cudaMallocHost(ptr: *mut *mut c_void, size: usize) -> i32;
fn cudaFreeHost(ptr: *mut c_void) -> i32;
fn cudaDeviceSynchronize() -> i32;
fn cudaGetErrorString(error: i32) -> *const std::ffi::c_char;
}
const CUDA_MEMCPY_HOST_TO_DEVICE: i32 = 1;
const CUDA_MEMCPY_DEVICE_TO_HOST: i32 = 2;
pub struct GpuContext {
device: i32,
}
impl GpuContext {
pub fn new(device: i32) -> Result<Self> {
unsafe {
let mut count: i32 = 0;
let result = cudaGetDeviceCount(&mut count);
if result != 0 {
return Err(SynaError::GpuUnavailable(format!(
"Failed to get device count: {}",
get_cuda_error_string(result)
)));
}
if count == 0 {
return Err(SynaError::GpuUnavailable(
"No CUDA devices available".to_string(),
));
}
if device < 0 || device >= count {
return Err(SynaError::GpuUnavailable(format!(
"Invalid device {}: only {} devices available",
device, count
)));
}
let result = cudaSetDevice(device);
if result != 0 {
return Err(SynaError::GpuUnavailable(format!(
"Failed to set device {}: {}",
device,
get_cuda_error_string(result)
)));
}
}
Ok(Self { device })
}
pub fn device_count() -> i32 {
unsafe {
let mut count: i32 = 0;
let result = cudaGetDeviceCount(&mut count);
if result != 0 {
return 0;
}
count
}
}
pub fn device(&self) -> i32 {
self.device
}
pub fn upload(&self, data: &[f32]) -> Result<GpuTensor> {
if data.is_empty() {
return Ok(GpuTensor {
ptr: std::ptr::null_mut(),
len: 0,
device: self.device,
});
}
let size = data.len() * std::mem::size_of::<f32>();
let mut device_ptr: *mut c_void = std::ptr::null_mut();
unsafe {
cudaSetDevice(self.device);
let result = cudaMalloc(&mut device_ptr, size);
if result != 0 {
return Err(SynaError::GpuOutOfMemory(format!(
"Failed to allocate {} bytes: {}",
size,
get_cuda_error_string(result)
)));
}
let result = cudaMemcpy(
device_ptr,
data.as_ptr() as *const c_void,
size,
CUDA_MEMCPY_HOST_TO_DEVICE,
);
if result != 0 {
cudaFree(device_ptr);
return Err(SynaError::GpuUnavailable(format!(
"Failed to copy data to device: {}",
get_cuda_error_string(result)
)));
}
}
Ok(GpuTensor {
ptr: device_ptr as *mut f32,
len: data.len(),
device: self.device,
})
}
pub fn upload_pinned(&self, data: &[f32]) -> Result<GpuTensor> {
if data.is_empty() {
return Ok(GpuTensor {
ptr: std::ptr::null_mut(),
len: 0,
device: self.device,
});
}
let size = data.len() * std::mem::size_of::<f32>();
let mut device_ptr: *mut c_void = std::ptr::null_mut();
let mut pinned_ptr: *mut c_void = std::ptr::null_mut();
unsafe {
cudaSetDevice(self.device);
let result = cudaMallocHost(&mut pinned_ptr, size);
if result != 0 {
return Err(SynaError::GpuOutOfMemory(format!(
"Failed to allocate pinned memory: {}",
get_cuda_error_string(result)
)));
}
std::ptr::copy_nonoverlapping(data.as_ptr(), pinned_ptr as *mut f32, data.len());
let result = cudaMalloc(&mut device_ptr, size);
if result != 0 {
cudaFreeHost(pinned_ptr);
return Err(SynaError::GpuOutOfMemory(format!(
"Failed to allocate device memory: {}",
get_cuda_error_string(result)
)));
}
let result = cudaMemcpyAsync(
device_ptr,
pinned_ptr,
size,
CUDA_MEMCPY_HOST_TO_DEVICE,
std::ptr::null_mut(), );
if result != 0 {
cudaFree(device_ptr);
cudaFreeHost(pinned_ptr);
return Err(SynaError::GpuUnavailable(format!(
"Failed to copy data to device: {}",
get_cuda_error_string(result)
)));
}
cudaDeviceSynchronize();
cudaFreeHost(pinned_ptr);
}
Ok(GpuTensor {
ptr: device_ptr as *mut f32,
len: data.len(),
device: self.device,
})
}
pub fn synchronize(&self) -> Result<()> {
unsafe {
cudaSetDevice(self.device);
let result = cudaDeviceSynchronize();
if result != 0 {
return Err(SynaError::GpuUnavailable(format!(
"Device synchronization failed: {}",
get_cuda_error_string(result)
)));
}
}
Ok(())
}
}
pub struct GpuTensor {
ptr: *mut f32,
len: usize,
device: i32,
}
impl GpuTensor {
pub fn as_ptr(&self) -> *mut f32 {
self.ptr
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn device(&self) -> i32 {
self.device
}
pub fn size_bytes(&self) -> usize {
self.len * std::mem::size_of::<f32>()
}
pub fn download(&self) -> Result<Vec<f32>> {
if self.len == 0 || self.ptr.is_null() {
return Ok(Vec::new());
}
let mut data = vec![0.0f32; self.len];
let size = self.len * std::mem::size_of::<f32>();
unsafe {
cudaSetDevice(self.device);
let result = cudaMemcpy(
data.as_mut_ptr() as *mut c_void,
self.ptr as *const c_void,
size,
CUDA_MEMCPY_DEVICE_TO_HOST,
);
if result != 0 {
return Err(SynaError::GpuUnavailable(format!(
"Failed to copy data from device: {}",
get_cuda_error_string(result)
)));
}
}
Ok(data)
}
}
impl Drop for GpuTensor {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
cudaSetDevice(self.device);
cudaFree(self.ptr as *mut c_void);
}
}
}
}
unsafe impl Send for GpuTensor {}
fn get_cuda_error_string(error: i32) -> String {
unsafe {
let ptr = cudaGetErrorString(error);
if ptr.is_null() {
return format!("Unknown error ({})", error);
}
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
}
#[cfg(feature = "gpu")]
pub use gpu_impl::{GpuContext, GpuTensor};
#[cfg(not(feature = "gpu"))]
pub struct GpuContext;
#[cfg(not(feature = "gpu"))]
impl GpuContext {
pub fn new(_device: i32) -> Result<Self> {
Err(SynaError::GpuUnavailable(
"GPU support not compiled. Rebuild with --features gpu".to_string(),
))
}
pub fn device_count() -> i32 {
0
}
}
#[cfg(not(feature = "gpu"))]
pub struct GpuTensor;
#[cfg(not(feature = "gpu"))]
impl GpuTensor {
pub fn as_ptr(&self) -> *mut f32 {
std::ptr::null_mut()
}
pub fn len(&self) -> usize {
0
}
pub fn is_empty(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(not(feature = "gpu"))]
fn test_gpu_unavailable_without_feature() {
let result = GpuContext::new(0);
assert!(result.is_err());
match result {
Err(SynaError::GpuUnavailable(msg)) => {
assert!(msg.contains("GPU support not compiled"));
}
_ => panic!("Expected GpuUnavailable error"),
}
}
#[test]
#[cfg(not(feature = "gpu"))]
fn test_device_count_without_feature() {
assert_eq!(GpuContext::device_count(), 0);
}
#[test]
#[cfg(not(feature = "gpu"))]
fn test_stub_tensor() {
let tensor = GpuTensor;
assert!(tensor.as_ptr().is_null());
assert_eq!(tensor.len(), 0);
assert!(tensor.is_empty());
}
}