#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use super::error::{GpuError, GpuResult};
use super::GpuBackend;
use crate::kernel::{Complex, Float};
#[derive(Debug)]
pub struct GpuBuffer<T: Float> {
size: usize,
backend: GpuBackend,
cpu_data: Vec<Complex<T>>,
#[cfg(feature = "cuda")]
cuda_ptr: Option<*mut core::ffi::c_void>,
#[cfg(feature = "metal")]
metal_buffer: Option<super::metal::MetalBufferHandle>,
}
unsafe impl<T: Float> Send for GpuBuffer<T> {}
unsafe impl<T: Float> Sync for GpuBuffer<T> {}
impl<T: Float> GpuBuffer<T> {
pub fn new(size: usize, backend: GpuBackend) -> GpuResult<Self> {
if size == 0 {
return Err(GpuError::InvalidSize(size));
}
let cpu_data = vec![Complex::<T>::zero(); size];
Ok(Self {
size,
backend,
cpu_data,
#[cfg(feature = "cuda")]
cuda_ptr: None,
#[cfg(feature = "metal")]
metal_buffer: None,
})
}
pub fn from_slice(data: &[Complex<T>], backend: GpuBackend) -> GpuResult<Self> {
if data.is_empty() {
return Err(GpuError::InvalidSize(0));
}
let mut buffer = Self::new(data.len(), backend)?;
buffer.upload(data)?;
Ok(buffer)
}
#[must_use]
pub const fn size(&self) -> usize {
self.size
}
#[must_use]
pub const fn backend(&self) -> GpuBackend {
self.backend
}
pub fn upload(&mut self, data: &[Complex<T>]) -> GpuResult<()> {
if data.len() != self.size {
return Err(GpuError::SizeMismatch {
expected: self.size,
got: data.len(),
});
}
self.cpu_data.copy_from_slice(data);
match self.backend {
GpuBackend::Cuda => {
#[cfg(feature = "cuda")]
{
self.upload_cuda()?;
Ok(())
}
#[cfg(not(feature = "cuda"))]
{
return Err(GpuError::NoBackendAvailable);
}
}
GpuBackend::Metal => {
#[cfg(feature = "metal")]
{
self.upload_metal()?;
Ok(())
}
#[cfg(not(feature = "metal"))]
{
return Err(GpuError::NoBackendAvailable);
}
}
GpuBackend::Auto => {
#[cfg(feature = "cuda")]
if super::cuda::is_available() {
self.upload_cuda()?;
return Ok(());
}
#[cfg(feature = "metal")]
if super::metal::is_available() {
self.upload_metal()?;
return Ok(());
}
Err(GpuError::NoBackendAvailable)
}
_ => Err(GpuError::Unsupported("Backend not implemented".into())),
}
}
pub fn download(&mut self, data: &mut [Complex<T>]) -> GpuResult<()> {
if data.len() != self.size {
return Err(GpuError::SizeMismatch {
expected: self.size,
got: data.len(),
});
}
match self.backend {
GpuBackend::Cuda => {
#[cfg(feature = "cuda")]
{
self.download_cuda()?;
data.copy_from_slice(&self.cpu_data);
Ok(())
}
#[cfg(not(feature = "cuda"))]
{
return Err(GpuError::NoBackendAvailable);
}
}
GpuBackend::Metal => {
#[cfg(feature = "metal")]
{
self.download_metal()?;
data.copy_from_slice(&self.cpu_data);
Ok(())
}
#[cfg(not(feature = "metal"))]
{
return Err(GpuError::NoBackendAvailable);
}
}
GpuBackend::Auto => {
#[cfg(feature = "cuda")]
if super::cuda::is_available() {
self.download_cuda()?;
data.copy_from_slice(&self.cpu_data);
return Ok(());
}
#[cfg(feature = "metal")]
if super::metal::is_available() {
self.download_metal()?;
data.copy_from_slice(&self.cpu_data);
return Ok(());
}
Err(GpuError::NoBackendAvailable)
}
_ => Err(GpuError::Unsupported("Backend not implemented".into())),
}
}
#[must_use]
pub fn cpu_data(&self) -> &[Complex<T>] {
&self.cpu_data
}
pub fn cpu_data_mut(&mut self) -> &mut [Complex<T>] {
&mut self.cpu_data
}
#[cfg(feature = "cuda")]
fn upload_cuda(&mut self) -> GpuResult<()> {
super::cuda::upload_buffer(self)
}
#[cfg(feature = "cuda")]
fn download_cuda(&mut self) -> GpuResult<()> {
super::cuda::download_buffer(self)
}
#[cfg(feature = "metal")]
fn upload_metal(&mut self) -> GpuResult<()> {
super::metal::upload_buffer(self)
}
#[cfg(feature = "metal")]
fn download_metal(&mut self) -> GpuResult<()> {
super::metal::download_buffer(self)
}
}
impl<T: Float> Drop for GpuBuffer<T> {
fn drop(&mut self) {
#[cfg(feature = "cuda")]
if let Some(ptr) = self.cuda_ptr.take() {
let _ = super::cuda::free_buffer(ptr);
}
#[cfg(feature = "metal")]
if let Some(handle) = self.metal_buffer.take() {
let _ = super::metal::free_buffer(handle);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_buffer_creation() {
let buffer: GpuBuffer<f64> =
GpuBuffer::new(1024, GpuBackend::Auto).expect("Failed to create buffer");
assert_eq!(buffer.size(), 1024);
}
#[test]
fn test_gpu_buffer_cpu_data() {
let mut buffer: GpuBuffer<f64> =
GpuBuffer::new(8, GpuBackend::Auto).expect("Failed to create buffer");
buffer.cpu_data_mut()[0] = Complex::new(1.0, 2.0);
assert_eq!(buffer.cpu_data()[0], Complex::new(1.0, 2.0));
}
}