use {
crate::{
ip_echo_server::{IpEchoServerMessage, IpEchoServerResponse},
HEADER_LENGTH, IP_ECHO_SERVER_RESPONSE_LENGTH, MAX_PORT_COUNT_PER_MESSAGE,
},
anyhow::bail,
bytes::{BufMut, BytesMut},
itertools::Itertools,
log::*,
std::{
collections::{BTreeMap, HashMap, HashSet},
net::{IpAddr, SocketAddr, TcpListener, TcpStream, UdpSocket},
sync::{Arc, RwLock},
time::{Duration, Instant},
},
tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpSocket,
sync::oneshot,
task::JoinSet,
},
};
pub(crate) const TIMEOUT: Duration = Duration::from_secs(5);
pub(crate) async fn ip_echo_server_request_with_binding(
ip_echo_server_addr: SocketAddr,
msg: IpEchoServerMessage,
bind_address: IpAddr,
) -> anyhow::Result<IpEchoServerResponse> {
let socket = TcpSocket::new_v4()?;
socket.bind(SocketAddr::new(bind_address, 0))?;
let response =
tokio::time::timeout(TIMEOUT, make_request(socket, ip_echo_server_addr, msg)).await??;
parse_response(response, ip_echo_server_addr)
}
pub(crate) async fn ip_echo_server_request(
ip_echo_server_addr: SocketAddr,
msg: IpEchoServerMessage,
) -> anyhow::Result<IpEchoServerResponse> {
let socket = TcpSocket::new_v4()?;
let response =
tokio::time::timeout(TIMEOUT, make_request(socket, ip_echo_server_addr, msg)).await??;
parse_response(response, ip_echo_server_addr)
}
async fn make_request(
socket: TcpSocket,
ip_echo_server_addr: SocketAddr,
msg: IpEchoServerMessage,
) -> anyhow::Result<BytesMut> {
let mut stream = socket.connect(ip_echo_server_addr).await?;
let mut bytes = BytesMut::with_capacity(IP_ECHO_SERVER_RESPONSE_LENGTH);
bytes.extend_from_slice(&[0u8; HEADER_LENGTH]);
bytes.extend_from_slice(&bincode::serialize(&msg)?);
bytes.put_u8(b'\n');
stream.write_all(&bytes).await?;
stream.flush().await?;
bytes.clear();
let _n = stream.read_buf(&mut bytes).await?;
stream.shutdown().await?;
Ok(bytes)
}
fn parse_response(
response: BytesMut,
ip_echo_server_addr: SocketAddr,
) -> anyhow::Result<IpEchoServerResponse> {
if response.len() < HEADER_LENGTH {
bail!("Response too short, received {} bytes", response.len());
}
let (response_header, body) =
response
.split_first_chunk::<HEADER_LENGTH>()
.ok_or(anyhow::anyhow!(
"Not enough data in the response from {ip_echo_server_addr}!"
))?;
let payload = match response_header {
[0, 0, 0, 0] => bincode::deserialize(body)?,
[b'H', b'T', b'T', b'P'] => {
let http_response = std::str::from_utf8(body);
match http_response {
Ok(r) => bail!(
"Invalid gossip entrypoint. {ip_echo_server_addr} looks to be an HTTP port \
replying with {r}"
),
Err(_) => bail!(
"Invalid gossip entrypoint. {ip_echo_server_addr} looks to be an HTTP port."
),
}
}
_ => {
bail!(
"Invalid gossip entrypoint. {ip_echo_server_addr} provided unexpected header \
bytes {response_header:?} "
);
}
};
Ok(payload)
}
pub(crate) const DEFAULT_RETRY_COUNT: usize = 5;
pub(crate) async fn verify_all_reachable_tcp(
ip_echo_server_addr: SocketAddr,
listeners: Vec<TcpListener>,
timeout: Duration,
) -> bool {
if listeners.is_empty() {
warn!("No ports provided for verify_all_reachable_tcp to check");
return true;
}
let bind_address = listeners[0]
.local_addr()
.expect("Sockets should be bound")
.ip();
for listener in listeners.iter() {
let local_binding = listener.local_addr().expect("Sockets should be bound");
assert_eq!(
local_binding.ip(),
bind_address,
"All sockets should be bound to the same IP"
);
}
let mut checkers = Vec::new();
let mut ok = true;
for chunk in &listeners.into_iter().chunks(MAX_PORT_COUNT_PER_MESSAGE) {
let listeners = chunk.collect_vec();
let ports = listeners
.iter()
.map(|l| l.local_addr().expect("Sockets should be bound").port())
.collect_vec();
info!(
"Checking that tcp ports {:?} are reachable from {:?}",
&ports, ip_echo_server_addr
);
let _ = ip_echo_server_request_with_binding(
ip_echo_server_addr,
IpEchoServerMessage::new(&ports, &[]),
bind_address,
)
.await
.map_err(|err| warn!("ip_echo_server request failed: {err}"));
for (port, tcp_listener) in ports.into_iter().zip(listeners) {
let listening_addr = tcp_listener.local_addr().unwrap();
let (sender, receiver) = oneshot::channel();
let thread_handle = tokio::task::spawn_blocking(move || {
debug!("Waiting for incoming connection on tcp/{port}");
match tcp_listener.incoming().next() {
Some(_) => {
let _ = sender.send(());
}
None => warn!("tcp incoming failed"),
}
});
let receiver = tokio::time::timeout(timeout, receiver);
checkers.push((listening_addr, thread_handle, receiver));
}
}
for (listening_addr, thread_handle, receiver) in checkers.drain(..) {
match receiver.await {
Ok(Ok(_)) => {
info!("tcp/{} is reachable", listening_addr.port());
}
Ok(Err(_v)) => {
unreachable!("The receive on oneshot channel should never fail");
}
Err(_t) => {
error!(
"Received no response at tcp/{}, check your port configuration",
listening_addr.port()
);
TcpStream::connect_timeout(&listening_addr, timeout).unwrap();
ok = false;
}
}
thread_handle.await.expect("Thread should exit cleanly");
}
ok
}
pub(crate) async fn verify_all_reachable_udp(
ip_echo_server_addr: SocketAddr,
sockets: &[&UdpSocket],
timeout: Duration,
retry_count: usize,
) -> bool {
if sockets.is_empty() {
warn!("No ports provided for verify_all_reachable_udp to check");
return true;
}
let mut ip_to_ports: HashMap<IpAddr, BTreeMap<u16, Vec<&UdpSocket>>> = HashMap::new();
for &socket in sockets.iter() {
let local_addr = socket.local_addr().expect("Socket must be bound");
ip_to_ports
.entry(local_addr.ip())
.or_default()
.entry(local_addr.port())
.or_default()
.push(socket);
}
for (bind_ip, ports_to_socks_map) in ip_to_ports {
let ports: Vec<u16> = ports_to_socks_map.keys().copied().collect();
info!("Checking that udp ports {ports:?} are reachable from bind IP {bind_ip:?}");
'outer: for chunk_to_check in ports.chunks(MAX_PORT_COUNT_PER_MESSAGE) {
let ports_to_check = chunk_to_check.to_vec();
for attempt in 0..retry_count {
if attempt > 0 {
error!("There are some udp ports with no response!! Retrying...");
}
let sockets_to_check: Vec<UdpSocket> = ports_to_check
.iter()
.flat_map(|port| ports_to_socks_map.get(port).unwrap())
.map(|&s| s.try_clone().expect("Unable to clone UDP socket"))
.collect();
let _ = ip_echo_server_request_with_binding(
ip_echo_server_addr,
IpEchoServerMessage::new(&[], &ports_to_check),
bind_ip,
)
.await
.map_err(|err| warn!("ip_echo_server request failed: {err}"));
let reachable_ports = Arc::new(RwLock::new(HashSet::new()));
let mut checkers = JoinSet::new();
for socket in sockets_to_check {
let port = socket.local_addr().expect("Socket should be bound").port();
let reachable_ports = reachable_ports.clone();
checkers.spawn_blocking(move || {
let start = Instant::now();
let original_read_timeout = socket.read_timeout().unwrap();
socket
.set_read_timeout(Some(Duration::from_millis(250)))
.unwrap();
loop {
if reachable_ports.read().unwrap().contains(&port)
|| Instant::now().duration_since(start) >= timeout
{
break;
}
let recv_result = socket.recv(&mut [0; 1]);
debug!("Waited for incoming datagram on udp/{port}: {recv_result:?}");
if recv_result.is_ok() {
reachable_ports.write().unwrap().insert(port);
break;
}
}
socket.set_read_timeout(original_read_timeout).unwrap();
});
}
loop {
let next = checkers.join_next().await;
let Some(r) = next else {
break;
};
r.expect("Threads should exit cleanly");
}
let reachable_ports = Arc::into_inner(reachable_ports)
.expect("Single owner expected")
.into_inner()
.expect("No threads should hold the lock");
info!(
"checked udp ports: {ports_to_check:?}, reachable udp ports: \
{reachable_ports:?}"
);
if reachable_ports.len() == ports_to_check.len() {
continue 'outer; }
}
error!("Maximum retry count reached. Some ports for IP {bind_ip} unreachable.");
return false;
}
}
true
}