#![cfg(feature = "cufft")]
use std::sync::Arc;
use cudarc::cufft::sys::float2;
use cudarc::driver::{CudaSlice, CudaStream};
use num_complex::Complex32;
use crate::{CudaDevice, GpuError};
const _: () = assert!(
core::mem::size_of::<Complex32>() == core::mem::size_of::<float2>(),
"Complex32 and cufft::float2 must be the same size"
);
pub struct CudaBuffer {
pub(crate) slice: CudaSlice<float2>,
stream: Arc<CudaStream>,
len: usize,
}
impl CudaBuffer {
pub fn zeroed(dev: &CudaDevice, len: usize) -> Result<Self, GpuError> {
let stream = dev.cuda_context().default_stream();
let slice = stream
.alloc_zeros::<float2>(len)
.map_err(map_cuda_err)?;
Ok(Self { slice, stream, len })
}
pub fn from_slice(dev: &CudaDevice, host: &[Complex32]) -> Result<Self, GpuError> {
let stream = dev.cuda_context().default_stream();
let host_f2: Vec<float2> =
host.iter().map(|c| float2 { x: c.re, y: c.im }).collect();
let slice = stream.clone_htod(&host_f2).map_err(map_cuda_err)?;
Ok(Self { slice, stream, len: host.len() })
}
pub fn to_vec(&self) -> Result<Vec<Complex32>, GpuError> {
let host_f2: Vec<float2> =
self.stream.clone_dtoh(&self.slice).map_err(map_cuda_err)?;
Ok(host_f2.iter().map(|f| Complex32::new(f.x, f.y)).collect())
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
fn map_cuda_err<E: core::fmt::Display>(e: E) -> GpuError {
GpuError::CudaError(e.to_string())
}