use mpi::datatype::Equivalence;
use mpi::topology::Communicator;
use crate::api::{Direction, Plan};
use crate::kernel::{Complex, Float};
use crate::mpi::distribution::LocalPartition;
use crate::mpi::error::MpiError;
use crate::mpi::pool::{MpiFloat, MpiPool};
use crate::mpi::transpose::distributed_transpose;
use crate::mpi::MpiFlags;
#[allow(dead_code)]
pub struct MpiPlan3D<T: Float, C: Communicator> {
dims: [usize; 3],
local_n0: usize,
local_0_start: usize,
direction: Direction,
flags: MpiFlags,
pool: *const MpiPool<C>,
plan_n2: Plan<T>,
plan_n1: Plan<T>,
plan_n0: Plan<T>,
scratch: Vec<Complex<T>>,
_phantom: core::marker::PhantomData<(T, C)>,
}
unsafe impl<T: Float, C: Communicator + Send> Send for MpiPlan3D<T, C> {}
unsafe impl<T: Float, C: Communicator + Sync> Sync for MpiPlan3D<T, C> {}
impl<T: Float + MpiFloat, C: Communicator> MpiPlan3D<T, C>
where
Complex<T>: Equivalence,
{
pub fn new(
n0: usize,
n1: usize,
n2: usize,
direction: Direction,
flags: MpiFlags,
pool: &MpiPool<C>,
) -> Result<Self, MpiError> {
if n0 == 0 || n1 == 0 || n2 == 0 {
return Err(MpiError::InvalidDimension {
dim: if n0 == 0 {
0
} else if n1 == 0 {
1
} else {
2
},
size: if n0 == 0 {
n0
} else if n1 == 0 {
n1
} else {
n2
},
message: "Dimension size cannot be zero".to_string(),
});
}
let partition = pool.local_partition(n0);
let local_n0 = partition.local_n;
let local_0_start = partition.local_start;
let transposed_partition = LocalPartition::new(n1, pool.size(), pool.rank());
let normal_size = local_n0 * n1 * n2;
let transposed_size = n0 * transposed_partition.local_n * n2;
let scratch_size = normal_size.max(transposed_size);
let plan_n2 =
Plan::dft_1d(n2, direction, flags.base).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create n2 plan for size {n2}"),
})?;
let plan_n1 =
Plan::dft_1d(n1, direction, flags.base).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create n1 plan for size {n1}"),
})?;
let plan_n0 =
Plan::dft_1d(n0, direction, flags.base).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create n0 plan for size {n0}"),
})?;
let scratch = vec![Complex::<T>::zero(); scratch_size];
Ok(Self {
dims: [n0, n1, n2],
local_n0,
local_0_start,
direction,
flags,
pool: std::ptr::from_ref(pool),
plan_n2,
plan_n1,
plan_n0,
scratch,
_phantom: core::marker::PhantomData,
})
}
pub fn dims(&self) -> [usize; 3] {
self.dims
}
pub fn local_dims(&self) -> (usize, usize, usize, usize) {
(
self.local_n0,
self.local_0_start,
self.dims[1],
self.dims[2],
)
}
pub fn execute_inplace(&mut self, data: &mut [Complex<T>]) -> Result<(), MpiError> {
let n0 = self.dims[0];
let n1 = self.dims[1];
let n2 = self.dims[2];
let expected_size = self.local_n0 * n1 * n2;
if data.len() < expected_size {
return Err(MpiError::SizeMismatch {
expected: expected_size,
actual: data.len(),
});
}
let pool = unsafe { &*self.pool };
let mut buffer_n2 = vec![Complex::<T>::zero(); n2];
for i0 in 0..self.local_n0 {
for i1 in 0..n1 {
let offset = i0 * n1 * n2 + i1 * n2;
buffer_n2.copy_from_slice(&data[offset..offset + n2]);
self.plan_n2
.execute(&buffer_n2, &mut data[offset..offset + n2]);
}
}
let mut buffer_n1 = vec![Complex::<T>::zero(); n1];
for i0 in 0..self.local_n0 {
for i2 in 0..n2 {
for i1 in 0..n1 {
buffer_n1[i1] = data[i0 * n1 * n2 + i1 * n2 + i2];
}
let mut output_n1 = vec![Complex::<T>::zero(); n1];
self.plan_n1.execute(&buffer_n1, &mut output_n1);
for i1 in 0..n1 {
data[i0 * n1 * n2 + i1 * n2 + i2] = output_n1[i1];
}
}
}
let transposed_partition = LocalPartition::new(n1, pool.size(), pool.rank());
let local_n1 = transposed_partition.local_n;
for i2 in 0..n2 {
let mut slice_in = vec![Complex::<T>::zero(); self.local_n0 * n1];
for i0 in 0..self.local_n0 {
for i1 in 0..n1 {
slice_in[i0 * n1 + i1] = data[i0 * n1 * n2 + i1 * n2 + i2];
}
}
let mut slice_out = vec![Complex::<T>::zero(); n0 * local_n1];
distributed_transpose(
pool,
&slice_in,
&mut slice_out,
n0,
n1,
self.local_n0,
self.local_0_start,
)?;
for i1_local in 0..local_n1 {
for i0 in 0..n0 {
self.scratch[i1_local * n0 * n2 + i0 * n2 + i2] = slice_out[i1_local * n0 + i0];
}
}
}
let mut buffer_n0 = vec![Complex::<T>::zero(); n0];
for i1_local in 0..local_n1 {
for i2 in 0..n2 {
for i0 in 0..n0 {
buffer_n0[i0] = self.scratch[i1_local * n0 * n2 + i0 * n2 + i2];
}
let mut output_n0 = vec![Complex::<T>::zero(); n0];
self.plan_n0.execute(&buffer_n0, &mut output_n0);
for i0 in 0..n0 {
self.scratch[i1_local * n0 * n2 + i0 * n2 + i2] = output_n0[i0];
}
}
}
if !self.flags.transposed_out {
for i2 in 0..n2 {
let mut slice_in = vec![Complex::<T>::zero(); local_n1 * n0];
for i1_local in 0..local_n1 {
for i0 in 0..n0 {
slice_in[i1_local * n0 + i0] =
self.scratch[i1_local * n0 * n2 + i0 * n2 + i2];
}
}
let mut slice_out = vec![Complex::<T>::zero(); self.local_n0 * n1];
distributed_transpose(
pool,
&slice_in,
&mut slice_out,
n1,
n0,
local_n1,
transposed_partition.local_start,
)?;
for i0 in 0..self.local_n0 {
for i1 in 0..n1 {
data[i0 * n1 * n2 + i1 * n2 + i2] = slice_out[i0 * n1 + i1];
}
}
}
} else {
let transposed_size = local_n1 * n0 * n2;
data[..transposed_size].copy_from_slice(&self.scratch[..transposed_size]);
}
Ok(())
}
pub fn execute(
&mut self,
input: &[Complex<T>],
output: &mut [Complex<T>],
) -> Result<(), MpiError> {
let expected_size = self.local_n0 * self.dims[1] * self.dims[2];
if input.len() < expected_size {
return Err(MpiError::SizeMismatch {
expected: expected_size,
actual: input.len(),
});
}
output[..expected_size].copy_from_slice(&input[..expected_size]);
self.execute_inplace(output)
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_3d_partition() {
use crate::mpi::distribution::LocalPartition;
let p = LocalPartition::new(32, 4, 0);
assert_eq!(p.local_n, 8);
assert_eq!(p.local_start, 0);
}
}