#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::string::{String, ToString};
use std::sync::Arc;
use oxicuda_driver::Context;
use oxicuda_fft::{FftHandle, FftPlan, FftType};
use oxicuda_fft::FftDirection as OxiFftDirection;
use super::buffer::GpuBuffer;
use super::error::{GpuError, GpuResult};
use super::plan::GpuDirection;
use super::GpuBackend;
use super::GpuCapabilities;
use crate::kernel::{Complex, Float};
#[must_use]
pub fn is_available() -> bool {
oxicuda_driver::init().is_ok() && oxicuda_driver::Device::get(0).is_ok()
}
pub fn query_capabilities() -> GpuResult<GpuCapabilities> {
if !is_available() {
return Err(GpuError::NoBackendAvailable);
}
let device = oxicuda_driver::Device::get(0)
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
let name = device
.name()
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
let total_memory = device
.total_memory()
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
Ok(GpuCapabilities {
backend: GpuBackend::Cuda,
device_name: name,
total_memory: total_memory as u64,
available_memory: 0,
max_fft_size: 1 << 27,
supports_f64: true,
supports_f16: true,
compute_units: 0,
max_workgroup_size: 1024,
})
}
pub fn synchronize() -> GpuResult<()> {
Ok(())
}
pub struct CudaFftPlan {
size: usize,
batch_size: usize,
context: Arc<Context>,
fft_handle: FftHandle,
fft_plan: FftPlan,
}
impl std::fmt::Debug for CudaFftPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaFftPlan")
.field("size", &self.size)
.field("batch_size", &self.batch_size)
.field("fft_handle", &self.fft_handle)
.field("fft_plan", &self.fft_plan)
.finish_non_exhaustive()
}
}
impl CudaFftPlan {
pub fn new(size: usize, batch_size: usize) -> GpuResult<Self> {
if !is_available() {
return Err(GpuError::NoBackendAvailable);
}
if size == 0 {
return Err(GpuError::InvalidSize(size));
}
oxicuda_driver::init().map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
let device = oxicuda_driver::Device::get(0)
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
let raw_ctx =
Context::new(&device).map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
let context = Arc::new(raw_ctx);
let fft_handle =
FftHandle::new(&context).map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
let fft_plan = FftPlan::new_1d(size, FftType::C2C, batch_size)
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
Ok(Self {
size,
batch_size,
context,
fft_handle,
fft_plan,
})
}
pub fn execute<T: Float>(
&self,
input: &GpuBuffer<T>,
output: &mut GpuBuffer<T>,
direction: GpuDirection,
) -> GpuResult<()> {
let expected_size = self.size * self.batch_size;
if input.size() != expected_size || output.size() != expected_size {
return Err(GpuError::SizeMismatch {
expected: expected_size,
got: input.size().min(output.size()),
});
}
self.execute_cpu(input, output, direction)
}
fn execute_cpu<T: Float>(
&self,
input: &GpuBuffer<T>,
output: &mut GpuBuffer<T>,
direction: GpuDirection,
) -> GpuResult<()> {
use crate::api::{Direction, Flags, Plan};
let dir = match direction {
GpuDirection::Forward => Direction::Forward,
GpuDirection::Inverse => Direction::Backward,
};
for batch in 0..self.batch_size {
let start = batch * self.size;
let end = start + self.size;
let input_slice = &input.cpu_data()[start..end];
let output_slice = &mut output.cpu_data_mut()[start..end];
if let Some(plan) = Plan::dft_1d(self.size, dir, Flags::ESTIMATE) {
let input_f64: Vec<Complex<f64>> = input_slice
.iter()
.map(|c| {
Complex::new(c.re.to_f64().unwrap_or(0.0), c.im.to_f64().unwrap_or(0.0))
})
.collect();
let mut output_f64 = vec![Complex::<f64>::zero(); self.size];
plan.execute(&input_f64, &mut output_f64);
for (i, c) in output_f64.iter().enumerate() {
output_slice[i] = Complex::new(T::from_f64(c.re), T::from_f64(c.im));
}
} else {
return Err(GpuError::ExecutionFailed(
"Failed to create CPU fallback plan".into(),
));
}
}
Ok(())
}
#[allow(dead_code)]
pub fn context(&self) -> &Arc<Context> {
&self.context
}
#[allow(dead_code)]
fn oxi_direction(direction: GpuDirection) -> OxiFftDirection {
match direction {
GpuDirection::Forward => OxiFftDirection::Forward,
GpuDirection::Inverse => OxiFftDirection::Inverse,
}
}
}
impl Drop for CudaFftPlan {
fn drop(&mut self) {
}
}
pub fn upload_buffer<T: Float>(_buffer: &mut GpuBuffer<T>) -> GpuResult<()> {
Ok(())
}
pub fn download_buffer<T: Float>(_buffer: &mut GpuBuffer<T>) -> GpuResult<()> {
Ok(())
}
pub fn free_buffer(_ptr: *mut core::ffi::c_void) -> GpuResult<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_availability() {
let _ = is_available();
}
#[test]
fn test_cuda_capabilities() {
if is_available() {
let caps = query_capabilities().expect("Failed to query capabilities");
assert_eq!(caps.backend, GpuBackend::Cuda);
assert!(caps.supports_f64);
}
}
#[test]
fn test_cuda_plan_creation() {
if is_available() {
let plan = CudaFftPlan::new(1024, 1);
assert!(plan.is_ok());
}
}
}