use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use anyhow::{bail, Context, Result};
use netwatch::UdpSocket;
use tokio::{sync::oneshot, time::Instant};
use tokio_util::task::AbortOnDropHandle;
use tracing::{debug, error, info_span, trace, warn, Instrument};
use crate::{
defaults::timeouts::HAIRPIN_CHECK_TIMEOUT,
netcheck::{self, reportgen, Inflight},
stun,
};
#[derive(Debug)]
pub(super) struct Client {
addr: Option<oneshot::Sender<Message>>,
_drop_guard: AbortOnDropHandle<()>,
}
impl Client {
pub(super) fn new(netcheck: netcheck::Addr, reportgen: reportgen::Addr) -> Self {
let (addr, msg_rx) = oneshot::channel();
let actor = Actor {
msg_rx,
netcheck,
reportgen,
};
let task =
tokio::spawn(async move { actor.run().await }.instrument(info_span!("hairpin.actor")));
Self {
addr: Some(addr),
_drop_guard: AbortOnDropHandle::new(task),
}
}
pub(super) fn has_started(&self) -> bool {
self.addr.is_none()
}
pub(super) fn start_check(&mut self, dst: SocketAddrV4) {
if let Some(addr) = self.addr.take() {
addr.send(Message::StartCheck(dst)).ok();
}
}
}
#[derive(Debug)]
enum Message {
StartCheck(SocketAddrV4),
}
#[derive(Debug)]
struct Actor {
msg_rx: oneshot::Receiver<Message>,
netcheck: netcheck::Addr,
reportgen: reportgen::Addr,
}
impl Actor {
async fn run(self) {
match self.run_inner().await {
Ok(_) => trace!("hairpin actor finished successfully"),
Err(err) => error!("Hairpin actor failed: {err:#}"),
}
}
async fn run_inner(self) -> Result<()> {
let socket = UdpSocket::bind_v4(0).context("Failed to bind hairpin socket on 0.0.0.0:0")?;
if let Err(err) = Self::prepare_hairpin(&socket).await {
warn!("unable to send hairpin prep: {err:#}");
}
let Ok(Message::StartCheck(dst)) = self.msg_rx.await else {
return Ok(());
};
let txn = stun::TransactionId::default();
trace!(%txn, "Sending hairpin with transaction ID");
let (stun_tx, stun_rx) = oneshot::channel();
let inflight = Inflight {
txn,
start: Instant::now(), s: stun_tx,
};
let (msg_response_tx, msg_response_rx) = oneshot::channel();
self.netcheck
.send(netcheck::Message::InFlightStun(inflight, msg_response_tx))
.await
.context("netcheck actor gone")?;
msg_response_rx.await.context("netcheck actor died")?;
if let Err(err) = socket.send_to(&stun::request(txn), dst).await {
warn!(%dst, "failed to send hairpin check");
return Err(err.into());
}
let now = Instant::now();
let hairpinning_works = match tokio::time::timeout(HAIRPIN_CHECK_TIMEOUT, stun_rx).await {
Ok(Ok(_)) => true,
Ok(Err(_)) => bail!("netcheck actor dropped stun response channel"),
Err(_) => false, };
debug!(
"hairpinning done in {:?}, res: {:?}",
now.elapsed(),
hairpinning_works
);
self.reportgen
.send(super::Message::HairpinResult(hairpinning_works))
.await
.context("Failed to send hairpin result to reportgen actor")?;
trace!("reportgen notified");
Ok(())
}
async fn prepare_hairpin(socket: &UdpSocket) -> Result<()> {
let documentation_ip = SocketAddr::from((Ipv4Addr::new(203, 0, 113, 1), 12345));
socket
.send_to(
b"tailscale netcheck; see https://github.com/tailscale/tailscale/issues/188",
documentation_ip,
)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use bytes::BytesMut;
use tokio::sync::mpsc;
use tracing::info;
use super::*;
#[tokio::test]
async fn test_hairpin_success() {
for i in 0..100 {
let now = Instant::now();
test_hairpin(true).await;
println!("done round {} in {:?}", i + 1, now.elapsed());
}
}
#[tokio::test]
async fn test_hairpin_failure() {
test_hairpin(false).await;
}
async fn test_hairpin(hairpinning_works: bool) {
let _guard = iroh_test::logging::setup();
let (netcheck_tx, mut netcheck_rx) = mpsc::channel(32);
let netcheck_addr = netcheck::Addr {
sender: netcheck_tx,
};
let (reportstate_tx, mut reportstate_rx) = mpsc::channel(32);
let reportstate_addr = reportgen::Addr {
sender: reportstate_tx,
};
let mut actor = Client::new(netcheck_addr, reportstate_addr);
let public_sock = UdpSocket::bind_local_v4(0).unwrap();
let ipp_v4 = match public_sock.local_addr().unwrap() {
SocketAddr::V4(ipp) => ipp,
SocketAddr::V6(_) => unreachable!(),
};
actor.start_check(ipp_v4);
let dummy_netcheck = tokio::spawn(
async move {
let netcheck::Message::InFlightStun(inflight, resp_tx) =
netcheck_rx.recv().await.unwrap()
else {
panic!("Wrong message received");
};
resp_tx.send(()).unwrap();
let mut buf = BytesMut::zeroed(64 << 10);
let (count, addr) = public_sock.recv_from(&mut buf).await.unwrap();
info!(
addr=?public_sock.local_addr().unwrap(),
%count,
"Forwarding payload to hairpin actor",
);
let payload = buf.split_to(count).freeze();
let txn = stun::parse_binding_request(&payload).unwrap();
assert_eq!(txn, inflight.txn);
if hairpinning_works {
inflight.s.send((Duration::new(0, 1), addr)).unwrap();
} else {
info!("Received hairpin request, not sending response");
tokio::time::sleep(HAIRPIN_CHECK_TIMEOUT * 8).await;
}
}
.instrument(info_span!("dummy-netcheck")),
);
match reportstate_rx.recv().await {
Some(reportgen::Message::HairpinResult(val)) => assert_eq!(val, hairpinning_works),
Some(msg) => panic!("Unexpected reportstate message: {msg:?}"),
None => panic!("reportstate mpsc has no senders"),
}
dummy_netcheck.await.expect("error in dummy netcheck actor");
}
#[tokio::test]
async fn test_client_drop() {
let _guard = iroh_test::logging::setup();
let (netcheck_tx, _netcheck_rx) = mpsc::channel(32);
let netcheck_addr = netcheck::Addr {
sender: netcheck_tx,
};
let (reportstate_tx, _reportstate_rx) = mpsc::channel(32);
let reportstate_addr = reportgen::Addr {
sender: reportstate_tx,
};
let mut client = Client::new(netcheck_addr, reportstate_addr);
let addr = client.addr.take();
drop(client);
tokio::task::yield_now().await;
let ipp_v4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10);
match addr.unwrap().send(Message::StartCheck(ipp_v4)) {
Err(_) => (),
_ => panic!("actor still running"),
}
}
}