use mpi::datatype::Equivalence;
use mpi::topology::{Color, Communicator, SimpleCommunicator};
use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use crate::mpi::distribution::LocalPartition;
use crate::mpi::error::MpiError;
use crate::mpi::pool::{MpiFloat, MpiPool};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PencilGrid {
pub n_rows: usize,
pub n_cols: usize,
}
impl PencilGrid {
pub fn new(n_rows: usize, n_cols: usize) -> Self {
Self { n_rows, n_cols }
}
pub fn total_procs(&self) -> usize {
self.n_rows * self.n_cols
}
pub fn row_rank(&self, global_rank: usize) -> usize {
global_rank / self.n_cols
}
pub fn col_rank(&self, global_rank: usize) -> usize {
global_rank % self.n_cols
}
}
pub struct PencilPlan3D<T: Float, C: Communicator> {
dims: [usize; 3],
grid: PencilGrid,
row_rank: usize,
col_rank: usize,
local_n0: usize,
local_0_start: usize,
local_n1: usize,
local_1_start: usize,
direction: Direction,
plan_x: Plan<T>,
plan_y: Plan<T>,
plan_z: Plan<T>,
_scratch: Vec<Complex<T>>,
_pool: *const MpiPool<C>,
row_pool: Option<MpiPool<SimpleCommunicator>>,
col_pool: Option<MpiPool<SimpleCommunicator>>,
_phantom: core::marker::PhantomData<(T, C)>,
}
unsafe impl<T: Float, C: Communicator + Send> Send for PencilPlan3D<T, C> {}
unsafe impl<T: Float, C: Communicator + Sync> Sync for PencilPlan3D<T, C> {}
impl<T: Float + MpiFloat, C: Communicator> PencilPlan3D<T, C>
where
Complex<T>: Equivalence,
{
pub fn new(
n0: usize,
n1: usize,
n2: usize,
grid: PencilGrid,
direction: Direction,
flags: Flags,
pool: &MpiPool<C>,
) -> Result<Self, MpiError> {
let dims = [n0, n1, n2];
for (i, &d) in dims.iter().enumerate() {
if d == 0 {
return Err(MpiError::InvalidDimension {
dim: i,
size: d,
message: "Dimension size cannot be zero".to_string(),
});
}
}
if grid.total_procs() != pool.size() {
return Err(MpiError::InsufficientProcesses {
required: grid.total_procs(),
available: pool.size(),
});
}
let global_rank = pool.rank();
let row_rank = grid.row_rank(global_rank);
let col_rank = grid.col_rank(global_rank);
let part0 = LocalPartition::new(n0, grid.n_rows, row_rank);
let part1 = LocalPartition::new(n1, grid.n_cols, col_rank);
let local_n0 = part0.local_n;
let local_0_start = part0.local_start;
let local_n1 = part1.local_n;
let local_1_start = part1.local_start;
let plan_x = Plan::dft_1d(n0, direction, flags).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create X (n0={n0}) plan"),
})?;
let plan_y = Plan::dft_1d(n1, direction, flags).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create Y (n1={n1}) plan"),
})?;
let plan_z = Plan::dft_1d(n2, direction, flags).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create Z (n2={n2}) plan"),
})?;
let local_n2_col = LocalPartition::new(n2, grid.n_cols, col_rank).local_n;
let local_n1_row = LocalPartition::new(n1, grid.n_rows, row_rank).local_n;
let scratch_size = (local_n0 * local_n1 * n2)
.max(local_n0 * n1 * local_n2_col)
.max(n0 * local_n1_row * local_n2_col);
let scratch = vec![Complex::<T>::zero(); scratch_size];
let (row_pool, col_pool) = if pool.size() == 1 {
(None, None)
} else {
let row_comm = pool
.comm()
.split_by_color(Color::with_value(row_rank as i32))
.ok_or_else(|| MpiError::CommunicationError {
message: format!("Failed to create row sub-comm (row_rank={row_rank})"),
})?;
let col_comm = pool
.comm()
.split_by_color(Color::with_value(col_rank as i32))
.ok_or_else(|| MpiError::CommunicationError {
message: format!("Failed to create col sub-comm (col_rank={col_rank})"),
})?;
(Some(MpiPool::new(row_comm)), Some(MpiPool::new(col_comm)))
};
Ok(Self {
dims,
grid,
row_rank,
col_rank,
local_n0,
local_0_start,
local_n1,
local_1_start,
direction,
plan_x,
plan_y,
plan_z,
_scratch: scratch,
_pool: core::ptr::from_ref(pool),
row_pool,
col_pool,
_phantom: core::marker::PhantomData,
})
}
pub fn dims(&self) -> [usize; 3] {
self.dims
}
pub fn grid(&self) -> PencilGrid {
self.grid
}
pub fn direction(&self) -> Direction {
self.direction
}
pub fn row_rank(&self) -> usize {
self.row_rank
}
pub fn col_rank(&self) -> usize {
self.col_rank
}
pub fn local_dims(&self) -> (usize, usize, usize, usize, usize) {
(
self.local_n0,
self.local_0_start,
self.local_n1,
self.local_1_start,
self.dims[2],
)
}
pub fn col_pool(&self) -> Option<&MpiPool<SimpleCommunicator>> {
self.col_pool.as_ref()
}
pub fn execute_inplace(&mut self, data: &mut [Complex<T>]) -> Result<(), MpiError> {
let [n0, n1, n2] = self.dims;
let expected = self.local_n0 * self.local_n1 * n2;
if data.len() < expected {
return Err(MpiError::SizeMismatch {
expected,
actual: data.len(),
});
}
if self.row_pool.is_none() {
pure::fft_3d_zyx_with_plans(data, n0, n1, n2, &self.plan_x, &self.plan_y, &self.plan_z);
Ok(())
} else {
Err(MpiError::FftError {
message: "multi-rank pencil execution not yet implemented".to_string(),
})
}
}
pub fn execute(
&mut self,
input: &[Complex<T>],
output: &mut [Complex<T>],
) -> Result<(), MpiError> {
let expected = self.local_n0 * self.local_n1 * self.dims[2];
if input.len() < expected {
return Err(MpiError::SizeMismatch {
expected,
actual: input.len(),
});
}
if output.len() < expected {
return Err(MpiError::SizeMismatch {
expected,
actual: output.len(),
});
}
output[..expected].copy_from_slice(&input[..expected]);
self.execute_inplace(output)
}
}
pub mod pure {
use super::*;
pub(super) fn fft_3d_zyx_with_plans<T: Float>(
data: &mut [Complex<T>],
n0: usize,
n1: usize,
n2: usize,
plan_x: &Plan<T>,
plan_y: &Plan<T>,
plan_z: &Plan<T>,
) {
{
let mut tmp = vec![Complex::<T>::zero(); n2];
for i0 in 0..n0 {
for i1 in 0..n1 {
let off = i0 * n1 * n2 + i1 * n2;
tmp.copy_from_slice(&data[off..off + n2]);
plan_z.execute(&tmp.clone(), &mut data[off..off + n2]);
}
}
}
{
let mut col_in = vec![Complex::<T>::zero(); n1];
let mut col_out = vec![Complex::<T>::zero(); n1];
for i0 in 0..n0 {
for i2 in 0..n2 {
for i1 in 0..n1 {
col_in[i1] = data[i0 * n1 * n2 + i1 * n2 + i2];
}
plan_y.execute(&col_in, &mut col_out);
for i1 in 0..n1 {
data[i0 * n1 * n2 + i1 * n2 + i2] = col_out[i1];
}
}
}
}
{
let mut row_in = vec![Complex::<T>::zero(); n0];
let mut row_out = vec![Complex::<T>::zero(); n0];
for i1 in 0..n1 {
for i2 in 0..n2 {
for i0 in 0..n0 {
row_in[i0] = data[i0 * n1 * n2 + i1 * n2 + i2];
}
plan_x.execute(&row_in, &mut row_out);
for i0 in 0..n0 {
data[i0 * n1 * n2 + i1 * n2 + i2] = row_out[i0];
}
}
}
}
}
#[cfg(test)]
pub fn fft_3d_zyx<T: Float>(
data: &mut [Complex<T>],
n0: usize,
n1: usize,
n2: usize,
direction: Direction,
) -> Result<(), MpiError> {
for (i, &d) in [n0, n1, n2].iter().enumerate() {
if d == 0 {
return Err(MpiError::InvalidDimension {
dim: i,
size: d,
message: "Dimension cannot be zero".to_string(),
});
}
}
let flags = Flags::ESTIMATE;
let plan_z = Plan::dft_1d(n2, direction, flags).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create Z plan for size {n2}"),
})?;
let plan_y = Plan::dft_1d(n1, direction, flags).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create Y plan for size {n1}"),
})?;
let plan_x = Plan::dft_1d(n0, direction, flags).ok_or_else(|| MpiError::FftError {
message: format!("Failed to create X plan for size {n0}"),
})?;
fft_3d_zyx_with_plans(data, n0, n1, n2, &plan_x, &plan_y, &plan_z);
Ok(())
}
#[cfg(test)]
pub fn max_abs_error<T: Float>(a: &[Complex<T>], b: &[Complex<T>]) -> T {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = *x - *y;
Float::sqrt(diff.re * diff.re + diff.im * diff.im)
})
.fold(T::zero(), |acc, v| if v > acc { v } else { acc })
}
}
#[cfg(test)]
mod tests {
use super::pure::{fft_3d_zyx, max_abs_error};
use crate::api::Direction;
use crate::kernel::Complex;
fn make_test_input_f64(n0: usize, n1: usize, n2: usize) -> Vec<Complex<f64>> {
let n = n0 * n1 * n2;
(0..n)
.map(|i| {
let t = i as f64 / n as f64;
Complex {
re: (2.0 * core::f64::consts::PI * t * 3.0).cos(),
im: (2.0 * core::f64::consts::PI * t * 5.0).sin(),
}
})
.collect()
}
#[test]
fn pencil_grid_basic() {
use super::PencilGrid;
let g = PencilGrid::new(2, 4);
assert_eq!(g.total_procs(), 8);
assert_eq!(g.row_rank(0), 0);
assert_eq!(g.row_rank(4), 1);
assert_eq!(g.col_rank(0), 0);
assert_eq!(g.col_rank(3), 3);
assert_eq!(g.col_rank(4), 0);
assert_eq!(g.col_rank(7), 3);
}
#[test]
fn pencil_grid_single_proc() {
use super::PencilGrid;
let g = PencilGrid::new(1, 1);
assert_eq!(g.total_procs(), 1);
assert_eq!(g.row_rank(0), 0);
assert_eq!(g.col_rank(0), 0);
}
#[test]
fn pencil_pure_fft_4x4x4_impulse_gives_ones() {
let n0 = 4;
let n1 = 4;
let n2 = 4;
let n = n0 * n1 * n2;
let mut data = vec![Complex::<f64>::zero(); n];
data[0] = Complex { re: 1.0, im: 0.0 };
fft_3d_zyx(&mut data, n0, n1, n2, Direction::Forward)
.expect("fft_3d_zyx 4x4x4 should succeed");
for (i, &v) in data.iter().enumerate() {
assert!(
(v.re - 1.0).abs() < 1e-10,
"coeff[{i}].re = {:.2e} (expected 1.0)",
v.re
);
assert!(
v.im.abs() < 1e-10,
"coeff[{i}].im = {:.2e} (expected 0.0)",
v.im
);
}
}
#[test]
fn pencil_pure_fft_8x8x8_impulse_gives_ones() {
let n0 = 8;
let n1 = 8;
let n2 = 8;
let n = n0 * n1 * n2;
let mut data = vec![Complex::<f64>::zero(); n];
data[0] = Complex { re: 1.0, im: 0.0 };
fft_3d_zyx(&mut data, n0, n1, n2, Direction::Forward)
.expect("fft_3d_zyx 8x8x8 should succeed");
for (i, &v) in data.iter().enumerate() {
assert!(
(v.re - 1.0).abs() < 1e-10,
"coeff[{i}].re = {:.2e} (expected 1.0)",
v.re
);
assert!(
v.im.abs() < 1e-10,
"coeff[{i}].im = {:.2e} (expected 0.0)",
v.im
);
}
}
#[test]
fn pencil_pure_roundtrip_4x4x4() {
let n0 = 4;
let n1 = 4;
let n2 = 4;
let n = n0 * n1 * n2;
let scale = n as f64;
let original = make_test_input_f64(n0, n1, n2);
let mut data = original.clone();
fft_3d_zyx(&mut data, n0, n1, n2, Direction::Forward).expect("forward fft should succeed");
fft_3d_zyx(&mut data, n0, n1, n2, Direction::Backward).expect("inverse fft should succeed");
for v in data.iter_mut() {
v.re /= scale;
v.im /= scale;
}
let err = max_abs_error(&original, &data);
assert!(
err < 1e-10,
"roundtrip error {err:.2e} exceeds 1e-10 for 4x4x4"
);
}
#[test]
fn pencil_pure_roundtrip_8x8x8() {
let n0 = 8;
let n1 = 8;
let n2 = 8;
let n = n0 * n1 * n2;
let scale = n as f64;
let original = make_test_input_f64(n0, n1, n2);
let mut data = original.clone();
fft_3d_zyx(&mut data, n0, n1, n2, Direction::Forward).expect("forward fft should succeed");
fft_3d_zyx(&mut data, n0, n1, n2, Direction::Backward).expect("inverse fft should succeed");
for v in data.iter_mut() {
v.re /= scale;
v.im /= scale;
}
let err = max_abs_error(&original, &data);
assert!(
err < 1e-10,
"roundtrip error {err:.2e} exceeds 1e-10 for 8x8x8"
);
}
#[test]
fn pencil_pure_linearity_8x8x8() {
let n0 = 8;
let n1 = 8;
let n2 = 8;
let n = n0 * n1 * n2;
let x = make_test_input_f64(n0, n1, n2);
let y: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = (i + 7) as f64 / n as f64;
Complex {
re: (2.0 * core::f64::consts::PI * t).cos(),
im: 0.0,
}
})
.collect();
let a = Complex::<f64> { re: 2.0, im: -1.0 };
let b = Complex::<f64> { re: -0.5, im: 3.0 };
let mut combined: Vec<Complex<f64>> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| a * xi + b * yi)
.collect();
fft_3d_zyx(&mut combined, n0, n1, n2, Direction::Forward)
.expect("combined fft should succeed");
let mut fx = x;
let mut fy = y;
fft_3d_zyx(&mut fx, n0, n1, n2, Direction::Forward).expect("fx fft should succeed");
fft_3d_zyx(&mut fy, n0, n1, n2, Direction::Forward).expect("fy fft should succeed");
let linear: Vec<Complex<f64>> = fx
.iter()
.zip(fy.iter())
.map(|(&xi, &yi)| a * xi + b * yi)
.collect();
let err = max_abs_error(&combined, &linear);
assert!(
err < 1e-8,
"linearity error {err:.2e} exceeds 1e-8 for 8x8x8"
);
}
#[test]
fn pencil_pure_zero_dim_error() {
let mut data: Vec<Complex<f64>> = Vec::new();
let result = fft_3d_zyx(&mut data, 0, 4, 4, Direction::Forward);
assert!(result.is_err(), "expected error for zero n0");
}
#[cfg(feature = "mpi")]
mod mpi_required {
#[test]
#[ignore = "Requires MPI runtime: mpirun -n 1 cargo test --features mpi"]
fn pencil_mpi_construction_p1() {
}
#[test]
#[ignore = "Requires MPI runtime with 4 ranks: mpirun -n 4 cargo test --features mpi"]
fn pencil_mpi_4rank_8x8x8() {
}
}
}