use std::net::{Shutdown, TcpListener, TcpStream};
use std::thread::JoinHandle;
use anyhow::anyhow;
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use crate::channel::{self, Sender, UnboundedReceiver, UnboundedSender};
use crate::network::remote::remote_recv;
use crate::network::{DemuxCoord, NetworkMessage, ReceiverEndpoint};
use crate::operator::ExchangeData;
#[derive(Debug)]
pub(crate) struct DemuxHandle<In: ExchangeData> {
coord: DemuxCoord,
tx_senders: UnboundedSender<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
}
impl<In: ExchangeData> DemuxHandle<In> {
pub fn new(
coord: DemuxCoord,
address: (String, u16),
num_clients: usize,
) -> (Self, JoinHandle<()>) {
let (tx_senders, rx_senders) = channel::unbounded();
let join_handle = std::thread::Builder::new()
.name(format!(
"reg-{}:{}-{}",
coord.coord.host_id, coord.prev_block_id, coord.coord.block_id
))
.spawn(move || bind_remotes(coord, address, num_clients, rx_senders))
.unwrap();
(Self { coord, tx_senders }, join_handle)
}
pub fn register(
&mut self,
receiver_endpoint: ReceiverEndpoint,
sender: Sender<NetworkMessage<In>>,
) {
log::debug!("demux register {} to {}", receiver_endpoint, self.coord);
self.tx_senders
.send((receiver_endpoint, sender))
.unwrap_or_else(|_| panic!("register for {:?} failed", self.coord))
}
}
fn bind_remotes<In: ExchangeData>(
coord: DemuxCoord,
address: (String, u16),
num_clients: usize,
rx_senders: UnboundedReceiver<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
) {
let address = (address.0.as_ref(), address.1);
let address: Vec<_> = address
.to_socket_addrs()
.map_err(|e| format!("Failed to get the address for {coord}: {e:?}"))
.unwrap()
.collect();
log::debug!("{coord} binding {}", address[0]);
let listener = TcpListener::bind(&*address)
.map_err(|e| {
anyhow!(
"Failed to bind socket for {} at {:?}: {:?}",
coord,
address,
e
)
})
.unwrap();
let address = listener
.local_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
debug!(
"{} ready at {}, waiting for {} clients",
coord, address, num_clients
);
let mut join_handles = vec![];
let mut tx_broadcast = vec![];
let mut incoming = listener.incoming();
let mut connected_clients = 0;
while connected_clients < num_clients {
let stream = incoming.next().unwrap();
let stream = match stream {
Ok(stream) => stream,
Err(e) => {
warn!("{} to accept incoming connection: {:?}", coord, e);
continue;
}
};
connected_clients += 1;
let peer_addr = stream.peer_addr().unwrap();
debug!(
"{} new connection from {} ({} / {})",
coord, peer_addr, connected_clients, num_clients
);
let (demux_tx, demux_rx) = channel::unbounded();
let join_handle = std::thread::Builder::new()
.name(format!(
"demux-{}:{}-{}",
coord.coord.host_id, coord.prev_block_id, coord.coord.block_id
))
.spawn(move || {
let mut senders = HashMap::new();
while let Ok((endpoint, sender)) = demux_rx.recv() {
senders.insert(endpoint, sender);
}
log::debug!("{coord} got senders");
demux_thread::<In>(coord, senders, stream);
})
.unwrap();
join_handles.push(join_handle);
tx_broadcast.push(demux_tx);
}
log::debug!("{} all clients connected", coord);
drop(listener);
while let Ok(t) = rx_senders.recv() {
for tx in tx_broadcast.iter() {
tx.send(t.clone()).unwrap();
}
}
drop(tx_broadcast); for handle in join_handles {
handle.join().unwrap();
}
log::debug!("{} finished", coord);
}
fn demux_thread<In: ExchangeData>(
coord: DemuxCoord,
senders: HashMap<ReceiverEndpoint, Sender<NetworkMessage<In>>>,
mut stream: TcpStream,
) {
let address = stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
log::debug!("{} started", coord);
let mut r = &mut stream;
while let Some((dest, message)) = remote_recv(coord, &mut r, &address) {
if let Err(e) = senders[&dest].send(message) {
warn!("demux failed to send message to {}: {:?}", dest, e);
}
}
let _ = stream.shutdown(Shutdown::Both);
log::debug!("{} finished", coord);
}