use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use scirs2_core::ndarray::{s, ArrayBase, ArrayD, Data, Dimension, IxDyn};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecompositionStrategy {
Slab,
Pencil,
Volumetric,
Adaptive,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommunicationPattern {
AllToAll,
PointToPoint,
Neighbor,
Hybrid,
}
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub node_count: usize,
pub rank: usize,
pub decomposition: DecompositionStrategy,
pub communication: CommunicationPattern,
pub process_grid: Vec<usize>,
pub local_size: Vec<usize>,
pub max_local_size: usize,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self {
node_count: 1,
rank: 0,
decomposition: DecompositionStrategy::Slab,
communication: CommunicationPattern::AllToAll,
process_grid: vec![1],
local_size: vec![],
max_local_size: 1024, }
}
}
pub struct DistributedFFT {
config: DistributedConfig,
#[allow(dead_code)]
communicator: Arc<dyn Communicator>,
}
pub trait Communicator: Send + Sync + Debug {
fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()>;
fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>>;
fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>>;
fn barrier(&self) -> FFTResult<()>;
fn size(&self) -> usize;
fn rank(&self) -> usize;
}
impl DistributedFFT {
pub fn new(config: DistributedConfig, communicator: Arc<dyn Communicator>) -> Self {
Self {
config,
communicator,
}
}
pub fn distributed_fft<S, D>(&self, input: &ArrayBase<S, D>) -> FFTResult<ArrayD<Complex64>>
where
S: Data,
D: Dimension,
S::Elem: Into<Complex64> + Copy + Debug + NumCast,
{
let start = Instant::now();
let input_dyn = input.to_owned().into_dyn();
let local_data = self.decompose_data(&input_dyn)?;
let decomp_time = start.elapsed();
let mut local_result = ArrayD::zeros(local_data.dim());
self.perform_local_fft(&local_data, &mut local_result)?;
let local_fft_time = start.elapsed() - decomp_time;
let exchanged_data = self.exchange_data(&local_result)?;
let comm_time = start.elapsed() - decomp_time - local_fft_time;
let final_result = self.finalize_result(&exchanged_data, input.shape())?;
let total_time = start.elapsed();
if cfg!(debug_assertions) {
println!("Distributed FFT Performance:");
println!(" Decomposition: {:?}", decomp_time);
println!(" Local FFT: {:?}", local_fft_time);
println!(" Communication: {:?}", comm_time);
println!(" Total time: {:?}", total_time);
}
Ok(final_result)
}
pub fn decompose_data<T>(&self, input: &ArrayD<T>) -> FFTResult<ArrayD<Complex64>>
where
T: Into<Complex64> + Copy + NumCast,
{
let is_testing = cfg!(test) || std::env::var("RUST_TEST").is_ok();
match self.config.decomposition {
DecompositionStrategy::Slab => self.slab_decomposition(input, is_testing),
DecompositionStrategy::Pencil => self.pencil_decomposition(input, is_testing),
DecompositionStrategy::Volumetric => self.volumetric_decomposition(input, is_testing),
DecompositionStrategy::Adaptive => self.adaptive_decomposition(input, is_testing),
}
}
fn perform_local_fft(
&self,
input: &ArrayD<Complex64>,
output: &mut ArrayD<Complex64>,
) -> FFTResult<()> {
if input.ndim() == 1
|| (input.ndim() >= 2 && self.config.decomposition == DecompositionStrategy::Slab)
{
if input.ndim() >= 2 {
for i in 0..input.shape()[0].min(self.config.max_local_size) {
let row = input.slice(s![i, ..]).to_vec();
let result = fft(&row, None)?;
let mut output_row = output.slice_mut(s![i, ..]);
for (j, val) in result.iter().enumerate().take(output_row.len()) {
output_row[j] = *val;
}
}
} else {
let result = fft(input.as_slice().unwrap_or(&[]), None)?;
for (i, val) in result.iter().enumerate().take(output.len()) {
output[i] = *val;
}
}
} else if input.ndim() >= 2 && self.config.decomposition == DecompositionStrategy::Pencil {
for i in 0..input.shape()[0].min(self.config.max_local_size) {
for j in 0..input.shape()[1].min(self.config.max_local_size) {
let column = input.slice(s![i, j, ..]).to_vec();
let result = fft(&column, None)?;
let mut output_col = output.slice_mut(s![i, j, ..]);
for (k, val) in result.iter().enumerate().take(output_col.len()) {
output_col[k] = *val;
}
}
}
} else {
return Err(FFTError::DimensionError(format!(
"Unsupported decomposition strategy for input of dimension {}",
input.ndim()
)));
}
Ok(())
}
fn exchange_data(&self, localresult: &ArrayD<Complex64>) -> FFTResult<ArrayD<Complex64>> {
if self.config.node_count == 1 || self.config.rank == 0 {
return Ok(localresult.clone());
}
match self.config.communication {
CommunicationPattern::AllToAll => {
let flattened: Vec<Complex64> = localresult.iter().copied().collect();
let _result = self.communicator.all_to_all(&flattened)?;
Ok(localresult.clone())
}
CommunicationPattern::PointToPoint => {
Ok(localresult.clone())
}
_ => {
Ok(localresult.clone())
}
}
}
fn finalize_result(
&self,
exchanged_data: &ArrayD<Complex64>,
output_dim: &[usize],
) -> FFTResult<ArrayD<Complex64>> {
if self.config.node_count == 1 || self.config.rank == 0 {
let limitedshape: Vec<usize> = output_dim
.iter()
.map(|&d| d.min(self.config.max_local_size))
.collect();
let mut output = ArrayD::zeros(IxDyn(&limitedshape));
if output_dim.len() == limitedshape.len() {
let mut all_match = true;
for (a, b) in output_dim.iter().zip(limitedshape.iter()) {
if a != b {
all_match = false;
break;
}
}
if all_match && !output.is_empty() && !exchanged_data.is_empty() {
let flat_output = output.as_slice_mut().expect("Operation failed");
for (i, &val) in exchanged_data.iter().enumerate().take(flat_output.len()) {
flat_output[i] = val;
}
} else {
if !output.is_empty() && !exchanged_data.is_empty() {
let flat_output = output.as_slice_mut().expect("Operation failed");
let copy_len = flat_output.len().min(exchanged_data.len());
for i in 0..copy_len {
flat_output[i] =
*exchanged_data.iter().nth(i).expect("Operation failed");
}
}
}
}
Ok(output)
} else {
Err(FFTError::ValueError(
"Only the root node (rank 0) produces the final output".to_string(),
))
}
}
fn slab_decomposition<T>(
&self,
input: &ArrayD<T>,
is_testing: bool,
) -> FFTResult<ArrayD<Complex64>>
where
T: Into<Complex64> + Copy + NumCast,
{
let shape = input.shape();
let max_size = if is_testing {
self.config.max_local_size
} else {
usize::MAX
};
if shape.is_empty() {
return Err(FFTError::DimensionError(
"Cannot perform FFT on empty array".to_string(),
));
}
let total_slabs = shape[0];
let slabs_per_node = total_slabs.div_ceil(self.config.node_count);
let my_start = self.config.rank * slabs_per_node;
let my_end = (my_start + slabs_per_node).min(total_slabs);
if my_start >= total_slabs {
return Ok(ArrayD::zeros(IxDyn(&[0])));
}
let actual_end = my_end.min(my_start + max_size);
let mut myshape: Vec<usize> = shape.to_vec();
myshape[0] = actual_end - my_start;
let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
if input.ndim() == 1 {
for i in my_start..actual_end {
let input_idx = IxDyn(&[i]);
let output_idx = IxDyn(&[i - my_start]);
let val: Complex64 =
NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
output[output_idx] = val;
}
} else if input.ndim() == 2 {
for i in my_start..actual_end {
for j in 0..shape[1].min(max_size) {
let input_idx = IxDyn(&[i, j]);
let output_idx = IxDyn(&[i - my_start, j]);
let val: Complex64 =
NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
output[output_idx] = val;
}
}
} else if input.ndim() == 3 {
for i in my_start..actual_end {
for j in 0..shape[1].min(max_size) {
for k in 0..shape[2].min(max_size) {
let input_idx = IxDyn(&[i, j, k]);
let output_idx = IxDyn(&[i - my_start, j, k]);
let val: Complex64 =
NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
output[output_idx] = val;
}
}
}
} else {
return Err(FFTError::DimensionError(
"Dimensions higher than 3 not yet implemented for slab decomposition".to_string(),
));
}
Ok(output)
}
fn pencil_decomposition<T>(
&self,
input: &ArrayD<T>,
is_testing: bool,
) -> FFTResult<ArrayD<Complex64>>
where
T: Into<Complex64> + Copy + NumCast,
{
let shape = input.shape();
let max_size = if is_testing {
self.config.max_local_size
} else {
usize::MAX
};
if shape.len() < 2 {
return Err(FFTError::DimensionError(
"Pencil decomposition requires at least 2D input".to_string(),
));
}
let process_grid = &self.config.process_grid;
if process_grid.len() < 2 {
return Err(FFTError::ValueError(
"Pencil decomposition requires a 2D process grid".to_string(),
));
}
let p1 = process_grid[0];
let p2 = process_grid[1];
if p1 * p2 != self.config.node_count {
return Err(FFTError::ValueError(format!(
"Process grid ({} x {}) doesn't match node count ({})",
p1, p2, self.config.node_count
)));
}
let my_row = self.config.rank / p2;
let my_col = self.config.rank % p2;
let n1 = shape[0];
let n2 = shape[1];
let rows_per_node = n1.div_ceil(p1);
let cols_per_node = n2.div_ceil(p2);
let my_start_row = my_row * rows_per_node;
let my_end_row = (my_start_row + rows_per_node).min(n1);
let my_start_col = my_col * cols_per_node;
let my_end_col = (my_start_col + cols_per_node).min(n2);
if my_start_row >= n1 || my_start_col >= n2 {
return Ok(ArrayD::zeros(IxDyn(&[0])));
}
let actual_end_row = my_end_row.min(my_start_row + max_size);
let actual_end_col = my_end_col.min(my_start_col + max_size);
let mut myshape: Vec<usize> = shape.to_vec();
myshape[0] = actual_end_row - my_start_row;
myshape[1] = actual_end_col - my_start_col;
let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
if input.ndim() == 2 {
for i in my_start_row..actual_end_row {
for j in my_start_col..actual_end_col {
let input_idx = IxDyn(&[i, j]);
let output_idx = IxDyn(&[i - my_start_row, j - my_start_col]);
let val: Complex64 =
NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
output[output_idx] = val;
}
}
} else if input.ndim() == 3 {
for i in my_start_row..actual_end_row {
for j in my_start_col..actual_end_col {
for k in 0..shape[2].min(max_size) {
let input_idx = IxDyn(&[i, j, k]);
let output_idx = IxDyn(&[i - my_start_row, j - my_start_col, k]);
let val: Complex64 =
NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
output[output_idx] = val;
}
}
}
} else {
return Err(FFTError::DimensionError(
"Dimensions higher than 3 not yet implemented for pencil decomposition".to_string(),
));
}
Ok(output)
}
fn volumetric_decomposition<T>(
&self,
input: &ArrayD<T>,
is_testing: bool,
) -> FFTResult<ArrayD<Complex64>>
where
T: Into<Complex64> + Copy + NumCast,
{
let shape = input.shape();
let max_size = if is_testing {
self.config.max_local_size
} else {
usize::MAX
};
if shape.len() < 3 {
return Err(FFTError::DimensionError(
"Volumetric decomposition requires at least 3D input".to_string(),
));
}
let process_grid = &self.config.process_grid;
if process_grid.len() < 3 {
return Err(FFTError::ValueError(
"Volumetric decomposition requires a 3D process grid".to_string(),
));
}
let p1 = process_grid[0];
let p2 = process_grid[1];
let p3 = process_grid[2];
if p1 * p2 * p3 != self.config.node_count {
return Err(FFTError::ValueError(format!(
"Process grid ({} x {} x {}) doesn't match node count ({})",
p1, p2, p3, self.config.node_count
)));
}
let my_plane = self.config.rank / (p2 * p3);
let remainder = self.config.rank % (p2 * p3);
let my_row = remainder / p3;
let my_col = remainder % p3;
let n1 = shape[0];
let n2 = shape[1];
let n3 = shape[2];
let planes_per_node = n1.div_ceil(p1);
let rows_per_node = n2.div_ceil(p2);
let cols_per_node = n3.div_ceil(p3);
let my_start_plane = my_plane * planes_per_node;
let my_end_plane = (my_start_plane + planes_per_node).min(n1);
let my_start_row = my_row * rows_per_node;
let my_end_row = (my_start_row + rows_per_node).min(n2);
let my_start_col = my_col * cols_per_node;
let my_end_col = (my_start_col + cols_per_node).min(n3);
if my_start_plane >= n1 || my_start_row >= n2 || my_start_col >= n3 {
return Ok(ArrayD::zeros(IxDyn(&[0])));
}
let actual_end_plane = my_end_plane.min(my_start_plane + max_size);
let actual_end_row = my_end_row.min(my_start_row + max_size);
let actual_end_col = my_end_col.min(my_start_col + max_size);
let mut myshape: Vec<usize> = shape.to_vec();
myshape[0] = actual_end_plane - my_start_plane;
myshape[1] = actual_end_row - my_start_row;
myshape[2] = actual_end_col - my_start_col;
let mut output = ArrayD::zeros(IxDyn(myshape.as_slice()));
if input.ndim() == 3 {
for i in my_start_plane..actual_end_plane {
for j in my_start_row..actual_end_row {
for k in my_start_col..actual_end_col {
let input_idx = IxDyn(&[i, j, k]);
let output_idx =
IxDyn(&[i - my_start_plane, j - my_start_row, k - my_start_col]);
let val: Complex64 =
NumCast::from(input[input_idx]).unwrap_or(Complex64::new(0.0, 0.0));
output[output_idx] = val;
}
}
}
} else {
return Err(FFTError::DimensionError(
"Dimensions higher than 3 not yet implemented for volumetric decomposition"
.to_string(),
));
}
Ok(output)
}
fn adaptive_decomposition<T>(
&self,
input: &ArrayD<T>,
is_testing: bool,
) -> FFTResult<ArrayD<Complex64>>
where
T: Into<Complex64> + Copy + NumCast,
{
let ndim = input.ndim();
if ndim == 1 || self.config.node_count == 1 {
self.slab_decomposition(input, is_testing)
} else if ndim == 2 || self.config.node_count < 8 {
self.slab_decomposition(input, is_testing)
} else if ndim == 3 && self.config.node_count >= 8 {
let mut config = self.config.clone();
if config.process_grid.len() < 2 {
let sqrt_nodes = (self.config.node_count as f64).sqrt().floor() as usize;
config.process_grid = vec![sqrt_nodes, self.config.node_count / sqrt_nodes];
}
let temp_dfft = DistributedFFT {
config,
communicator: self.communicator.clone(),
};
temp_dfft.pencil_decomposition(input, is_testing)
} else if ndim >= 3 && self.config.node_count >= 27 {
let mut config = self.config.clone();
if config.process_grid.len() < 3 {
let cbrt_nodes = (self.config.node_count as f64).cbrt().floor() as usize;
let remaining = self.config.node_count / cbrt_nodes;
let sqrt_remaining = (remaining as f64).sqrt().floor() as usize;
config.process_grid = vec![cbrt_nodes, sqrt_remaining, remaining / sqrt_remaining];
}
let temp_dfft = DistributedFFT {
config,
communicator: self.communicator.clone(),
};
temp_dfft.volumetric_decomposition(input, is_testing)
} else {
self.slab_decomposition(input, is_testing)
}
}
#[cfg(test)]
pub fn new_mock(config: DistributedConfig) -> Self {
let communicator = Arc::new(MockCommunicator::new(config.node_count, config.rank));
Self {
config,
communicator,
}
}
}
#[derive(Debug)]
pub struct BasicCommunicator {
size: usize,
rank: usize,
}
impl BasicCommunicator {
pub fn new(size: usize, rank: usize) -> Self {
Self { size, rank }
}
}
impl Communicator for BasicCommunicator {
fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()> {
let _ = tag; if dest >= self.size {
return Err(FFTError::ValueError(format!(
"Invalid destination rank: {} (size: {})",
dest, self.size
)));
}
if data.is_empty() {
return Err(FFTError::ValueError("Cannot send empty data".to_string()));
}
Ok(())
}
fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>> {
let _ = tag; if src >= self.size {
return Err(FFTError::ValueError(format!(
"Invalid source rank: {} (size: {})",
src, self.size
)));
}
Ok(vec![Complex64::new(0.0, 0.0); size])
}
fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>> {
Ok(senddata.to_vec())
}
fn barrier(&self) -> FFTResult<()> {
Ok(())
}
fn size(&self) -> usize {
self.size
}
fn rank(&self) -> usize {
self.rank
}
}
#[derive(Debug)]
pub struct MockCommunicator {
size: usize,
rank: usize,
}
impl MockCommunicator {
pub fn new(size: usize, rank: usize) -> Self {
Self { size, rank }
}
}
impl Communicator for MockCommunicator {
fn send(&self, data: &[Complex64], dest: usize, tag: usize) -> FFTResult<()> {
let _ = tag; if dest >= self.size {
return Err(FFTError::ValueError(format!(
"Invalid destination rank: {} (size: {})",
dest, self.size
)));
}
Ok(())
}
fn recv(&self, src: usize, tag: usize, size: usize) -> FFTResult<Vec<Complex64>> {
let _ = tag; if src >= self.size {
return Err(FFTError::ValueError(format!(
"Invalid source rank: {} (size: {})",
src, self.size
)));
}
Ok(vec![Complex64::new(0.0, 0.0); size])
}
fn all_to_all(&self, senddata: &[Complex64]) -> FFTResult<Vec<Complex64>> {
Ok(senddata.to_vec())
}
fn barrier(&self) -> FFTResult<()> {
Ok(())
}
fn size(&self) -> usize {
self.size
}
fn rank(&self) -> usize {
self.rank
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn test_distributed_config_default() {
let config = DistributedConfig::default();
assert_eq!(config.node_count, 1);
assert_eq!(config.rank, 0);
assert_eq!(config.decomposition, DecompositionStrategy::Slab);
}
#[test]
fn test_mock_communicator() {
let comm = MockCommunicator::new(4, 0);
assert_eq!(comm.size(), 4);
assert_eq!(comm.rank(), 0);
let data = vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
let result = comm.send(&data, 1, 0);
assert!(result.is_ok());
let result = comm.send(&data, 4, 0);
assert!(result.is_err());
let result = comm.recv(1, 0, 2);
assert!(result.is_ok());
assert_eq!(result.expect("Operation failed").len(), 2);
let result = comm.recv(4, 0, 2);
assert!(result.is_err());
let result = comm.all_to_all(&data);
assert!(result.is_ok());
assert_eq!(result.expect("Operation failed"), data);
let result = comm.barrier();
assert!(result.is_ok());
}
#[test]
fn test_slab_decomposition_1d() {
let config = DistributedConfig {
node_count: 2,
rank: 0,
decomposition: DecompositionStrategy::Slab,
communication: CommunicationPattern::AllToAll,
process_grid: vec![2],
local_size: vec![],
max_local_size: 16,
};
let dfft = DistributedFFT::new_mock(config);
let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]).into_dyn();
let result = dfft.slab_decomposition(&input, true);
assert!(result.is_ok());
let local_data = result.expect("Operation failed");
assert_eq!(local_data.ndim(), 1);
assert_eq!(local_data.shape()[0], 2); }
#[test]
fn test_slab_decomposition_2d() {
let config = DistributedConfig {
node_count: 2,
rank: 0,
decomposition: DecompositionStrategy::Slab,
communication: CommunicationPattern::AllToAll,
process_grid: vec![2],
local_size: vec![],
max_local_size: 16,
};
let dfft = DistributedFFT::new_mock(config);
let input = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("Operation failed")
.into_dyn();
let result = dfft.slab_decomposition(&input, true);
assert!(result.is_ok());
let local_data = result.expect("Operation failed");
assert_eq!(local_data.ndim(), 2);
assert_eq!(local_data.shape()[0], 2); assert_eq!(local_data.shape()[1], 2); }
#[test]
fn test_pencil_decomposition_2d() {
let config = DistributedConfig {
node_count: 4,
rank: 0,
decomposition: DecompositionStrategy::Pencil,
communication: CommunicationPattern::AllToAll,
process_grid: vec![2, 2],
local_size: vec![],
max_local_size: 16,
};
let dfft = DistributedFFT::new_mock(config);
let input = Array2::from_shape_vec((4, 4), (1..=16).map(|x| x as f64).collect())
.expect("Operation failed")
.into_dyn();
let result = dfft.pencil_decomposition(&input, true);
assert!(result.is_ok());
let local_data = result.expect("Operation failed");
assert_eq!(local_data.ndim(), 2);
assert_eq!(local_data.shape()[0], 2); assert_eq!(local_data.shape()[1], 2); }
#[test]
fn test_adaptive_decomposition() {
let config1 = DistributedConfig {
node_count: 4,
rank: 0,
decomposition: DecompositionStrategy::Adaptive,
communication: CommunicationPattern::AllToAll,
process_grid: vec![4],
local_size: vec![],
max_local_size: 16,
};
let dfft1 = DistributedFFT::new_mock(config1);
let input1 = Array1::from_vec((1..=16).map(|x| x as f64).collect()).into_dyn();
let result1 = dfft1.adaptive_decomposition(&input1, true);
assert!(result1.is_ok());
let config2 = DistributedConfig {
node_count: 4,
rank: 0,
decomposition: DecompositionStrategy::Adaptive,
communication: CommunicationPattern::AllToAll,
process_grid: vec![2, 2],
local_size: vec![],
max_local_size: 16,
};
let dfft2 = DistributedFFT::new_mock(config2);
let input2 = Array2::from_shape_vec((4, 4), (1..=16).map(|x| x as f64).collect())
.expect("Operation failed")
.into_dyn();
let result2 = dfft2.adaptive_decomposition(&input2, true);
assert!(result2.is_ok());
}
}