use crate::FftExecutor;
use num_complex::Complex;
use std::error::Error;
use std::fmt;
#[allow(non_camel_case_types)]
type hipfftHandle = i32;
#[allow(non_camel_case_types)]
#[repr(C)]
struct hipfftComplex {
x: f32,
y: f32,
}
#[allow(non_camel_case_types)]
type hipStream_t = *mut std::ffi::c_void;
const HIPFFT_C2C: i32 = 0x29;
const HIPFFT_FORWARD: i32 = -1;
const HIPFFT_BACKWARD: i32 = 1;
const HIP_MEMCPY_HOST_TO_DEVICE: i32 = 1;
const HIP_MEMCPY_DEVICE_TO_HOST: i32 = 2;
#[link(name = "hipfft")]
extern "C" {
fn hipfftCreate(plan: *mut hipfftHandle) -> i32;
fn hipfftDestroy(plan: hipfftHandle) -> i32;
fn hipfftPlan1d(plan: *mut hipfftHandle, nx: i32, type_: i32, batch: i32) -> i32;
fn hipfftExecC2C(
plan: hipfftHandle,
idata: *mut hipfftComplex,
odata: *mut hipfftComplex,
direction: i32,
) -> i32;
fn hipfftSetStream(plan: hipfftHandle, stream: hipStream_t) -> i32;
}
#[link(name = "amdhip64")]
extern "C" {
fn hipMalloc(ptr: *mut *mut std::ffi::c_void, size: usize) -> i32;
fn hipFree(ptr: *mut std::ffi::c_void) -> i32;
fn hipMemcpy(
dst: *mut std::ffi::c_void,
src: *const std::ffi::c_void,
count: usize,
kind: i32,
) -> i32;
fn hipDeviceSynchronize() -> i32;
}
#[derive(Debug)]
struct HipFftError(i32);
impl fmt::Display for HipFftError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "hipFFT error code {}", self.0)
}
}
impl Error for HipFftError {}
fn check(code: i32) -> Result<(), HipFftError> {
if code == 0 {
Ok(())
} else {
Err(HipFftError(code))
}
}
pub struct HipFft {
fft_size: usize,
}
impl HipFft {
pub fn new(fft_size: usize) -> Result<Self, Box<dyn Error>> {
if !std::path::Path::new("/dev/kfd").exists() {
return Err("No AMD GPU detected (/dev/kfd absent); hipFFT requires ROCm".into());
}
let mut probe: hipfftHandle = 0;
unsafe {
check(hipfftCreate(&mut probe))?;
check(hipfftPlan1d(&mut probe, 1, HIPFFT_C2C, 1))?;
hipfftDestroy(probe);
}
Ok(Self { fft_size })
}
fn exec_batch(
&self,
inputs: &[Vec<Complex<f32>>],
direction: i32,
) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let n = inputs[0].len();
let batch = inputs.len();
let mut host_in: Vec<hipfftComplex> = inputs
.iter()
.flat_map(|v| v.iter().map(|c| hipfftComplex { x: c.re, y: c.im }))
.collect();
let byte_count = host_in.len() * std::mem::size_of::<hipfftComplex>();
unsafe {
let mut d_in: *mut std::ffi::c_void = std::ptr::null_mut();
let mut d_out: *mut std::ffi::c_void = std::ptr::null_mut();
check(hipMalloc(&mut d_in, byte_count))?;
check(hipMalloc(&mut d_out, byte_count))?;
check(hipMemcpy(
d_in,
host_in.as_ptr() as *const std::ffi::c_void,
byte_count,
HIP_MEMCPY_HOST_TO_DEVICE,
))?;
let mut plan: hipfftHandle = 0;
check(hipfftCreate(&mut plan))?;
check(hipfftPlan1d(&mut plan, n as i32, HIPFFT_C2C, batch as i32))?;
check(hipfftExecC2C(
plan,
d_in as *mut hipfftComplex,
d_out as *mut hipfftComplex,
direction,
))?;
check(hipDeviceSynchronize())?;
check(hipMemcpy(
host_in.as_mut_ptr() as *mut std::ffi::c_void,
d_out,
byte_count,
HIP_MEMCPY_DEVICE_TO_HOST,
))?;
hipfftDestroy(plan);
hipFree(d_in);
hipFree(d_out);
}
let scale = if direction == HIPFFT_BACKWARD {
1.0 / n as f32
} else {
1.0
};
let output: Vec<Vec<Complex<f32>>> = host_in
.chunks(n)
.map(|chunk| {
chunk
.iter()
.map(|c| Complex::new(c.x * scale, c.y * scale))
.collect()
})
.collect();
Ok(output)
}
}
impl FftExecutor for HipFft {
fn name(&self) -> &str {
"hipFFT (ROCm/HIP reference)"
}
fn fft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
self.exec_batch(inputs, HIPFFT_FORWARD)
}
fn ifft(&self, inputs: &[Vec<Complex<f32>>]) -> Result<Vec<Vec<Complex<f32>>>, Box<dyn Error>> {
self.exec_batch(inputs, HIPFFT_BACKWARD)
}
}