use mpi::datatype::Equivalence;
use mpi::topology::Communicator;
use crate::kernel::Complex;
use super::distribution::LocalPartition;
use super::error::MpiError;
use super::pool::{MpiFloat, MpiPool};
pub fn distributed_transpose<T, C>(
pool: &MpiPool<C>,
input: &[Complex<T>],
output: &mut [Complex<T>],
n0: usize,
n1: usize,
local_n0: usize,
_local_0_start: usize,
) -> Result<(), MpiError>
where
T: MpiFloat,
C: Communicator,
Complex<T>: Equivalence,
{
let num_procs = pool.size();
let rank = pool.rank();
let expected_input_size = local_n0 * n1;
if input.len() < expected_input_size {
return Err(MpiError::SizeMismatch {
expected: expected_input_size,
actual: input.len(),
});
}
let transposed_partition = LocalPartition::new(n1, num_procs, rank);
let local_n1 = transposed_partition.local_n;
let expected_output_size = n0 * local_n1;
if output.len() < expected_output_size {
return Err(MpiError::SizeMismatch {
expected: expected_output_size,
actual: output.len(),
});
}
let mut send_counts = Vec::with_capacity(num_procs);
let mut send_displs = Vec::with_capacity(num_procs);
let mut recv_counts = Vec::with_capacity(num_procs);
let mut recv_displs = Vec::with_capacity(num_procs);
let mut send_offset = 0;
let mut recv_offset = 0;
for p in 0..num_procs {
let partition_p = LocalPartition::new(n1, num_procs, p);
let send_count = local_n0 * partition_p.local_n;
send_counts.push(send_count as i32);
send_displs.push(send_offset as i32);
send_offset += send_count;
let source_partition = LocalPartition::new(n0, num_procs, p);
let recv_count = source_partition.local_n * local_n1;
recv_counts.push(recv_count as i32);
recv_displs.push(recv_offset as i32);
recv_offset += recv_count;
}
let total_send = send_offset;
let mut send_buffer = vec![Complex::<T>::zero(); total_send];
let mut buf_offset = 0;
for p in 0..num_procs {
let partition_p = LocalPartition::new(n1, num_procs, p);
for row in 0..local_n0 {
for col in 0..partition_p.local_n {
let global_col = partition_p.local_start + col;
send_buffer[buf_offset] = input[row * n1 + global_col];
buf_offset += 1;
}
}
}
let total_recv = recv_offset;
let mut recv_buffer = vec![Complex::<T>::zero(); total_recv];
pool.all_to_all_v_complex(
&send_buffer,
&send_counts,
&send_displs,
&mut recv_buffer,
&recv_counts,
&recv_displs,
)?;
let mut recv_idx = 0;
for p in 0..num_procs {
let source_partition = LocalPartition::new(n0, num_procs, p);
for src_row in 0..source_partition.local_n {
let global_row = source_partition.local_start + src_row;
for local_col in 0..local_n1 {
output[local_col * n0 + global_row] = recv_buffer[recv_idx];
recv_idx += 1;
}
}
}
Ok(())
}
#[allow(dead_code)]
pub fn distributed_transpose_inplace<T, C>(
pool: &MpiPool<C>,
data: &mut [Complex<T>],
scratch: &mut [Complex<T>],
n0: usize,
n1: usize,
local_n0: usize,
local_0_start: usize,
) -> Result<(), MpiError>
where
T: MpiFloat,
C: Communicator,
Complex<T>: Equivalence,
{
distributed_transpose(pool, data, scratch, n0, n1, local_n0, local_0_start)?;
let transposed_partition = LocalPartition::new(n1, pool.size(), pool.rank());
let local_n1 = transposed_partition.local_n;
let output_size = n0 * local_n1;
data[..output_size].copy_from_slice(&scratch[..output_size]);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_calculation() {
let n0 = 16;
let n1 = 8;
let num_procs = 4;
let mut total_send = 0;
let mut total_recv = 0;
for rank in 0..num_procs {
let local_partition = LocalPartition::new(n0, num_procs, rank);
let transposed_partition = LocalPartition::new(n1, num_procs, rank);
let local_elements = local_partition.local_n * n1;
let transposed_elements = n0 * transposed_partition.local_n;
total_send += local_elements;
total_recv += transposed_elements;
}
assert_eq!(total_send, n0 * n1);
assert_eq!(total_recv, n0 * n1);
}
}