use super::time_util::sleep_for_secs;
use std::io::{Read, Write};
use std::net::{Ipv4Addr, SocketAddrV4, TcpListener, TcpStream};
fn parse_host(host: &str) -> Ipv4Addr {
if host.to_lowercase() == "localhost" {
Ipv4Addr::LOCALHOST
} else {
host.parse().unwrap_or_else(|err| {
panic!("Unable to parse '{}' as an IPV4 address: {}", host, err)
})
}
}
fn socket(host: &str, port: u16) -> SocketAddrV4 {
SocketAddrV4::new(parse_host(host), port)
}
const HANDSHAKE_SLEEP: u64 = 2;
pub fn handshake(s0: SocketAddrV4, listener: bool) {
if listener {
let listener = TcpListener::bind(s0).unwrap_or_else(|err| {
panic!("Failed to start TCP connection at {:?}: {}", s0, err);
});
let stream = listener.incoming().next().unwrap_or_else(|| {
panic!("Failed to get stream using TCP (got None) at {:?}", s0);
});
let mut data = [0_u8; 50];
let _msg = stream
.unwrap_or_else(|err| {
panic!("Listener failed to get message from stream: {}", err);
})
.read(&mut data)
.unwrap_or_else(|err| {
panic!("Listener failed to read message from stream: {}", err);
});
} else {
loop {
if let Ok(mut stream) = TcpStream::connect(s0) {
stream.write_all(&[1]).unwrap();
break;
} else {
sleep_for_secs(HANDSHAKE_SLEEP);
}
}
}
}
pub fn barrier(host: &str, num_nodes: u64, this_node: u64, start_port: u16) {
assert!(this_node < num_nodes);
for phase in 0..2 {
if this_node == 0 {
for i in 1..num_nodes {
let socket0 = socket(
host,
start_port + (num_nodes as u16) * phase + (i as u16),
);
handshake(socket0, true);
}
} else {
let socket0 = socket(
host,
start_port + (num_nodes as u16) * phase + (this_node as u16),
);
handshake(socket0, false);
}
}
}