#[cfg(feature = "cuda")]
use std::{ffi::c_void, sync::Arc};
#[cfg(feature = "cuda")]
use cudarc::driver::CudaStream;
use realfft::RealFftPlanner;
use rustfft::FftPlanner;
use crate::Complex_;
pub fn fft3d_r2c(
data_r: &mut [f32],
dims: (usize, usize, usize),
planner: &mut FftPlanner<f32>,
) -> Vec<Complex_> {
let (nx, ny, nz) = dims;
let nzc = nz / 2 + 1;
let n_cplx = nx * ny * nzc;
let mut rplanner = RealFftPlanner::<f32>::new();
let r2c_z = rplanner.plan_fft_forward(nz);
let fft_y = planner.plan_fft_forward(ny);
let fft_x = planner.plan_fft_forward(nx);
let mut out = vec![Complex_::new(0.0, 0.0); n_cplx];
for ix in 0..nx {
for iy in 0..ny {
let row_r = ix * (ny * nz) + iy * nz;
let row_k = ix * (ny * nzc) + iy * nzc;
let in_row = &mut data_r[row_r..row_r + nz];
let out_row = &mut out[row_k..row_k + nzc];
r2c_z.process(in_row, out_row).unwrap();
}
}
{
let mut tmp = vec![Complex_::new(0.0, 0.0); ny];
for ix in 0..nx {
for izc in 0..nzc {
for (j, iy) in (0..ny).enumerate() {
tmp[j] = out[ix * (ny * nzc) + iy * nzc + izc];
}
fft_y.process(&mut tmp);
for (j, iy) in (0..ny).enumerate() {
out[ix * (ny * nzc) + iy * nzc + izc] = tmp[j];
}
}
}
}
{
let mut tmp = vec![Complex_::new(0.0, 0.0); nx];
for iy in 0..ny {
for izc in 0..nzc {
for (k, ix) in (0..nx).enumerate() {
tmp[k] = out[ix * (ny * nzc) + iy * nzc + izc];
}
fft_x.process(&mut tmp);
for (k, ix) in (0..nx).enumerate() {
out[ix * (ny * nzc) + iy * nzc + izc] = tmp[k];
}
}
}
}
out
}
pub fn fft3d_c2r(
data_k: &mut [Complex_],
dims: (usize, usize, usize),
planner: &mut FftPlanner<f32>,
) -> Vec<f32> {
let (nx, ny, nz) = dims;
let nzc = nz / 2 + 1;
let mut planner_real = RealFftPlanner::<f32>::new();
let c2r_z = planner_real.plan_fft_inverse(nz);
let ifft_y = planner.plan_fft_inverse(ny);
let ifft_x = planner.plan_fft_inverse(nx);
{
let mut tmp = vec![Complex_::new(0.0, 0.0); nx];
for iy in 0..ny {
for izc in 0..nzc {
for (k, ix) in (0..nx).enumerate() {
tmp[k] = data_k[ix * (ny * nzc) + iy * nzc + izc];
}
ifft_x.process(&mut tmp);
for (k, ix) in (0..nx).enumerate() {
data_k[ix * (ny * nzc) + iy * nzc + izc] = tmp[k];
}
}
}
}
{
let mut tmp = vec![Complex_::new(0.0, 0.0); ny];
for ix in 0..nx {
for izc in 0..nzc {
for (j, iy) in (0..ny).enumerate() {
tmp[j] = data_k[ix * (ny * nzc) + iy * nzc + izc];
}
ifft_y.process(&mut tmp);
for (j, iy) in (0..ny).enumerate() {
data_k[ix * (ny * nzc) + iy * nzc + izc] = tmp[j];
}
}
}
}
let mut out = vec![0.; nx * ny * nz];
for ix in 0..nx {
for iy in 0..ny {
let row_k = ix * (ny * nzc) + iy * nzc;
let row_r = ix * (ny * nz) + iy * nz;
let in_row = &mut data_k[row_k..row_k + nzc];
let out_row = &mut out[row_r..row_r + nz];
in_row[0].im = 0.0;
if nz % 2 == 0 {
in_row[nzc - 1].im = 0.0;
}
c2r_z.process(in_row, out_row).unwrap();
}
}
out
}
#[cfg(feature = "cuda")]
unsafe extern "C" {
pub(crate) fn make_plan(nx: i32, ny: i32, nz: i32, cu_stream: *mut c_void) -> *mut c_void;
pub(crate) fn destroy_plan(plan: *mut c_void);
pub(crate) fn exec_forward(plan: *mut c_void, rho_real: *mut c_void, rho: *mut c_void);
pub(crate) fn exec_inverse(
plan: *mut c_void,
exk: *mut c_void,
eyk: *mut c_void,
ezk: *mut c_void,
ex: *mut c_void,
ey: *mut c_void,
ez: *mut c_void,
);
}
#[cfg(feature = "cuda")]
pub(crate) fn create_gpu_plan(
dims: (usize, usize, usize),
stream: &Arc<CudaStream>,
) -> *mut c_void {
let raw_stream: *mut c_void = stream.cu_stream() as *mut c_void;
let (nx, ny, nz) = (dims.0 as i32, dims.1 as i32, dims.2 as i32);
unsafe { make_plan(nx, ny, nz, raw_stream) }
}