#[cfg(feature = "async-tokio")]
use tokio::io::AsyncWriteExt;
#[cfg(feature = "async-tokio")]
use tokio::net::{TcpListener, TcpStream};
#[cfg(feature = "async-tokio")]
use tokio::task::JoinHandle;
use anyhow::anyhow;
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use crate::channel::{self, Sender, UnboundedReceiver, UnboundedSender};
use crate::network::remote::remote_recv;
use crate::network::{DemuxCoord, NetworkMessage, ReceiverEndpoint};
use crate::operator::ExchangeData;
#[derive(Debug)]
pub(crate) struct DemuxHandle<In: Send + 'static> {
coord: DemuxCoord,
tx_senders: UnboundedSender<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
}
#[cfg(feature = "async-tokio")]
impl<In: ExchangeData> DemuxHandle<In> {
pub fn new(
coord: DemuxCoord,
address: (String, u16),
num_clients: usize,
) -> (Self, JoinHandle<()>) {
let (tx_senders, rx_senders) = channel::unbounded();
let join_handle = tokio::spawn(bind_remotes(coord, address, num_clients, rx_senders));
(Self { coord, tx_senders }, join_handle)
}
pub fn register(
&mut self,
receiver_endpoint: ReceiverEndpoint,
sender: Sender<NetworkMessage<In>>,
) {
log::debug!(
"registering {} to the demultiplexer of {}",
receiver_endpoint,
self.coord
);
self.tx_senders
.send((receiver_endpoint, sender))
.unwrap_or_else(|_| panic!("register for {:?} failed", self.coord))
}
}
#[cfg(feature = "async-tokio")]
async fn bind_remotes<In: ExchangeData>(
coord: DemuxCoord,
address: (String, u16),
num_clients: usize,
rx_senders: UnboundedReceiver<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
) {
let address = (address.0.as_ref(), address.1);
let address: Vec<_> = address
.to_socket_addrs()
.map_err(|e| format!("Failed to get the address for {}: {:?}", coord, e))
.unwrap()
.collect();
log::debug!("demux binding {}", address[0]);
let listener = TcpListener::bind(&*address)
.await
.map_err(|e| {
anyhow!(
"Failed to bind socket for {} at {:?}: {:?}",
coord,
address,
e
)
})
.unwrap();
let address = listener
.local_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
info!(
"Remote receiver at {} is ready to accept {} connections to {}",
coord, num_clients, address
);
let mut join_handles = vec![];
let mut tx_broadcast = vec![];
let mut connected_clients = 0;
while connected_clients < num_clients {
let stream = listener.accept().await;
let (stream, peer_addr) = match stream {
Ok(stream) => stream,
Err(e) => {
warn!("Failed to accept incoming connection at {}: {:?}", coord, e);
continue;
}
};
connected_clients += 1;
info!(
"Remote receiver at {} accepted a new connection from {} ({} / {})",
coord, peer_addr, connected_clients, num_clients
);
let (demux_tx, demux_rx) = flume::unbounded();
let join_handle = tokio::spawn(async move {
let mut senders = HashMap::new();
while let Ok((endpoint, sender)) = demux_rx.recv_async().await {
senders.insert(endpoint, sender);
}
log::debug!("demux got senders");
demux_thread::<In>(coord, senders, stream).await;
});
join_handles.push(join_handle);
tx_broadcast.push(demux_tx);
}
log::debug!("All connection to {} started, waiting for senders", coord);
while let Ok(t) = rx_senders.recv() {
for tx in tx_broadcast.iter() {
tx.send(t.clone()).unwrap();
}
}
drop(tx_broadcast); for handle in join_handles {
handle.await.unwrap();
}
log::debug!("all demuxes for {} finished", coord);
}
#[cfg(feature = "async-tokio")]
async fn demux_thread<In: ExchangeData>(
coord: DemuxCoord,
senders: HashMap<ReceiverEndpoint, Sender<NetworkMessage<In>>>,
mut stream: TcpStream,
) {
let address = stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
log::debug!("{} started", coord);
while let Some((dest, message)) = remote_recv(coord, &mut stream, &address).await {
if let Err(e) = senders[&dest].send(message) {
warn!("demux failed to send message to {}: {:?}", dest, e);
}
}
stream.shutdown().await.unwrap();
log::debug!("{} finished", coord);
}