use std::any::TypeId;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::marker::PhantomData;
#[cfg(not(feature = "async-tokio"))]
use std::thread::JoinHandle;
#[cfg(feature = "async-tokio")]
use futures::StreamExt;
use itertools::Itertools;
use typemap_rev::{TypeMap, TypeMapKey};
use crate::channel::Sender;
use crate::config::RuntimeConfig;
use crate::network::demultiplexer::DemuxHandle;
use crate::network::multiplexer::MultiplexingSender;
use crate::network::{
local_channel, BlockCoord, Coord, DemuxCoord, NetworkReceiver, NetworkSender, ReceiverEndpoint,
};
use crate::operator::ExchangeData;
use crate::scheduler::{BlockId, HostId};
use super::NetworkMessage;
struct ReceiverKey<In: ExchangeData>(PhantomData<In>);
impl<In: ExchangeData> TypeMapKey for ReceiverKey<In> {
type Value = HashMap<ReceiverEndpoint, NetworkReceiver<In>, crate::block::CoordHasherBuilder>;
}
struct SenderKey<In: ExchangeData>(PhantomData<In>);
impl<In: ExchangeData> TypeMapKey for SenderKey<In> {
type Value = HashMap<ReceiverEndpoint, NetworkSender<In>, crate::block::CoordHasherBuilder>;
}
struct DemultiplexingReceiverKey<In: ExchangeData>(PhantomData<In>);
impl<In: ExchangeData> TypeMapKey for DemultiplexingReceiverKey<In> {
type Value = HashMap<DemuxCoord, DemuxHandle<In>, crate::block::CoordHasherBuilder>;
}
struct MultiplexingSenderKey<In: ExchangeData>(PhantomData<In>);
impl<In: ExchangeData> TypeMapKey for MultiplexingSenderKey<In> {
type Value = HashMap<DemuxCoord, MultiplexingSender<In>, crate::block::CoordHasherBuilder>;
}
#[derive(Default, Derivative)]
#[derivative(Debug)]
struct SenderMetadata {
to_remote: bool,
}
pub(crate) struct NetworkTopology {
config: RuntimeConfig,
receivers: Option<TypeMap>,
senders: Option<TypeMap>,
demultiplexers: Option<TypeMap>,
multiplexers: Option<TypeMap>,
next: HashMap<(Coord, TypeId), Vec<(Coord, bool)>, crate::block::CoordHasherBuilder>,
prev: HashMap<Coord, Vec<(Coord, TypeId)>, crate::block::CoordHasherBuilder>,
senders_metadata: HashMap<ReceiverEndpoint, SenderMetadata, crate::block::CoordHasherBuilder>,
block_replicas: HashMap<BlockId, HashSet<Coord>, crate::block::CoordHasherBuilder>,
used_receivers: HashSet<ReceiverEndpoint>,
registered_receivers: HashSet<ReceiverEndpoint>,
demultiplexer_addresses: HashMap<DemuxCoord, (String, u16), crate::block::CoordHasherBuilder>,
#[cfg(not(feature = "async-tokio"))]
join_handles: Vec<JoinHandle<()>>,
#[cfg(feature = "async-tokio")]
async_join_handles: Vec<tokio::task::JoinHandle<()>>,
}
impl std::fmt::Debug for NetworkTopology {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NetworkTopology")
.field("config", &self.config)
.field("next", &self.next)
.field("senders_metadata", &self.senders_metadata)
.field("block_replicas", &self.block_replicas)
.field("demultiplexer_addresses", &self.demultiplexer_addresses)
.finish()
}
}
impl NetworkTopology {
pub(crate) fn new(config: RuntimeConfig) -> Self {
NetworkTopology {
config,
receivers: Some(TypeMap::new()),
senders: Some(TypeMap::new()),
demultiplexers: Some(TypeMap::new()),
multiplexers: Some(TypeMap::new()),
next: Default::default(),
prev: Default::default(),
senders_metadata: Default::default(),
block_replicas: Default::default(),
used_receivers: Default::default(),
registered_receivers: Default::default(),
demultiplexer_addresses: Default::default(),
#[cfg(not(feature = "async-tokio"))]
join_handles: Default::default(),
#[cfg(feature = "async-tokio")]
async_join_handles: Default::default(),
}
}
#[cfg(feature = "async-tokio")]
pub(crate) async fn stop_and_wait(&mut self) {
self.async_join_handles
.drain(..)
.collect::<futures::stream::FuturesUnordered<_>>()
.for_each(|h| {
h.unwrap();
futures::future::ready(())
})
.await;
}
#[cfg(not(feature = "async-tokio"))]
pub(crate) fn stop_and_wait(&mut self) {
for handle in self.join_handles.drain(..) {
handle.join().unwrap();
}
}
pub fn get_senders<T: ExchangeData>(
&mut self,
coord: Coord,
) -> Vec<(ReceiverEndpoint, NetworkSender<T>)> {
let typ = TypeId::of::<T>();
match self.next.get(&(coord, typ)) {
None => Default::default(),
Some(next) => {
let next = next.clone();
next.iter()
.filter_map(|&(c, fragile)| {
if fragile {
None
} else {
let receiver_endpoint = ReceiverEndpoint::new(c, coord.block_id);
Some((receiver_endpoint, self.get_sender(receiver_endpoint)))
}
})
.collect()
}
}
}
pub fn get_sender<T: ExchangeData>(
&mut self,
receiver_endpoint: ReceiverEndpoint,
) -> NetworkSender<T> {
if !self
.senders
.as_mut()
.unwrap()
.contains_key::<SenderKey<T>>()
{
self.senders
.as_mut()
.unwrap()
.insert::<SenderKey<T>>(Default::default());
}
let entry = self
.senders
.as_mut()
.unwrap()
.get_mut::<SenderKey<T>>()
.unwrap();
if !entry.contains_key(&receiver_endpoint) {
self.register_channel::<T>(receiver_endpoint);
}
self.senders
.as_mut()
.unwrap()
.get::<SenderKey<T>>()
.unwrap()
.get(&receiver_endpoint)
.unwrap()
.clone()
}
pub fn get_receiver<T: ExchangeData>(
&mut self,
receiver_endpoint: ReceiverEndpoint,
) -> NetworkReceiver<T> {
if self.used_receivers.contains(&receiver_endpoint) {
panic!("The receiver for {receiver_endpoint} has already been got",);
}
self.used_receivers.insert(receiver_endpoint);
let entry = self
.receivers
.as_mut()
.unwrap()
.entry::<ReceiverKey<T>>()
.or_default();
if !entry.contains_key(&receiver_endpoint) {
self.register_channel::<T>(receiver_endpoint);
}
self.receivers
.as_mut()
.unwrap()
.get_mut::<ReceiverKey<T>>()
.unwrap()
.remove(&receiver_endpoint)
.unwrap()
}
fn register_demux<T: ExchangeData>(
&mut self,
receiver_endpoint: ReceiverEndpoint,
local_sender: Sender<NetworkMessage<T>>,
) {
let demux_coord = DemuxCoord::from(receiver_endpoint);
let demuxes = self
.demultiplexers
.as_mut()
.unwrap()
.entry::<DemultiplexingReceiverKey<T>>()
.or_default();
if let Entry::Vacant(e) = demuxes.entry(demux_coord) {
let mut prev = HashSet::new();
for (&(from, typ), to) in &self.next {
if from.host_id == demux_coord.coord.host_id {
continue;
}
if typ != TypeId::of::<T>() {
continue;
}
for &(to, _fragile) in to {
if demux_coord.includes_channel(from, to) {
prev.insert(BlockCoord::from(from));
}
}
}
if !prev.is_empty() {
let address = self.demultiplexer_addresses[&demux_coord].clone();
let (demux, join_handle) = DemuxHandle::new(demux_coord, address, prev.len());
#[cfg(not(feature = "async-tokio"))]
self.join_handles.push(join_handle);
#[cfg(feature = "async-tokio")]
self.async_join_handles.push(join_handle);
e.insert(demux);
} else {
log::debug!("demux {} skipping (no remote predecessors)", demux_coord);
}
}
if let Some(demux) = demuxes.get_mut(&demux_coord) {
demux.register(receiver_endpoint, local_sender)
}
}
fn register_mux<T: ExchangeData>(
&mut self,
receiver_endpoint: ReceiverEndpoint,
) -> NetworkSender<T> {
let muxers = self
.multiplexers
.as_mut()
.unwrap()
.entry::<MultiplexingSenderKey<T>>()
.or_default();
let demux_coord = DemuxCoord::from(receiver_endpoint);
if let Entry::Vacant(e) = muxers.entry(demux_coord) {
let address = self.demultiplexer_addresses[&demux_coord].clone();
let (mux, join_handle) = MultiplexingSender::new(demux_coord, address);
#[cfg(not(feature = "async-tokio"))]
self.join_handles.push(join_handle);
#[cfg(feature = "async-tokio")]
self.async_join_handles.push(join_handle);
e.insert(mux);
}
muxers
.get_mut(&demux_coord)
.unwrap()
.get_sender(receiver_endpoint)
}
fn register_channel<T: ExchangeData>(&mut self, receiver_endpoint: ReceiverEndpoint) {
log::debug!("new endpoint {}", receiver_endpoint);
assert!(
!self.registered_receivers.contains(&receiver_endpoint),
"receiver {receiver_endpoint} has already been registered",
);
self.registered_receivers.insert(receiver_endpoint);
let sender_metadata = self
.senders_metadata
.get(&receiver_endpoint)
.unwrap_or_else(|| panic!("Channel for endpoint {receiver_endpoint} not registered",));
match &self.config {
RuntimeConfig::Remote(_) => {
if sender_metadata.to_remote {
let sender = self.register_mux(receiver_endpoint);
self.senders
.as_mut()
.unwrap()
.entry::<SenderKey<T>>()
.or_default()
.insert(receiver_endpoint, sender);
} else {
let (sender, receiver) = local_channel(receiver_endpoint);
if receiver_endpoint.coord.host_id == self.config.host_id().unwrap() {
self.register_demux(receiver_endpoint, sender.clone_inner());
}
self.receivers
.as_mut()
.unwrap()
.entry::<ReceiverKey<T>>()
.or_default()
.insert(receiver_endpoint, receiver);
self.senders
.as_mut()
.unwrap()
.entry::<SenderKey<T>>()
.or_default()
.insert(receiver_endpoint, sender);
};
}
RuntimeConfig::Local(_) => {
let (sender, receiver) = local_channel(receiver_endpoint);
self.receivers
.as_mut()
.unwrap()
.entry::<ReceiverKey<T>>()
.or_default()
.insert(receiver_endpoint, receiver);
self.senders
.as_mut()
.unwrap()
.entry::<SenderKey<T>>()
.or_default()
.insert(receiver_endpoint, sender);
}
}
}
pub fn connect(&mut self, from: Coord, to: Coord, typ: TypeId, fragile: bool) {
let host_id = self.config.host_id().unwrap();
let from_remote = from.host_id != host_id;
let to_remote = to.host_id != host_id;
log::trace!(
"new connection: {} -> {}, remote: ({}, {})",
from,
to,
from_remote,
to_remote
);
self.next
.entry((from, typ))
.or_default()
.push((to, fragile));
self.prev.entry(to).or_default().push((from, typ));
self.block_replicas
.entry(from.block_id)
.or_default()
.insert(from);
self.block_replicas
.entry(to.block_id)
.or_default()
.insert(to);
if from_remote && to_remote {
return;
}
let receiver_endpoint = ReceiverEndpoint::new(to, from.block_id);
self.senders_metadata.entry(receiver_endpoint).or_default();
if to_remote {
let metadata = self.senders_metadata.get_mut(&receiver_endpoint).unwrap();
metadata.to_remote = true;
}
}
pub fn prev(&self, coord: Coord) -> Vec<(Coord, TypeId)> {
if let Some(prev) = self.prev.get(&coord) {
prev.clone()
} else {
vec![]
}
}
pub fn replicas(&self, block_id: BlockId) -> Vec<Coord> {
if let Some(replicas) = self.block_replicas.get(&block_id) {
replicas.iter().cloned().collect()
} else {
vec![]
}
}
pub fn build(&mut self) {
log::debug!("finalizing topology");
let config = if let RuntimeConfig::Remote(config) = &self.config {
config
} else {
return;
};
let mut coords = HashSet::new();
for (&(from, _typ), to) in self.next.iter() {
for &(to, _fragile) in to {
let coord = DemuxCoord::new(from, to);
coords.insert(coord);
}
}
let mut used_ports: HashMap<HostId, u16> = HashMap::new();
for coord in coords.into_iter().sorted() {
let host_id = coord.coord.host_id;
let port_offset = used_ports.entry(host_id).or_default();
let host = &config.hosts[host_id as usize];
let port = host.base_port + *port_offset;
*port_offset += 1;
let address = (host.address.clone(), port);
log::debug!("demux {} socket: {:?}", coord, address);
self.demultiplexer_addresses.insert(coord, address);
}
}
pub fn finalize(&mut self) {
self.receivers.take();
self.senders.take();
self.demultiplexers.take();
self.multiplexers.take();
}
pub fn log(&self) {
let mut topology = "execution graph:".to_owned();
for ((coord, _typ), next) in self.next.iter().sorted() {
write!(&mut topology, "\n {coord}:",).unwrap();
for (next, fragile) in next.iter().sorted() {
write!(
&mut topology,
" {}{}",
next,
if *fragile { "*" } else { "" }
)
.unwrap();
}
}
log::trace!("{}", topology);
}
}
#[cfg(test)]
mod tests {
use crate::network::NetworkMessage;
use crate::operator::StreamElement;
use super::*;
fn build_message<T>(t: T) -> NetworkMessage<T> {
NetworkMessage::new_single(StreamElement::Item(t), Coord::default())
}
#[test]
fn test_local_topology() {
let config = RuntimeConfig::local(4);
let mut topology = NetworkTopology::new(config);
let sender1 = Coord::new(0, 0, 0);
let sender2 = Coord::new(1, 0, 0);
let receiver1 = Coord::new(2, 0, 0);
let receiver2 = Coord::new(3, 0, 0);
topology.connect(sender1, receiver1, TypeId::of::<i32>(), false);
topology.connect(sender2, receiver1, TypeId::of::<u64>(), false);
topology.connect(sender2, receiver2, TypeId::of::<u64>(), false);
topology.build();
let endpoint1 = ReceiverEndpoint::new(receiver1, 0);
let endpoint2 = ReceiverEndpoint::new(receiver1, 1);
let endpoint3 = ReceiverEndpoint::new(receiver2, 1);
let tx1 = topology
.get_senders::<i32>(sender1)
.into_iter()
.collect::<HashMap<_, _>>();
assert_eq!(tx1.len(), 1);
tx1[&endpoint1].send(build_message(123i32)).unwrap();
let tx2 = topology
.get_senders::<u64>(sender2)
.into_iter()
.collect::<HashMap<_, _>>();
assert_eq!(tx2.len(), 2);
tx2[&endpoint2].send(build_message(666u64)).unwrap();
tx2[&endpoint3].send(build_message(42u64)).unwrap();
let rx1 = topology.get_receiver::<i32>(endpoint1);
assert_eq!(
rx1.recv().unwrap().into_iter().collect::<Vec<_>>(),
vec![StreamElement::Item(123i32)]
);
let rx2 = topology.get_receiver::<u64>(endpoint2);
assert_eq!(
rx2.recv().unwrap().into_iter().collect::<Vec<_>>(),
vec![StreamElement::Item(666u64)]
);
let rx3 = topology.get_receiver::<u64>(endpoint3);
assert_eq!(
rx3.recv().unwrap().into_iter().collect::<Vec<_>>(),
vec![StreamElement::Item(42u64)]
);
}
#[cfg(not(feature = "async-tokio"))]
#[test]
fn test_remote_topology() {
let mut config = tempfile::NamedTempFile::new().unwrap();
let config_yaml = "hosts:\n".to_string()
+ " - address: 127.0.0.1\n"
+ " base_port: 21841\n"
+ " num_cores: 1\n"
+ " - address: 127.0.0.1\n"
+ " base_port: 31258\n"
+ " num_cores: 1\n";
std::io::Write::write_all(&mut config, config_yaml.as_bytes()).unwrap();
let config = RuntimeConfig::remote(config.path()).unwrap();
let run = |mut config: RuntimeConfig, host: HostId| {
if let RuntimeConfig::Remote(remote) = &mut config {
remote.host_id = Some(host);
}
let mut topology = NetworkTopology::new(config);
let s1 = Coord::new(0, 0, 0);
let s2 = Coord::new(0, 1, 0);
let s3 = Coord::new(1, 0, 0);
let s4 = Coord::new(1, 0, 1);
let r1 = Coord::new(2, 1, 0);
let r2 = Coord::new(2, 1, 1);
topology.connect(s1, r1, TypeId::of::<i32>(), false);
topology.connect(s2, r1, TypeId::of::<i32>(), false);
topology.connect(s3, r1, TypeId::of::<u64>(), false);
topology.connect(s4, r1, TypeId::of::<u64>(), false);
topology.connect(s3, r2, TypeId::of::<u64>(), false);
topology.connect(s4, r2, TypeId::of::<u64>(), false);
topology.build();
let endpoint1 = ReceiverEndpoint::new(r1, 0);
let endpoint2 = ReceiverEndpoint::new(r1, 1);
let endpoint3 = ReceiverEndpoint::new(r2, 1);
if s1.host_id == host {
let tx1 = topology
.get_senders::<i32>(s1)
.into_iter()
.collect::<HashMap<_, _>>();
tx1[&endpoint1].send(build_message(123i32)).unwrap();
}
if s2.host_id == host {
let tx2 = topology
.get_senders::<i32>(s2)
.into_iter()
.collect::<HashMap<_, _>>();
assert_eq!(tx2.len(), 1);
tx2[&endpoint1].send(build_message(456i32)).unwrap();
}
if s3.host_id == host {
let tx3 = topology
.get_senders::<u64>(s3)
.into_iter()
.collect::<HashMap<_, _>>();
assert_eq!(tx3.len(), 2);
tx3[&endpoint2].send(build_message(666u64)).unwrap();
tx3[&endpoint3].send(build_message(42u64)).unwrap();
}
if s4.host_id == host {
let tx4 = topology
.get_senders::<u64>(s4)
.into_iter()
.collect::<HashMap<_, _>>();
assert_eq!(tx4.len(), 2);
tx4[&endpoint2].send(build_message(111u64)).unwrap();
tx4[&endpoint3].send(build_message(4242u64)).unwrap();
}
let mut join_handles = vec![];
if endpoint1.coord.host_id == host {
let rx1 = topology.get_receiver::<i32>(endpoint1);
join_handles.push(
std::thread::Builder::new()
.name("rx1".into())
.spawn(move || receiver(rx1, vec![123i32, 456i32]))
.unwrap(),
);
}
if endpoint2.coord.host_id == host {
let rx2 = topology.get_receiver::<u64>(endpoint2);
join_handles.push(
std::thread::Builder::new()
.name("rx2".into())
.spawn(move || receiver(rx2, vec![111u64, 666u64]))
.unwrap(),
);
}
if endpoint3.coord.host_id == host {
let rx3 = topology.get_receiver::<u64>(endpoint3);
join_handles.push(
std::thread::Builder::new()
.name("rx3".into())
.spawn(move || receiver(rx3, vec![42u64, 4242u64]))
.unwrap(),
);
}
topology.finalize();
for handle in join_handles {
handle.join().unwrap();
}
topology.stop_and_wait();
};
let config0 = config.clone();
let join0 = std::thread::Builder::new()
.name("host0".into())
.spawn(move || run(config0, 0))
.unwrap();
let join1 = std::thread::Builder::new()
.name("host1".into())
.spawn(move || run(config, 1))
.unwrap();
join0.join().unwrap();
join1.join().unwrap();
}
#[cfg(not(feature = "async-tokio"))]
fn receiver<T: ExchangeData + Ord + std::fmt::Debug>(
receiver: NetworkReceiver<T>,
expected: Vec<T>,
) {
let res = (0..expected.len())
.flat_map(|_| receiver.recv().unwrap().into_iter())
.sorted()
.collect_vec();
assert_eq!(
res,
expected.into_iter().map(StreamElement::Item).collect_vec()
);
}
#[test]
fn test_multiple_output_types() {
let config = RuntimeConfig::local(4);
let mut topology = NetworkTopology::new(config);
let sender1 = Coord::new(0, 0, 0);
let sender2 = Coord::new(0, 0, 0);
let receiver1 = Coord::new(1, 0, 0);
let receiver2 = Coord::new(2, 0, 0);
topology.connect(sender1, receiver1, TypeId::of::<i32>(), false);
topology.connect(sender2, receiver2, TypeId::of::<u64>(), false);
topology.build();
let endpoint1 = ReceiverEndpoint::new(receiver1, 0);
let endpoint2 = ReceiverEndpoint::new(receiver2, 0);
let tx1 = topology.get_sender::<i32>(endpoint1);
tx1.send(build_message(123i32)).unwrap();
let tx2 = topology.get_sender::<u64>(endpoint2);
tx2.send(build_message(666u64)).unwrap();
let rx1 = topology.get_receiver::<i32>(endpoint1);
assert_eq!(
rx1.recv().unwrap().into_iter().collect::<Vec<_>>(),
vec![StreamElement::Item(123i32)]
);
let rx2 = topology.get_receiver::<u64>(endpoint2);
assert_eq!(
rx2.recv().unwrap().into_iter().collect::<Vec<_>>(),
vec![StreamElement::Item(666u64)]
);
}
}