use mpi::datatype::Equivalence;
use mpi::topology::Communicator;
use crate::api::{Direction, Plan};
use crate::kernel::{Complex, Float};
use crate::mpi::distribution::LocalPartition;
use crate::mpi::error::MpiError;
use crate::mpi::pool::{MpiFloat, MpiPool};
use crate::mpi::transpose::distributed_transpose;
use crate::mpi::MpiFlags;
pub struct MpiPlan2D<T: Float, C: Communicator> {
n0: usize,
n1: usize,
local_n0: usize,
local_0_start: usize,
direction: Direction,
flags: MpiFlags,
pool: *const MpiPool<C>,
row_plan: Plan<T>,
col_plan: Plan<T>,
scratch: Vec<Complex<T>>,
_phantom: core::marker::PhantomData<(T, C)>,
}
unsafe impl<T: Float, C: Communicator + Send> Send for MpiPlan2D<T, C> {}
unsafe impl<T: Float, C: Communicator + Sync> Sync for MpiPlan2D<T, C> {}
impl<T: Float + MpiFloat, C: Communicator> MpiPlan2D<T, C>
where
Complex<T>: Equivalence,
{
pub fn new(
n0: usize,
n1: usize,
direction: Direction,
flags: MpiFlags,
pool: &MpiPool<C>,
) -> Result<Self, MpiError> {
if n0 == 0 || n1 == 0 {
return Err(MpiError::InvalidDimension {
dim: usize::from(n0 != 0),
size: if n0 == 0 { n0 } else { n1 },
message: "Dimension size cannot be zero".to_string(),
});
}
let partition = pool.local_partition(n0);
let local_n0 = partition.local_n;
let local_0_start = partition.local_start;
let transposed_partition = LocalPartition::new(n1, pool.size(), pool.rank());
let scratch_size = (local_n0 * n1).max(n0 * transposed_partition.local_n);
let row_plan =
Plan::dft_1d(n1, direction, flags.base).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create row plan for size {n1}"),
})?;
let col_plan =
Plan::dft_1d(n0, direction, flags.base).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create column plan for size {n0}"),
})?;
let scratch = vec![Complex::<T>::zero(); scratch_size];
Ok(Self {
n0,
n1,
local_n0,
local_0_start,
direction,
flags,
pool: std::ptr::from_ref(pool),
row_plan,
col_plan,
scratch,
_phantom: core::marker::PhantomData,
})
}
pub fn dims(&self) -> (usize, usize) {
(self.n0, self.n1)
}
pub fn local_dims(&self) -> (usize, usize, usize) {
(self.local_n0, self.local_0_start, self.n1)
}
pub fn direction(&self) -> Direction {
self.direction
}
pub fn execute_inplace(&mut self, data: &mut [Complex<T>]) -> Result<(), MpiError> {
let expected_size = self.local_n0 * self.n1;
if data.len() < expected_size {
return Err(MpiError::SizeMismatch {
expected: expected_size,
actual: data.len(),
});
}
let pool = unsafe { &*self.pool };
let mut row_buffer = vec![Complex::<T>::zero(); self.n1];
for row in 0..self.local_n0 {
let row_start = row * self.n1;
row_buffer.copy_from_slice(&data[row_start..row_start + self.n1]);
self.row_plan
.execute(&row_buffer, &mut data[row_start..row_start + self.n1]);
}
distributed_transpose(
pool,
data,
&mut self.scratch,
self.n0,
self.n1,
self.local_n0,
self.local_0_start,
)?;
let transposed_partition = LocalPartition::new(self.n1, pool.size(), pool.rank());
let local_n1 = transposed_partition.local_n;
let mut col_buffer = vec![Complex::<T>::zero(); self.n0];
for col in 0..local_n1 {
let col_start = col * self.n0;
col_buffer.copy_from_slice(&self.scratch[col_start..col_start + self.n0]);
self.col_plan.execute(
&col_buffer,
&mut self.scratch[col_start..col_start + self.n0],
);
}
if !self.flags.transposed_out {
let temp = self.scratch.clone();
distributed_transpose(
pool,
&temp,
data,
self.n1,
self.n0,
local_n1,
transposed_partition.local_start,
)?;
} else {
let transposed_size = local_n1 * self.n0;
data[..transposed_size].copy_from_slice(&self.scratch[..transposed_size]);
}
Ok(())
}
pub fn execute(
&mut self,
input: &[Complex<T>],
output: &mut [Complex<T>],
) -> Result<(), MpiError> {
let expected_size = self.local_n0 * self.n1;
if input.len() < expected_size {
return Err(MpiError::SizeMismatch {
expected: expected_size,
actual: input.len(),
});
}
output[..expected_size].copy_from_slice(&input[..expected_size]);
self.execute_inplace(output)
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_local_partition() {
use crate::mpi::distribution::LocalPartition;
let p = LocalPartition::new(16, 4, 0);
assert_eq!(p.local_n, 4);
assert_eq!(p.local_start, 0);
}
}