#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use super::error::{GpuError, GpuResult};
use super::GpuBackend;
use crate::kernel::{Complex, Float};
#[derive(Debug)]
pub struct GpuBuffer<T: Float> {
size: usize,
backend: GpuBackend,
cpu_data: Vec<Complex<T>>,
}
unsafe impl<T: Float> Send for GpuBuffer<T> {}
unsafe impl<T: Float> Sync for GpuBuffer<T> {}
impl<T: Float> GpuBuffer<T> {
pub fn new(size: usize, backend: GpuBackend) -> GpuResult<Self> {
if size == 0 {
return Err(GpuError::InvalidSize(size));
}
let cpu_data = vec![Complex::<T>::zero(); size];
Ok(Self {
size,
backend,
cpu_data,
})
}
pub fn from_slice(data: &[Complex<T>], backend: GpuBackend) -> GpuResult<Self> {
if data.is_empty() {
return Err(GpuError::InvalidSize(0));
}
let mut buffer = Self::new(data.len(), backend)?;
buffer.upload(data)?;
Ok(buffer)
}
#[must_use]
pub const fn size(&self) -> usize {
self.size
}
#[must_use]
pub const fn backend(&self) -> GpuBackend {
self.backend
}
pub fn upload(&mut self, data: &[Complex<T>]) -> GpuResult<()> {
if data.len() != self.size {
return Err(GpuError::SizeMismatch {
expected: self.size,
got: data.len(),
});
}
self.cpu_data.copy_from_slice(data);
Ok(())
}
pub fn download(&mut self, data: &mut [Complex<T>]) -> GpuResult<()> {
if data.len() != self.size {
return Err(GpuError::SizeMismatch {
expected: self.size,
got: data.len(),
});
}
data.copy_from_slice(&self.cpu_data);
Ok(())
}
#[must_use]
pub fn cpu_data(&self) -> &[Complex<T>] {
&self.cpu_data
}
pub fn cpu_data_mut(&mut self) -> &mut [Complex<T>] {
&mut self.cpu_data
}
}
impl<T: Float> Drop for GpuBuffer<T> {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_buffer_creation() {
let buffer: GpuBuffer<f64> =
GpuBuffer::new(1024, GpuBackend::Auto).expect("Failed to create buffer");
assert_eq!(buffer.size(), 1024);
}
#[test]
fn test_gpu_buffer_cpu_data() {
let mut buffer: GpuBuffer<f64> =
GpuBuffer::new(8, GpuBackend::Auto).expect("Failed to create buffer");
buffer.cpu_data_mut()[0] = Complex::new(1.0, 2.0);
assert_eq!(buffer.cpu_data()[0], Complex::new(1.0, 2.0));
}
}