use std::collections::{BTreeMap, BTreeSet};
use std::future::Future;
use std::hash::Hash;
use futures::executor;
use itertools::Itertools;
use mpi::traits::{Destination, Equivalence, Source};
use msg_types::{MPIIncomingEdge, MPIMessage, MPIMessageTag, MPIRelRc};
use send_recv::{MPIAsyncSendRecv, MPIBufferedSendRecv, MPISendRecv, MPIStandardSendRecv};
use crate::{detached::Detached, hash_id::RelRcHash, RelRc};
use super::DetachedInnerData;
mod msg_types;
mod send_recv;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MPIMode {
#[default]
Standard,
Buffered,
Async,
}
pub trait RelRcCommunicator<N, E> {
fn send_relrc(&self, relrc: &RelRc<N, E>, mode: MPIMode) {
executor::block_on(self.send_relrc_async(relrc, mode))
}
fn recv_relrc(
&self,
attach_to: impl IntoIterator<Item = RelRc<N, E>>,
mode: MPIMode,
) -> RelRc<N, E> {
if mode == MPIMode::Async {
panic!("Use recv_relrc_async instead of recv_relrc for async mode");
}
executor::block_on(self.recv_relrc_async(attach_to, mode))
}
fn send_relrc_async(&self, relrc: &RelRc<N, E>, mode: MPIMode) -> impl Future<Output = ()>;
fn recv_relrc_async(
&self,
attach_to: impl IntoIterator<Item = RelRc<N, E>>,
mode: MPIMode,
) -> impl Future<Output = RelRc<N, E>>;
}
impl<T, N, E> RelRcCommunicator<N, E> for T
where
T: Source + Destination,
N: Hash + Clone + Equivalence,
E: Hash + Clone + Equivalence,
{
async fn send_relrc_async(&self, relrc: &RelRc<N, E>, mode: MPIMode) {
match mode {
MPIMode::Buffered => {
let dest = MPIBufferedSendRecv(self);
send_relrc(&dest, relrc).await;
}
MPIMode::Standard => {
let dest = MPIStandardSendRecv(self);
send_relrc(&dest, relrc).await;
}
MPIMode::Async => {
unimplemented!(
"Async mode not supported for sending. Use Standard or Buffered mode instead."
);
}
}
}
async fn recv_relrc_async(
&self,
attach_to: impl IntoIterator<Item = RelRc<N, E>>,
mode: MPIMode,
) -> RelRc<N, E> {
macro_rules! recv_with_mode {
($mode:expr) => {{
let source = $mode(self);
recv_relrc(&source, attach_to).await
}};
}
match mode {
MPIMode::Buffered => recv_with_mode!(MPIBufferedSendRecv),
MPIMode::Standard => recv_with_mode!(MPIStandardSendRecv),
MPIMode::Async => recv_with_mode!(MPIAsyncSendRecv),
}
}
}
async fn send_relrc<N: Hash + Clone, E: Hash + Clone>(
dest: &impl MPISendRecv<N, E>,
relrc: &RelRc<N, E>,
) {
let detached = relrc.detach(&BTreeSet::new());
mpi_send(
dest,
detached.current,
&detached.all_data[&detached.current],
);
loop {
let msg = dest.receive(MPIMessageTag::Ack).await;
if matches!(msg, MPIMessage::Done) {
break;
}
let MPIMessage::RequestRelRc(hash) = msg else {
panic!("Received unexpected message");
};
mpi_send(dest, hash, &detached.all_data[&hash]);
}
}
async fn recv_relrc<N: Hash + Clone, E: Hash + Clone>(
source: &impl MPISendRecv<N, E>,
attach_to: impl IntoIterator<Item = RelRc<N, E>>,
) -> RelRc<N, E> {
let attach_to: BTreeMap<RelRcHash, RelRc<N, E>> =
attach_to.into_iter().map(|r| (r.hash_id(), r)).collect();
let mut detached: Option<Detached<N, E>> = None;
while detached.is_none() || !detached.as_ref().unwrap().attaches_to(&attach_to) {
if let Some(detached) = detached.as_ref() {
let first_unknown_hash = detached
.required_hashes()
.find(|hash| !attach_to.contains_key(hash))
.expect("cannot attach but all required objects are known");
let msg = MPIMessage::RequestRelRc(first_unknown_hash);
source.send(&msg);
}
let (hash, detached_inner) = mpi_recv(source).await;
if detached.is_none() {
detached = Some(Detached::empty(hash));
}
let all_data = &mut detached.as_mut().unwrap().all_data;
all_data.insert(hash, detached_inner);
}
source.send(&MPIMessage::Done);
RelRc::attach(detached.unwrap(), attach_to.values().cloned())
}
fn mpi_send<N: Clone, E: Clone>(
dest: &impl MPISendRecv<N, E>,
hash: RelRcHash,
data: &DetachedInnerData<N, E>,
) {
let relrc_msg = MPIRelRc { hash: hash.into() };
dest.send(&relrc_msg.into());
dest.send(&MPIMessage::NodeWeight(data.value.clone()));
let (incoming_hashes, incoming_values): (Vec<_>, Vec<_>) =
data.incoming.iter().map(|(fst, snd)| (*fst, snd)).unzip();
let msgs = incoming_hashes
.into_iter()
.map(|hash| MPIIncomingEdge {
source_hash: hash.into(),
})
.collect_vec();
dest.send(&msgs.into());
for weight in incoming_values {
dest.send(&MPIMessage::EdgeWeight(weight.clone()));
}
}
async fn mpi_recv<N, E>(source: &impl MPISendRecv<N, E>) -> (RelRcHash, DetachedInnerData<N, E>) {
let MPIMessage::RelRc(hash) = source.receive(MPIMessageTag::RelRc).await else {
panic!("Expected RelRc message");
};
let MPIMessage::NodeWeight(node_weight) = source.receive(MPIMessageTag::NodeWeight).await
else {
panic!("Expected node weight message");
};
let MPIMessage::IncomingEdge(incoming_edges) =
source.receive(MPIMessageTag::IncomingEdge).await
else {
panic!("Expected incoming edge message");
};
let mut incoming = Vec::with_capacity(incoming_edges.len());
for source_hash in incoming_edges {
let MPIMessage::EdgeWeight(edge_weight) = source.receive(MPIMessageTag::EdgeWeight).await
else {
panic!("Expected edge weight message");
};
incoming.push((source_hash, edge_weight));
}
(
hash,
DetachedInnerData {
value: node_weight,
incoming,
},
)
}