use std::{
net::{IpAddr, SocketAddr},
time::{Duration, Instant},
};
use local_channel::mpsc::{Receiver, Sender, channel};
use tokio::{
net::TcpStream,
task::{JoinHandle, spawn_local},
time,
};
use tokio_util::sync::CancellationToken;
use mousehop_ipc::ClientHandle;
use crate::client::ClientManager;
const PROBE_INTERVAL: Duration = Duration::from_secs(5);
const PROBE_TIMEOUT: Duration = Duration::from_secs(1);
pub(crate) struct ProbeResult {
pub handle: ClientHandle,
pub ip: IpAddr,
pub rtt_micros: Option<u32>,
}
pub(crate) struct LatencyProber {
cancellation_token: CancellationToken,
task: Option<JoinHandle<()>>,
event_rx: Receiver<ProbeResult>,
}
impl LatencyProber {
pub(crate) fn new(client_manager: ClientManager) -> Self {
let (event_tx, event_rx) = channel();
let cancellation_token = CancellationToken::new();
let task = LatencyTask {
client_manager,
event_tx,
cancellation_token: cancellation_token.clone(),
};
Self {
cancellation_token,
task: Some(spawn_local(task.run())),
event_rx,
}
}
pub(crate) async fn event(&mut self) -> ProbeResult {
self.event_rx.recv().await.expect("channel closed")
}
pub(crate) async fn terminate(&mut self) {
self.cancellation_token.cancel();
if let Some(task) = self.task.take() {
let _ = task.await;
}
}
}
struct LatencyTask {
client_manager: ClientManager,
event_tx: Sender<ProbeResult>,
cancellation_token: CancellationToken,
}
impl LatencyTask {
async fn run(self) {
let cancellation_token = self.cancellation_token.clone();
tokio::select! {
_ = self.probe_loop() => {},
_ = cancellation_token.cancelled() => {},
}
}
async fn probe_loop(&self) {
let mut tick = time::interval(PROBE_INTERVAL);
tick.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
loop {
tick.tick().await;
let targets = self.client_manager.probe_targets();
for (handle, port, ips) in targets {
for ip in ips {
let tx = self.event_tx.clone();
spawn_local(async move {
let rtt_micros = probe_addr(SocketAddr::new(ip, port)).await;
let _ = tx.send(ProbeResult {
handle,
ip,
rtt_micros,
});
});
}
}
}
}
}
async fn probe_addr(addr: SocketAddr) -> Option<u32> {
let start = Instant::now();
match time::timeout(PROBE_TIMEOUT, TcpStream::connect(addr)).await {
Ok(Ok(_stream)) => Some(elapsed_micros(start)),
Ok(Err(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
Some(elapsed_micros(start))
}
Ok(Err(_)) | Err(_) => None,
}
}
fn elapsed_micros(start: Instant) -> u32 {
start.elapsed().as_micros().min(u32::MAX as u128) as u32
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[cfg(not(target_os = "windows"))]
#[tokio::test]
async fn refused_connection_counts_as_reachable() {
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 1));
assert!(
probe_addr(addr).await.is_some(),
"a refused (RST) connection should be treated as reachable"
);
}
#[tokio::test]
async fn unroutable_address_times_out_as_unreachable() {
let addr = SocketAddr::from((Ipv4Addr::new(192, 0, 2, 1), 4252));
assert!(
probe_addr(addr).await.is_none(),
"an unroutable address should be reported as unreachable"
);
}
}