use core::{net::SocketAddr, time::Duration};
use std::{
collections::HashSet,
sync::{Arc, RwLock},
};
use kameo::{
actor::ActorRef,
message::{Context, Message},
};
use tokio::task::JoinSet;
use ts_keys::{DiscoPublicKey, NodePublicKey};
use ts_magicsock::{BindingVerifier, DirectTransport, MagicSock, SelfEndpoint};
use ts_transport::{
BatchRecvIter, PeerId, PeerLookup, UnderlayTransport, UnderlayTransportExt, UnderlayTransportId,
};
use crate::{
Env, Error,
dataplane::{DataplaneActor, NewUnderlayTransport, UnderlayFromDataplane, UnderlayToDataplane},
multiderp::{self, Multiderp},
peer_tracker::{PeerDb, PeerState},
};
const PING_INTERVAL: Duration = Duration::from_secs(2);
const STUN_PROBE_INTERVAL: Duration = Duration::from_secs(30);
const ADVERTISE_INTERVAL: Duration = Duration::from_secs(5);
#[derive(Clone)]
pub struct EndpointAdvertisement {
pub endpoints: Arc<Vec<SelfEndpoint>>,
}
const BIND_ADDR: &str = "0.0.0.0:0";
const BIND_ADDR_V6: &str = "[::]:0";
async fn bind_underlay_addr(
enable_ipv6: bool,
our_disco: ts_keys::DiscoPrivateKey,
our_node_key: NodePublicKey,
) -> Result<MagicSock, ts_magicsock::Error> {
if !enable_ipv6 {
let v4: SocketAddr = BIND_ADDR.parse().expect("valid bind address");
return MagicSock::bind(v4, our_disco, our_node_key).await;
}
let v6: SocketAddr = BIND_ADDR_V6.parse().expect("valid bind address");
match MagicSock::bind(v6, our_disco, our_node_key).await {
Ok(sock) => Ok(sock),
Err(e) => {
tracing::warn!(
error = %e,
%v6,
"dual-stack underlay bind failed (host IPv6 disabled?); falling back to IPv4-only",
);
let v4: SocketAddr = BIND_ADDR.parse().expect("valid bind address");
MagicSock::bind(v4, our_disco, our_node_key).await
}
}
}
pub struct DirectManager {
sock: Option<Arc<MagicSock>>,
transport_id: Option<UnderlayTransportId>,
peer_db: Arc<RwLock<Option<Arc<PeerDb>>>>,
#[allow(dead_code)]
tasks: JoinSet<()>,
}
#[kameo::messages]
impl DirectManager {
#[message]
pub fn direct_transport_id(&self) -> Option<UnderlayTransportId> {
self.transport_id
}
#[message]
pub fn peers_with_direct_path(&self, ids: Vec<PeerId>) -> HashSet<PeerId> {
let mut ready = HashSet::new();
let Some(sock) = self.sock.as_ref() else {
return ready;
};
let db = poisoned_read(&self.peer_db);
let Some(db) = db.as_ref() else {
return ready;
};
for id in ids {
let Some((_, node)) = db.get(&id) else {
continue;
};
let Some(disco) = node.disco_key else {
continue;
};
if sock.best_addr(&disco).is_some() {
ready.insert(id);
}
}
ready
}
#[message]
pub async fn rebind(&self) -> Result<(), ts_magicsock::Error> {
match self.sock.as_ref() {
Some(sock) => sock.rebind().await,
None => Ok(()),
}
}
}
fn verify_binding(
peer_db: &RwLock<Option<Arc<PeerDb>>>,
disco: &DiscoPublicKey,
claimed_node_key: Option<&NodePublicKey>,
) -> bool {
let db = poisoned_read(peer_db);
let Some(db) = db.as_ref() else {
return false;
};
let Some((_, node)) = db.get(disco) else {
return false;
};
match claimed_node_key {
Some(claimed) => node.node_key == *claimed,
None => true,
}
}
fn poisoned_read(
lock: &RwLock<Option<Arc<PeerDb>>>,
) -> std::sync::RwLockReadGuard<'_, Option<Arc<PeerDb>>> {
lock.read().unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn poisoned_write(
lock: &RwLock<Option<Arc<PeerDb>>>,
) -> std::sync::RwLockWriteGuard<'_, Option<Arc<PeerDb>>> {
lock.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
struct DiscoPeerLookup(Arc<RwLock<Option<Arc<PeerDb>>>>);
impl PeerLookup<PeerId, DiscoPublicKey> for DiscoPeerLookup {
fn lookup_key(&self, id: PeerId) -> Option<DiscoPublicKey> {
let db = poisoned_read(&self.0);
let db = db.as_ref()?;
let (_, node) = db.get(&id)?;
node.disco_key
}
}
impl PeerLookup<DiscoPublicKey, PeerId> for DiscoPeerLookup {
fn lookup_key(&self, key: DiscoPublicKey) -> Option<PeerId> {
let db = poisoned_read(&self.0);
let db = db.as_ref()?;
let (id, _) = db.get(&key)?;
Some(id)
}
}
async fn run_direct(
transport: impl UnderlayTransport<PeerKey = PeerId, Error = ts_magicsock::Error>,
mut from_dataplane: UnderlayFromDataplane,
to_dataplane: UnderlayToDataplane,
mut shutdown: tokio::sync::watch::Receiver<bool>,
) {
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => break,
from_direct = transport.recv() => {
for ret in from_direct.batch_iter() {
match ret {
Ok((peer_id, pkts)) => {
let pkts = pkts.into_iter().collect::<Vec<_>>();
if to_dataplane.send((peer_id, pkts)).is_err() {
tracing::error!("underlay receive channel closed");
return;
}
}
Err(e) => {
tracing::trace!(error = %e, "ignoring undecodable direct packet");
}
}
}
}
from_net = from_dataplane.recv() => {
let Some(from_net) = from_net else {
tracing::warn!("direct underlay queue closed");
break;
};
if let Err(e) = transport.send([from_net]).await {
tracing::trace!(error = %e, "sending direct packet");
}
}
}
}
}
async fn run_pinger(sock: Arc<MagicSock>, mut shutdown: tokio::sync::watch::Receiver<bool>) {
let mut interval = tokio::time::interval(PING_INTERVAL);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => break,
_ = interval.tick() => {
if let Err(e) = sock.send_pings().await {
tracing::trace!(error = %e, "sending disco pings");
}
}
}
}
}
async fn run_stun_prober(
sock: Arc<MagicSock>,
multiderp: ActorRef<Multiderp>,
mut shutdown: tokio::sync::watch::Receiver<bool>,
) {
let mut interval = tokio::time::interval(STUN_PROBE_INTERVAL);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => break,
_ = interval.tick() => {
let servers = match multiderp.ask(multiderp::StunServersV4).await {
Ok((servers,)) => servers,
Err(e) => {
tracing::trace!(error = %e, "querying stun servers from multiderp");
continue;
}
};
probe_stun_servers_once(&sock, &servers).await;
}
}
}
}
async fn probe_stun_servers_once(sock: &MagicSock, servers: &[SocketAddr]) {
for &s in servers {
if let Err(e) = sock.send_stun_request(s).await {
tracing::trace!(error = %e, server = %s, "sending stun binding request");
}
}
}
async fn run_advertiser(
sock: Arc<MagicSock>,
env: Env,
mut shutdown: tokio::sync::watch::Receiver<bool>,
) {
let mut interval = tokio::time::interval(ADVERTISE_INTERVAL);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut last: Vec<SelfEndpoint> = Vec::new();
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => break,
_ = interval.tick() => {
let mut eps = sock.self_endpoints();
eps.sort_by_key(|e| (e.addr, e.ty as u8));
if eps == last {
continue;
}
last = eps.clone();
if let Err(e) = env
.publish(EndpointAdvertisement {
endpoints: Arc::new(eps),
})
.await
{
tracing::error!(error = %e, "publishing endpoint advertisement");
}
}
}
}
}
async fn run_call_me_maybe(
sock: Arc<MagicSock>,
peer_db: Arc<RwLock<Option<Arc<PeerDb>>>>,
multiderp: ActorRef<Multiderp>,
mut shutdown: tokio::sync::watch::Receiver<bool>,
) {
let mut interval = tokio::time::interval(ADVERTISE_INTERVAL);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => break,
_ = interval.tick() => {
let have_reflexive = sock
.self_endpoints()
.iter()
.any(|e| e.ty == ts_magicsock::SelfEndpointType::Stun);
if !have_reflexive {
continue;
}
let targets: Vec<(ts_keys::NodePublicKey, DiscoPublicKey, Option<ts_derp::RegionId>)> = {
let db = poisoned_read(&peer_db);
let Some(db) = db.as_ref() else { continue; };
db.peers()
.values()
.filter_map(|node| {
let disco = node.disco_key?;
if sock.best_addr(&disco).is_some() {
return None;
}
Some((node.node_key, disco, node.derp_region))
})
.collect()
};
for (node_key, disco, netmap_region) in targets {
let region = match netmap_region {
Some(region) => Some(region),
None => {
match multiderp.ask(multiderp::RegionForNode { node: node_key }).await {
Ok(region) => region,
Err(e) => {
tracing::trace!(error = %e, "inferring call-me-maybe relay region");
None
}
}
}
};
let Some(region) = region else {
continue;
};
let frame = match sock.seal_call_me_maybe(&disco) {
Ok(frame) => frame,
Err(e) => {
tracing::trace!(error = %e, "sealing call-me-maybe");
continue;
}
};
if let Err(e) = multiderp
.tell(multiderp::SendDisco {
peer: node_key,
region,
frame,
})
.await
{
tracing::trace!(error = %e, "relaying call-me-maybe to multiderp");
}
}
}
}
}
}
impl kameo::Actor for DirectManager {
type Args = (Env, ActorRef<DataplaneActor>, ActorRef<Multiderp>);
type Error = Error;
async fn on_start(
(env, dataplane, multiderp): Self::Args,
slf: ActorRef<Self>,
) -> Result<Self, Self::Error> {
env.subscribe::<Arc<PeerState>>(&slf).await?;
let peer_db: Arc<RwLock<Option<Arc<PeerDb>>>> = Default::default();
let mut tasks = JoinSet::new();
let verifier_db = peer_db.clone();
let binding_verifier: BindingVerifier = Arc::new(move |disco, claimed_node_key| {
verify_binding(&verifier_db, disco, claimed_node_key)
});
let sock = match bind_underlay_addr(
env.enable_ipv6,
env.keys.disco_keys.private,
env.keys.node_keys.public,
)
.await
{
Ok(sock) => Arc::new(
sock.with_enable_ipv6(env.enable_ipv6)
.with_binding_verifier(binding_verifier),
),
Err(e) => {
tracing::error!(
error = %e,
enable_ipv6 = env.enable_ipv6,
"direct underlay udp bind failed; direct manager inert, staying DERP-only",
);
return Ok(Self {
sock: None,
transport_id: None,
peer_db,
tasks,
});
}
};
let (transport_id, from_dataplane, to_dataplane) =
dataplane.ask(NewUnderlayTransport).await?;
let transport =
DirectTransport::new(sock.clone()).with_key_lookup(DiscoPeerLookup(peer_db.clone()));
tasks.spawn(run_direct(
transport,
from_dataplane,
to_dataplane,
env.shutdown.clone(),
));
tasks.spawn(run_pinger(sock.clone(), env.shutdown.clone()));
tasks.spawn(run_advertiser(
sock.clone(),
env.clone(),
env.shutdown.clone(),
));
tasks.spawn(run_stun_prober(
sock.clone(),
multiderp.clone(),
env.shutdown.clone(),
));
if let Err(e) = multiderp
.tell(multiderp::SetDirectSock { sock: sock.clone() })
.await
{
tracing::warn!(error = %e, "could not install direct socket on multiderp");
}
tasks.spawn(run_call_me_maybe(
sock.clone(),
peer_db.clone(),
multiderp,
env.shutdown.clone(),
));
Ok(Self {
sock: Some(sock),
transport_id: Some(transport_id),
peer_db,
tasks,
})
}
}
impl Message<Arc<PeerState>> for DirectManager {
type Reply = ();
async fn handle(&mut self, msg: Arc<PeerState>, _ctx: &mut Context<Self, Self::Reply>) {
if let Some(sock) = self.sock.as_ref() {
let mut live = HashSet::new();
for node in msg.peers.peers().values() {
let Some(disco) = node.disco_key else {
continue;
};
live.insert(disco);
sock.set_netmap_endpoints(disco, node.underlay_addresses.iter().copied());
}
sock.retain_peers(&live);
}
let mut db = poisoned_write(&self.peer_db);
*db = Some(msg.peers.clone());
}
}
#[cfg(test)]
mod tests {
use ts_control::{Node, StableNodeId, TailnetAddress};
use ts_keys::{DiscoPrivateKey, NodePrivateKey};
use super::*;
use crate::peer_tracker::PeerDb;
fn node_with_keys(disco: DiscoPublicKey, node_key: NodePublicKey, stable: &str) -> Node {
Node {
id: 1,
stable_id: StableNodeId(stable.to_string()),
hostname: "peer".to_string(),
user_id: 0,
tailnet: Some("ts.net".to_string()),
tags: vec![],
tailnet_address: TailnetAddress {
ipv4: "100.64.0.9/32".parse().unwrap(),
ipv6: "fd7a::9/128".parse().unwrap(),
},
node_key,
node_key_expiry: None,
key_signature: vec![],
machine_key: None,
disco_key: Some(disco),
accepted_routes: vec![],
underlay_addresses: vec![],
derp_region: None,
cap: Default::default(),
cap_map: Default::default(),
peerapi_port: None,
peerapi_dns_proxy: false,
is_wireguard_only: false,
exit_node_dns_resolvers: vec![],
peer_relay: false,
service_vips: Default::default(),
}
}
fn db_with(node: Node) -> Arc<RwLock<Option<Arc<PeerDb>>>> {
let mut db = PeerDb::default();
db.upsert(&node);
Arc::new(RwLock::new(Some(Arc::new(db))))
}
#[test]
fn verify_binding_ping_requires_exact_node_key() {
let disco = DiscoPrivateKey::random().public_key();
let node_key = NodePrivateKey::random().public_key();
let other_key = NodePrivateKey::random().public_key();
let db = db_with(node_with_keys(disco, node_key, "n1"));
assert!(
verify_binding(&db, &disco, Some(&node_key)),
"correct disco<->node-key binding must be accepted"
);
assert!(
!verify_binding(&db, &disco, Some(&other_key)),
"a claimed node key that is not the bound one must be rejected"
);
let unknown_disco = DiscoPrivateKey::random().public_key();
assert!(
!verify_binding(&db, &unknown_disco, Some(&node_key)),
"a disco key not in the netmap must be rejected"
);
let empty: Arc<RwLock<Option<Arc<PeerDb>>>> = Default::default();
assert!(
!verify_binding(&empty, &disco, Some(&node_key)),
"with no netmap loaded the verifier fails closed"
);
}
#[test]
fn verify_binding_call_me_maybe_is_membership_only() {
let disco = DiscoPrivateKey::random().public_key();
let node_key = NodePrivateKey::random().public_key();
let db = db_with(node_with_keys(disco, node_key, "n1"));
assert!(
verify_binding(&db, &disco, None),
"a netmap-member disco key must be accepted for a CallMeMaybe"
);
let stranger = DiscoPrivateKey::random().public_key();
assert!(
!verify_binding(&db, &stranger, None),
"a non-member disco key must be rejected for a CallMeMaybe"
);
}
#[tokio::test]
async fn probe_stun_servers_once_sends_binding_request() {
let sock = Arc::new(
MagicSock::bind(
BIND_ADDR.parse().unwrap(),
DiscoPrivateKey::random(),
NodePrivateKey::random().public_key(),
)
.await
.unwrap(),
);
let sink = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server: SocketAddr = sink.local_addr().unwrap();
probe_stun_servers_once(&sock, &[server]).await;
let mut buf = [0u8; 64];
let (n, _from) = tokio::time::timeout(Duration::from_secs(2), sink.recv_from(&mut buf))
.await
.expect("a STUN binding request must arrive at the v4 server")
.unwrap();
assert_eq!(
n, 20,
"a STUN Binding Request is exactly the 20-byte header"
);
assert_eq!(
&buf[0..2],
&0x0001u16.to_be_bytes(),
"message type must be Binding Request (0x0001)"
);
assert_eq!(
&buf[4..8],
&0x2112_A442u32.to_be_bytes(),
"the STUN magic cookie must be present at bytes[4..8]"
);
}
#[tokio::test]
async fn bind_underlay_addr_v4_default_is_unchanged() {
let sock = bind_underlay_addr(
false,
DiscoPrivateKey::random(),
NodePrivateKey::random().public_key(),
)
.await
.expect("the IPv4 underlay bind must succeed");
let local = sock.local_addr().expect("a bound socket has a local addr");
assert!(
local.is_ipv4(),
"with enable_ipv6 == false the underlay must bind the v4 family, got {local}"
);
assert_eq!(
local.ip(),
"0.0.0.0".parse::<core::net::IpAddr>().unwrap(),
"the v4 default binds the unspecified v4 address"
);
}
#[tokio::test]
async fn bind_underlay_addr_v6_attempts_dual_stack_or_falls_back() {
let sock = bind_underlay_addr(
true,
DiscoPrivateKey::random(),
NodePrivateKey::random().public_key(),
)
.await
.expect("bind must succeed (dual-stack, else inert IPv4 fallback) and never error");
let local = sock.local_addr().expect("a bound socket has a local addr");
match tokio::net::UdpSocket::bind("[::]:0").await {
Ok(_) => assert!(
local.is_ipv6(),
"on a v6-capable host enable_ipv6 == true must bind the v6 (dual-stack) family, \
got {local}"
),
Err(_) => assert!(
local.is_ipv4(),
"on a host that cannot bind v6 the inert fallback must yield a v4 socket, got \
{local}"
),
}
}
#[tokio::test]
async fn probe_stun_servers_once_empty_list_is_noop() {
let sock = Arc::new(
MagicSock::bind(
BIND_ADDR.parse().unwrap(),
DiscoPrivateKey::random(),
NodePrivateKey::random().public_key(),
)
.await
.unwrap(),
);
probe_stun_servers_once(&sock, &[]).await;
}
}