#[cfg(feature = "mpi")]
use crate::algebra::prelude::*;
#[cfg(feature = "mpi")]
use crate::error::KError;
#[cfg(feature = "mpi")]
use crate::parallel::{Comm, UniverseComm};
#[cfg(feature = "mpi")]
#[derive(Debug, Clone)]
pub struct CommPlan {
pub imports: Vec<Vec<usize>>,
pub exports: Vec<Vec<usize>>,
pub import_locs: Vec<Vec<usize>>,
pub import_peers: Vec<usize>,
pub export_peers: Vec<usize>,
}
#[cfg(feature = "mpi")]
impl CommPlan {
pub fn exchange_values(
&self,
comm: &UniverseComm,
row_start: usize,
local: &[S],
) -> Result<Vec<Vec<S>>, KError> {
let mut send = Vec::with_capacity(self.export_peers.len());
for &peer in &self.export_peers {
let export = &self.exports[peer];
let mut buf = Vec::with_capacity(export.len());
for &g in export {
let li = g - row_start;
buf.push(local[li]);
}
send.push(buf);
}
alltoallv_scalar_sparse(comm, &self.export_peers, &send, &self.import_peers)
}
}
#[cfg(feature = "mpi")]
pub fn alltoallv_u64(comm: &UniverseComm, send: &[Vec<u64>]) -> Result<Vec<Vec<u64>>, KError> {
let size = comm.size();
if send.len() != size {
return Err(KError::InvalidInput(
"alltoallv_u64: send buffer length must match communicator size".into(),
));
}
let rank = comm.rank();
let mut recv_counts = vec![0u64; size];
let mut send_counts = vec![0u64; size];
for (slot, buf) in send_counts.iter_mut().zip(send.iter()) {
*slot = buf.len() as u64;
}
let mut recv_count_bufs = vec![[0u64; 1]; size];
let count_bufs: Vec<[u64; 1]> = send_counts.iter().map(|&count| [count]).collect();
let mut reqs = Vec::new();
for peer in 0..size {
if peer == rank {
recv_counts[peer] = send_counts[peer];
continue;
}
let buf = unsafe { &mut *recv_count_bufs.as_mut_ptr().add(peer) };
reqs.push(comm.irecv_from_u64(buf, peer as i32));
}
for peer in 0..size {
if peer == rank {
continue;
}
reqs.push(comm.isend_to_u64(&count_bufs[peer], peer as i32));
}
comm.wait_all(&mut reqs);
for peer in 0..size {
if peer == rank {
continue;
}
recv_counts[peer] = recv_count_bufs[peer][0];
}
let mut recv = vec![Vec::new(); size];
let mut reqs = Vec::new();
for peer in 0..size {
if peer == rank {
recv[peer] = send[peer].clone();
continue;
}
let count = recv_counts[peer] as usize;
recv[peer] = vec![0u64; count];
}
for peer in 0..size {
if peer == rank {
continue;
}
let buf = unsafe { &mut *recv.as_mut_ptr().add(peer) };
reqs.push(comm.irecv_from_u64(buf, peer as i32));
}
for peer in 0..size {
if peer == rank {
continue;
}
reqs.push(comm.isend_to_u64(&send[peer], peer as i32));
}
comm.wait_all(&mut reqs);
Ok(recv)
}
#[cfg(feature = "mpi")]
pub fn alltoallv_u64_sparse(
comm: &UniverseComm,
send_peers: &[usize],
send: &[Vec<u64>],
recv_peers: &[usize],
) -> Result<Vec<Vec<u64>>, KError> {
if send_peers.len() != send.len() {
return Err(KError::InvalidInput(
"alltoallv_u64_sparse: send_peers/send lengths must match".into(),
));
}
let size = comm.size();
let rank = comm.rank();
if size == 1 {
return Ok(Vec::new());
}
let dense_send = send_peers.len() == size.saturating_sub(1)
&& recv_peers.len() == size.saturating_sub(1)
&& send_peers.iter().all(|&p| p != rank)
&& recv_peers.iter().all(|&p| p != rank);
if dense_send {
let mut dense = vec![Vec::<u64>::new(); size];
for (&peer, payload) in send_peers.iter().zip(send.iter()) {
dense[peer] = payload.clone();
}
let recv = alltoallv_u64(comm, &dense)?;
return Ok(recv_peers.iter().map(|&peer| recv[peer].clone()).collect());
}
let mut recv_counts = vec![[0u64; 1]; recv_peers.len()];
let send_counts: Vec<[u64; 1]> = send.iter().map(|buf| [buf.len() as u64]).collect();
let mut reqs = Vec::with_capacity(recv_peers.len() + send_peers.len());
for (slot, &peer) in recv_counts.iter_mut().zip(recv_peers.iter()) {
reqs.push(comm.irecv_from_u64(slot, peer as i32));
}
for (&peer, count) in send_peers.iter().zip(send_counts.iter()) {
reqs.push(comm.isend_to_u64(count, peer as i32));
}
comm.wait_all(&mut reqs);
let mut recv = recv_counts
.iter()
.map(|count| vec![0u64; count[0] as usize])
.collect::<Vec<_>>();
let mut reqs = Vec::with_capacity(recv_peers.len() + send_peers.len());
for (buf, &peer) in recv.iter_mut().zip(recv_peers.iter()) {
reqs.push(comm.irecv_from_u64(buf, peer as i32));
}
for (&peer, payload) in send_peers.iter().zip(send.iter()) {
reqs.push(comm.isend_to_u64(payload, peer as i32));
}
comm.wait_all(&mut reqs);
Ok(recv)
}
#[cfg(feature = "mpi")]
pub fn alltoallv_scalar(comm: &UniverseComm, send: &[Vec<S>]) -> Result<Vec<Vec<S>>, KError> {
let size = comm.size();
if send.len() != size {
return Err(KError::InvalidInput(
"alltoallv_scalar: send buffer length must match communicator size".into(),
));
}
let packed = pack_scalar_buffers(send);
let recv_packed = alltoallv_u64(comm, &packed)?;
unpack_scalar_recvs(&recv_packed)
}
#[cfg(feature = "mpi")]
pub fn alltoallv_scalar_sparse(
comm: &UniverseComm,
send_peers: &[usize],
send: &[Vec<S>],
recv_peers: &[usize],
) -> Result<Vec<Vec<S>>, KError> {
if send_peers.len() != send.len() {
return Err(KError::InvalidInput(
"alltoallv_scalar_sparse: send_peers/send lengths must match".into(),
));
}
let packed = pack_scalar_buffers(send);
let recv_packed = alltoallv_u64_sparse(comm, send_peers, &packed, recv_peers)?;
unpack_scalar_recvs(&recv_packed)
}
#[cfg(feature = "mpi")]
fn scalar_words() -> usize {
#[cfg(feature = "complex")]
{
2
}
#[cfg(not(feature = "complex"))]
{
1
}
}
#[cfg(feature = "mpi")]
fn pack_scalar_buffers(send: &[Vec<S>]) -> Vec<Vec<u64>> {
let words = scalar_words();
send.iter()
.map(|buf| {
let mut packed = Vec::with_capacity(buf.len() * words);
for &value in buf {
pack_scalar(value, &mut packed);
}
packed
})
.collect()
}
#[cfg(feature = "mpi")]
fn unpack_scalar_recvs(recv: &[Vec<u64>]) -> Result<Vec<Vec<S>>, KError> {
let words = scalar_words();
let mut out = Vec::with_capacity(recv.len());
for buf in recv {
if buf.len() % words != 0 {
return Err(KError::InvalidInput(
"alltoallv_scalar: corrupt packed scalar buffer".into(),
));
}
let n = buf.len() / words;
let mut scalars = Vec::with_capacity(n);
for chunk in buf.chunks_exact(words) {
scalars.push(unpack_scalar(chunk)?);
}
out.push(scalars);
}
Ok(out)
}
#[cfg(feature = "mpi")]
fn pack_scalar(value: S, dst: &mut Vec<u64>) {
dst.push(value.real().to_bits());
#[cfg(feature = "complex")]
dst.push(value.imag().to_bits());
}
#[cfg(feature = "mpi")]
fn unpack_scalar(words: &[u64]) -> Result<S, KError> {
#[cfg(feature = "complex")]
{
if words.len() != 2 {
return Err(KError::InvalidInput(
"packed scalar buffer length mismatch".into(),
));
}
Ok(S::from_parts(
f64::from_bits(words[0]),
f64::from_bits(words[1]),
))
}
#[cfg(not(feature = "complex"))]
{
if words.len() != 1 {
return Err(KError::InvalidInput(
"packed scalar buffer length mismatch".into(),
));
}
Ok(S::from_real(f64::from_bits(words[0])))
}
}