use mpi::datatype::Equivalence;
use mpi::topology::Communicator;
use crate::api::Direction;
use crate::kernel::{Complex, Float};
use crate::mpi::error::MpiError;
use crate::mpi::pool::{MpiFloat, MpiPool};
use crate::mpi::MpiFlags;
#[allow(dead_code)]
pub struct MpiPlanND<T: Float, C: Communicator> {
dims: Vec<usize>,
local_n0: usize,
local_0_start: usize,
direction: Direction,
flags: MpiFlags,
pool: *const MpiPool<C>,
local_plans: Vec<crate::api::Plan<T>>,
scratch: Vec<Complex<T>>,
_phantom: core::marker::PhantomData<(T, C)>,
}
unsafe impl<T: Float, C: Communicator + Send> Send for MpiPlanND<T, C> {}
unsafe impl<T: Float, C: Communicator + Sync> Sync for MpiPlanND<T, C> {}
impl<T: Float + MpiFloat, C: Communicator> MpiPlanND<T, C>
where
Complex<T>: Equivalence,
{
pub fn new(
dims: &[usize],
direction: Direction,
flags: MpiFlags,
pool: &MpiPool<C>,
) -> Result<Self, MpiError> {
if dims.is_empty() {
return Err(MpiError::InvalidDimension {
dim: 0,
size: 0,
message: "Cannot create plan with zero dimensions".to_string(),
});
}
for (i, &size) in dims.iter().enumerate() {
if size == 0 {
return Err(MpiError::InvalidDimension {
dim: i,
size,
message: "Dimension size cannot be zero".to_string(),
});
}
}
let partition = pool.local_partition(dims[0]);
let local_n0 = partition.local_n;
let local_0_start = partition.local_start;
let remaining_product: usize = dims[1..].iter().product();
let local_size = local_n0 * remaining_product;
let mut local_plans = Vec::with_capacity(dims.len());
for (i, &n) in dims.iter().enumerate() {
let plan = crate::api::Plan::dft_1d(n, direction, flags.base).ok_or_else(|| {
MpiError::FftError {
message: format!("Failed to create plan for dimension {i} (size {n})"),
}
})?;
local_plans.push(plan);
}
let scratch_size = local_size * 2; let scratch = vec![Complex::<T>::zero(); scratch_size];
Ok(Self {
dims: dims.to_vec(),
local_n0,
local_0_start,
direction,
flags,
pool: std::ptr::from_ref(pool),
local_plans,
scratch,
_phantom: core::marker::PhantomData,
})
}
pub fn dims(&self) -> &[usize] {
&self.dims
}
pub fn ndim(&self) -> usize {
self.dims.len()
}
pub fn local_info(&self) -> (usize, usize) {
(self.local_n0, self.local_0_start)
}
pub fn execute_inplace(&mut self, data: &mut [Complex<T>]) -> Result<(), MpiError> {
let ndim = self.dims.len();
let remaining_product: usize = self.dims[1..].iter().product();
let expected_size = self.local_n0 * remaining_product;
if data.len() < expected_size {
return Err(MpiError::SizeMismatch {
expected: expected_size,
actual: data.len(),
});
}
let pool = unsafe { &*self.pool };
if ndim == 1 {
return Ok(());
}
for d in (1..ndim).rev() {
self.fft_along_dimension(data, d)?;
}
self.distributed_fft_dim0(data, pool)?;
Ok(())
}
#[allow(clippy::needless_pass_by_ref_mut)]
fn fft_along_dimension(&mut self, data: &mut [Complex<T>], dim: usize) -> Result<(), MpiError> {
let n_dim = self.dims[dim];
let plan = &self.local_plans[dim];
let inner_product: usize = self.dims[dim + 1..].iter().product();
let outer_product: usize = self.local_n0 * self.dims[1..dim].iter().product::<usize>();
let mut buffer = vec![Complex::<T>::zero(); n_dim];
let mut output = vec![Complex::<T>::zero(); n_dim];
for outer in 0..outer_product {
for inner in 0..inner_product {
for i in 0..n_dim {
let idx = outer * self.dims[dim..].iter().product::<usize>()
+ i * inner_product
+ inner;
buffer[i] = data[idx];
}
plan.execute(&buffer, &mut output);
for i in 0..n_dim {
let idx = outer * self.dims[dim..].iter().product::<usize>()
+ i * inner_product
+ inner;
data[idx] = output[i];
}
}
}
Ok(())
}
#[allow(clippy::needless_pass_by_ref_mut)]
fn distributed_fft_dim0(
&mut self,
data: &mut [Complex<T>],
pool: &MpiPool<C>,
) -> Result<(), MpiError> {
let n0 = self.dims[0];
let plan_n0 = &self.local_plans[0];
let stride: usize = self.dims[1..].iter().product();
let _num_procs = pool.size();
let _rank = pool.rank();
for fiber_idx in 0..stride {
let mut local_fiber = Vec::with_capacity(self.local_n0);
for i0 in 0..self.local_n0 {
local_fiber.push(data[i0 * stride + fiber_idx]);
}
let mut global_fiber = vec![Complex::<T>::zero(); n0];
let local_partition = pool.local_partition(n0);
for (i, &val) in local_fiber.iter().enumerate() {
global_fiber[local_partition.local_start + i] = val;
}
pool.barrier();
let mut fft_result = vec![Complex::<T>::zero(); n0];
plan_n0.execute(&global_fiber, &mut fft_result);
for i0 in 0..self.local_n0 {
let global_i0 = local_partition.local_start + i0;
data[i0 * stride + fiber_idx] = fft_result[global_i0];
}
}
Ok(())
}
pub fn execute(
&mut self,
input: &[Complex<T>],
output: &mut [Complex<T>],
) -> Result<(), MpiError> {
let remaining_product: usize = self.dims[1..].iter().product();
let expected_size = self.local_n0 * remaining_product;
if input.len() < expected_size {
return Err(MpiError::SizeMismatch {
expected: expected_size,
actual: input.len(),
});
}
output[..expected_size].copy_from_slice(&input[..expected_size]);
self.execute_inplace(output)
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_nd_dimensions() {
let dims = [16, 8, 4];
let remaining: usize = dims[1..].iter().product();
assert_eq!(remaining, 32);
}
}