use core::net::{SocketAddr, SocketAddrV4};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use kameo::{
actor::ActorRef,
error::SendError,
message::{Context, Message},
};
use tokio::{
sync::{mpsc, watch},
task::JoinSet,
};
use ts_control::DerpRegion;
use ts_derp::RegionId;
use ts_keys::{NodeKeyPair, NodePublicKey};
use ts_magicsock::MagicSock;
use ts_transport::{
BatchRecvIter, PeerId, UnderlayTransport, UnderlayTransportExt, UnderlayTransportId,
};
use crate::{
Env, Error,
dataplane::{DataplaneActor, NewUnderlayTransport, UnderlayFromDataplane, UnderlayToDataplane},
derp_latency::DerpLatencyMeasurement,
peer_tracker::{PeerDb, PeerState},
};
pub struct Multiderp {
env: Env,
dataplane: ActorRef<DataplaneActor>,
derps: HashMap<RegionId, RegionEntry>,
regions: HashMap<RegionId, DerpRegion>,
current_home_derp: Option<RegionId>,
peer_db: Arc<RwLock<Option<Arc<PeerDb>>>>,
observed_routes: Arc<RwLock<HashMap<PeerId, RegionId>>>,
direct_sock: Arc<RwLock<Option<Arc<MagicSock>>>>,
tasks: JoinSet<()>,
}
struct RegionEntry {
transport_id: UnderlayTransportId,
home_derp: watch::Sender<bool>,
disco_tx: mpsc::Sender<(NodePublicKey, Vec<u8>)>,
}
impl Multiderp {
#[tracing::instrument(skip_all, fields(region_id = %id))]
async fn ensure_region(
&mut self,
id: RegionId,
region: &DerpRegion,
mut shutdown: watch::Receiver<bool>,
) {
if self.derps.contains_key(&id) {
tracing::trace!("region already existed");
return;
}
let region = region.clone();
let keys = self.env.keys.node_keys;
let (transport_id, mut up, down) = match self.dataplane.ask(NewUnderlayTransport).await {
Ok(val) => val,
Err(SendError::ActorNotRunning(..) | SendError::ActorStopped) => {
if !*shutdown.borrow() {
panic!("dataplane has stopped but we're not shutting down");
}
return;
}
Err(e) => {
tracing::error!(error = %e, "multiderp: failed to set up DERP region; skipping");
return;
}
};
let (home_derp_tx, mut home_derp_rx) = watch::channel(false);
let (disco_tx, mut disco_rx) = mpsc::channel::<(NodePublicKey, Vec<u8>)>(8);
let peer_db = self.peer_db.clone();
let direct_sock = self.direct_sock.clone();
let observed_routes = self.observed_routes.clone();
self.tasks.spawn(async move {
while !*shutdown.borrow() {
tokio::select! {
_ = shutdown.changed() => {
break;
},
ret = run_derp_once(
id,
®ion,
keys,
&down,
&mut up,
&mut home_derp_rx,
&mut disco_rx,
&peer_db,
&direct_sock,
&observed_routes,
) => if let Err(e) = ret {
tracing::error!(error = %e, region_id = %id, "running derp client");
tokio::time::sleep(Duration::from_millis(500)).await;
},
}
if up.is_closed() {
tracing::warn!(region_id = %id, "underlay up channel closed!");
break;
}
if down.is_closed() {
tracing::warn!(region_id = %id, "underlay down channel closed!");
break;
}
}
});
self.derps.insert(
id,
RegionEntry {
transport_id,
home_derp: home_derp_tx,
disco_tx,
},
);
}
}
#[kameo::messages]
impl Multiderp {
#[message]
pub fn transport_id_for_region(&self, id: RegionId) -> Option<UnderlayTransportId> {
Some(self.derps.get(&id)?.transport_id)
}
#[message]
pub fn region_for_peer(&self, peer: PeerId) -> Option<RegionId> {
let observed = poisoned_read(&self.observed_routes).get(&peer).copied();
resolve_region_for_peer(observed, self.current_home_derp, |r| {
self.derps.contains_key(&r)
})
}
#[message]
pub fn region_for_node(&self, node: NodePublicKey) -> Option<RegionId> {
let peer = {
let db = poisoned_read(&self.peer_db);
let (id, _) = db.as_ref()?.get(&node)?;
id
};
let observed = poisoned_read(&self.observed_routes).get(&peer).copied();
resolve_region_for_peer(observed, self.current_home_derp, |r| {
self.derps.contains_key(&r)
})
}
#[message]
pub fn stun_servers_v4(&self) -> (Vec<SocketAddr>,) {
(stun_servers_from_regions(self.regions.values()),)
}
#[message]
pub fn set_direct_sock(&mut self, sock: Arc<MagicSock>) {
*poisoned_write(&self.direct_sock) = Some(sock);
}
#[message]
pub async fn send_disco(&mut self, peer: NodePublicKey, region: RegionId, frame: Vec<u8>) {
let Some(region_info) = self.regions.get(®ion).cloned() else {
tracing::warn!(region_id = %region, "no derp region info, dropping disco frame");
return;
};
self.ensure_region(region, ®ion_info, self.env.shutdown.clone())
.await;
let Some(entry) = self.derps.get(®ion) else {
tracing::warn!(region_id = %region, "region not established, dropping disco frame");
return;
};
if let Err(e) = entry.disco_tx.try_send((peer, frame)) {
tracing::trace!(error = %e, region_id = %region, "disco relay queue full or closed, dropping frame");
}
}
}
fn stun_servers_from_regions<'a>(
regions: impl IntoIterator<Item = &'a DerpRegion>,
) -> Vec<SocketAddr> {
let mut servers = Vec::new();
for region in regions {
for srv in ®ion.servers {
let Some(stun_port) = srv.stun_port else {
continue;
};
if let ts_derp::IpUsage::FixedAddr(v4) = srv.ipv4 {
servers.push(SocketAddr::V4(SocketAddrV4::new(v4, stun_port)));
}
}
}
servers
}
fn poisoned_read<T>(lock: &RwLock<T>) -> std::sync::RwLockReadGuard<'_, T> {
lock.read().unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn poisoned_write<T>(lock: &RwLock<T>) -> std::sync::RwLockWriteGuard<'_, T> {
lock.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn resolve_region_for_peer(
observed: Option<RegionId>,
home: Option<RegionId>,
region_is_live: impl Fn(RegionId) -> bool,
) -> Option<RegionId> {
observed.or(home).filter(|r| region_is_live(*r))
}
struct PeerDbLookup<'a>(&'a RwLock<Option<Arc<PeerDb>>>);
impl ts_transport::PeerLookup<PeerId, NodePublicKey> for PeerDbLookup<'_> {
fn lookup_key(&self, id: PeerId) -> Option<NodePublicKey> {
let db = poisoned_read(self.0);
let db = db.as_ref()?;
let (_, node) = db.get(&id)?;
Some(node.node_key)
}
}
impl ts_transport::PeerLookup<NodePublicKey, PeerId> for PeerDbLookup<'_> {
fn lookup_key(&self, key: NodePublicKey) -> Option<PeerId> {
let db = poisoned_read(self.0);
let db = db.as_ref()?;
let (id, _) = db.get(&key)?;
Some(id)
}
}
#[tracing::instrument(skip_all, fields(region_id = %id), name = "derp runner")]
async fn run_derp_once(
id: RegionId,
region: &DerpRegion,
keys: NodeKeyPair,
to_dataplane: &UnderlayToDataplane,
from_dataplane: &mut UnderlayFromDataplane,
home_derp_rx: &mut watch::Receiver<bool>,
disco_rx: &mut mpsc::Receiver<(NodePublicKey, Vec<u8>)>,
peer_db: &RwLock<Option<Arc<PeerDb>>>,
direct_sock: &RwLock<Option<Arc<MagicSock>>>,
observed_routes: &RwLock<HashMap<PeerId, RegionId>>,
) -> Result<(), ts_derp::Error> {
const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(10);
loop {
let mut pending = None;
let mut pending_disco = None;
tracing::trace!("waiting for packet activity or for this to become home derp");
while !*home_derp_rx.borrow_and_update() {
tokio::select! {
_ = home_derp_rx.changed() => {
tracing::trace!(is_home_derp = *home_derp_rx.borrow());
},
from_net = from_dataplane.recv() => {
tracing::trace!("received packet to send");
pending = from_net;
break;
}
disco = disco_rx.recv() => {
tracing::trace!("received disco frame to relay, waking connection");
pending_disco = disco;
break;
}
}
}
tracing::trace!("establishing derp connection");
let client = Arc::new(ts_derp::DefaultClient::connect(®ion.servers, &keys).await?);
let transport = client.clone().with_key_lookup(PeerDbLookup(peer_db));
if let Some(pending) = pending {
tracing::trace!("sending queued packet");
transport.send([pending]).await?;
}
if let Some((node_key, frame)) = pending_disco {
tracing::trace!("relaying queued disco frame");
client.send_one(node_key, &frame).await?;
}
let mut last_activity = Instant::now();
loop {
let span = tracing::trace_span!("derp_loop");
let inactivity_timeout =
(!*home_derp_rx.borrow()).then(|| last_activity + INACTIVITY_TIMEOUT);
tokio::select! {
from_derp = transport.recv() => {
last_activity = Instant::now();
let sock = poisoned_read(direct_sock).clone();
for ret in from_derp.batch_iter() {
let (peer_id, pkts) = ret?;
if poisoned_read(observed_routes).get(&peer_id) != Some(&id) {
poisoned_write(observed_routes).insert(peer_id, id);
tracing::trace!(parent: &span, %peer_id, region_id = %id, "learned observed derp route for peer");
}
let data = demux_relayed_disco(pkts, sock.as_deref());
if data.is_empty() {
continue;
}
tracing::trace!(parent: &span, %peer_id, len = data.len(), "packet from derp server");
let Ok(()) = to_dataplane.send((peer_id, data)) else {
tracing::error!(parent: &span, "underlay receive channel closed");
break;
};
}
},
disco = disco_rx.recv() => {
last_activity = Instant::now();
let Some((node_key, frame)) = disco else {
tracing::warn!(parent: &span, "disco relay queue closed");
break;
};
tracing::trace!(parent: &span, "relaying disco frame over derp");
client.send_one(node_key, &frame).await?;
},
from_net = from_dataplane.recv() => {
last_activity = Instant::now();
let Some(from_net) = from_net else {
tracing::warn!(parent: &span, "transport queue closed");
break;
};
tracing::trace!(parent: &span, peer = %from_net.0, packets = from_net.1.len(), "packets to derp server");
transport.send([from_net]).await?;
},
_ = option_timeout(inactivity_timeout) => {
if !*home_derp_rx.borrow_and_update() {
tracing::trace!(parent: &span, "timed out and not home derp, closing derp conn");
break;
}
},
_ = home_derp_rx.changed() => {
tracing::trace!(is_home_derp = *home_derp_rx.borrow());
},
}
}
}
}
fn demux_relayed_disco(
pkts: impl IntoIterator<Item = ts_packet::PacketMut>,
sock: Option<&MagicSock>,
) -> Vec<ts_packet::PacketMut> {
let mut data = Vec::new();
for mut pkt in pkts {
if ts_magicsock::looks_like_disco(pkt.as_ref())
&& let Some(sock) = sock
&& sock.handle_relayed_call_me_maybe(pkt.as_mut())
{
continue;
}
data.push(pkt);
}
data
}
async fn option_timeout(duration: Option<Instant>) {
match duration {
Some(dur) => tokio::time::sleep_until(dur.into()).await,
None => core::future::pending().await,
}
}
impl kameo::Actor for Multiderp {
type Args = (Env, ActorRef<DataplaneActor>);
type Error = Error;
async fn on_start(
(env, dataplane): Self::Args,
slf: ActorRef<Self>,
) -> Result<Self, Self::Error> {
env.subscribe::<Arc<ts_control::StateUpdate>>(&slf).await?;
env.subscribe::<Arc<PeerState>>(&slf).await?;
env.subscribe::<DerpLatencyMeasurement>(&slf).await?;
Ok(Self {
env,
dataplane,
peer_db: Default::default(),
direct_sock: Default::default(),
observed_routes: Default::default(),
derps: Default::default(),
regions: Default::default(),
tasks: JoinSet::new(),
current_home_derp: None,
})
}
}
impl Message<Arc<ts_control::StateUpdate>> for Multiderp {
type Reply = ();
#[tracing::instrument(skip_all, name = "multiderp map update")]
async fn handle(
&mut self,
msg: Arc<ts_control::StateUpdate>,
_ctx: &mut Context<Self, Self::Reply>,
) {
let Some(derp_map) = &msg.derp else {
return;
};
for (id, region) in derp_map {
self.regions.insert(*id, region.clone());
self.ensure_region(*id, region, self.env.shutdown.clone())
.await;
if let Some(home_derp) = self.current_home_derp
&& *id == home_derp
{
self.derps
.get_mut(&home_derp)
.unwrap()
.home_derp
.send_replace(true);
}
}
}
}
impl Message<Arc<PeerState>> for Multiderp {
type Reply = ();
async fn handle(&mut self, msg: Arc<PeerState>, _ctx: &mut Context<Self, Self::Reply>) {
poisoned_write(&self.observed_routes)
.retain(|peer_id, _| msg.peers.peers().contains_key(peer_id));
let mut db = poisoned_write(&self.peer_db);
*db = Some(msg.peers.clone());
}
}
impl Message<DerpLatencyMeasurement> for Multiderp {
type Reply = ();
async fn handle(&mut self, msg: DerpLatencyMeasurement, _ctx: &mut Context<Self, Self::Reply>) {
let Some(result) = msg.measurement.as_ref().first() else {
tracing::trace!("received home derp measurement message but none was set");
return;
};
if let Some(home_derp) = self.current_home_derp {
self.derps
.get_mut(&home_derp)
.unwrap()
.home_derp
.send_replace(false);
}
if self.current_home_derp.is_none_or(|id| id != result.id) {
self.current_home_derp = Some(result.id);
if let Some(derp) = self.derps.get_mut(&result.id) {
derp.home_derp.send_replace(true);
}
tracing::info!(
region_id = %result.id,
latency_ms = result.latency.as_secs_f32() * 1000.,
"new home derp region selected"
);
}
}
}
#[cfg(test)]
mod tests {
use ts_keys::DiscoPrivateKey;
use ts_packet::PacketMut;
use super::*;
fn localhost() -> std::net::SocketAddr {
"127.0.0.1:0".parse().unwrap()
}
fn allow_all() -> ts_magicsock::BindingVerifier {
Arc::new(|_: &ts_keys::DiscoPublicKey, _: Option<&NodePublicKey>| true)
}
#[tokio::test]
async fn relayed_call_me_maybe_is_demuxed_not_forwarded() {
let our_disco = DiscoPrivateKey::random();
let our_node = ts_keys::NodePrivateKey::random().public_key();
let sock = MagicSock::bind(localhost(), our_disco, our_node)
.await
.unwrap()
.with_binding_verifier(allow_all());
let peer_disco = DiscoPrivateKey::random();
let peer_ep: std::net::SocketAddr = "203.0.113.7:41641".parse().unwrap();
let cmm =
ts_magicsock::seal_call_me_maybe(&peer_disco, &our_disco.public_key(), &[peer_ep])
.unwrap();
let wg = PacketMut::from(&[0x04u8, 0, 0, 0, 1, 2, 3, 4][..]);
let batch = vec![PacketMut::from(&cmm[..]), wg];
let to_dataplane = demux_relayed_disco(batch, Some(&sock));
assert_eq!(
to_dataplane.len(),
1,
"only the data frame reaches the dataplane"
);
assert_eq!(to_dataplane[0].as_ref(), &[0x04u8, 0, 0, 0, 1, 2, 3, 4]);
assert_eq!(
sock.candidate_addrs(&peer_disco.public_key()),
vec![peer_ep],
"the relayed CallMeMaybe's endpoint should be learned"
);
}
#[tokio::test]
async fn without_direct_sock_all_frames_forwarded() {
let our_disco = DiscoPrivateKey::random();
let peer_disco = DiscoPrivateKey::random();
let cmm = ts_magicsock::seal_call_me_maybe(
&peer_disco,
&our_disco.public_key(),
&["203.0.113.7:41641".parse().unwrap()],
)
.unwrap();
let wg = PacketMut::from(&[0x04u8, 9, 9][..]);
let batch = vec![PacketMut::from(&cmm[..]), wg];
let out = demux_relayed_disco(batch, None);
assert_eq!(
out.len(),
2,
"no demux without a direct socket; all frames pass through"
);
}
#[tokio::test]
async fn relayed_ping_is_dropped_not_ponged() {
let our_disco = DiscoPrivateKey::random();
let our_node = ts_keys::NodePrivateKey::random().public_key();
let sock = MagicSock::bind(localhost(), our_disco, our_node)
.await
.unwrap()
.with_binding_verifier(allow_all());
let peer_disco = DiscoPrivateKey::random();
let peer_node = ts_keys::NodePrivateKey::random().public_key();
let tx = ts_magicsock::random_tx_id();
let ping =
ts_magicsock::seal_ping(&peer_disco, peer_node, &our_disco.public_key(), tx).unwrap();
let out = demux_relayed_disco(vec![PacketMut::from(&ping[..])], Some(&sock));
assert!(
out.is_empty(),
"a relayed disco Ping is consumed (kept off the dataplane)"
);
assert!(
sock.candidate_addrs(&peer_disco.public_key()).is_empty(),
"a relayed Ping must not learn a candidate path"
);
}
#[tokio::test]
async fn relayed_call_me_maybe_forbidden_endpoints_filtered() {
let our_disco = DiscoPrivateKey::random();
let our_node = ts_keys::NodePrivateKey::random().public_key();
let sock = MagicSock::bind(localhost(), our_disco, our_node)
.await
.unwrap()
.with_binding_verifier(allow_all());
let peer_disco = DiscoPrivateKey::random();
let loopback: std::net::SocketAddr = "127.0.0.1:41641".parse().unwrap();
let private: std::net::SocketAddr = "10.1.2.3:41641".parse().unwrap();
let public: std::net::SocketAddr = "203.0.113.50:41641".parse().unwrap();
let cmm = ts_magicsock::seal_call_me_maybe(
&peer_disco,
&our_disco.public_key(),
&[loopback, private, public],
)
.unwrap();
let out = demux_relayed_disco(vec![PacketMut::from(&cmm[..])], Some(&sock));
assert!(out.is_empty(), "the CallMeMaybe is consumed, not forwarded");
assert_eq!(
sock.candidate_addrs(&peer_disco.public_key()),
vec![public],
"only the public candidate survives the pingable-candidate filter"
);
}
fn server(
ipv4: ts_derp::IpUsage<core::net::Ipv4Addr>,
stun_port: Option<u16>,
) -> ts_derp::ServerConnInfo {
ts_derp::ServerConnInfo {
hostname: "derp.example".to_string(),
ipv4,
ipv6: ts_derp::IpUsage::Disable,
tls_validation_config: ts_derp::TlsValidationConfig::CommonName {
common_name: "derp.example".to_string(),
},
https_port: 443,
stun_port,
stun_only: false,
supports_port_80: false,
}
}
fn region(servers: Vec<ts_derp::ServerConnInfo>) -> DerpRegion {
DerpRegion {
info: ts_derp::RegionInfo {
name: "r".to_string(),
code: "r".to_string(),
no_measure_no_home: false,
},
servers,
}
}
#[test]
fn stun_servers_from_regions_returns_only_fixed_v4_with_port() {
let fixed = core::net::Ipv4Addr::new(203, 0, 113, 5);
let r = region(vec![
server(ts_derp::IpUsage::FixedAddr(fixed), Some(3478)),
server(ts_derp::IpUsage::UseDns, Some(3478)),
server(ts_derp::IpUsage::Disable, Some(3478)),
server(
ts_derp::IpUsage::FixedAddr(core::net::Ipv4Addr::new(198, 51, 100, 9)),
None,
),
]);
let got = stun_servers_from_regions([&r]);
assert_eq!(
got,
vec![SocketAddr::V4(SocketAddrV4::new(fixed, 3478))],
"only the FixedAddr-v4-with-port server must be probed (UseDns/Disable/no-port skipped)"
);
}
#[test]
fn stun_servers_from_regions_empty_when_no_fixed_v4() {
let r = region(vec![
server(ts_derp::IpUsage::UseDns, Some(3478)),
server(ts_derp::IpUsage::Disable, None),
]);
assert!(
stun_servers_from_regions([&r]).is_empty(),
"no FixedAddr-v4 STUN server => empty probe list"
);
}
fn rid(n: u32) -> RegionId {
RegionId(core::num::NonZeroU32::new(n).unwrap())
}
#[test]
fn resolve_region_prefers_observed_then_home() {
let live = |_: RegionId| true;
assert_eq!(
resolve_region_for_peer(Some(rid(7)), Some(rid(19)), live),
Some(rid(7)),
"an observed route must win over the home-region fallback"
);
assert_eq!(
resolve_region_for_peer(None, Some(rid(19)), live),
Some(rid(19)),
"with no observed route, relay via our own home region"
);
assert_eq!(
resolve_region_for_peer(Some(rid(7)), None, live),
Some(rid(7)),
"an observed route is usable even before a home region is known"
);
assert_eq!(
resolve_region_for_peer(None, None, live),
None,
"with neither an observed route nor a home region there is no relay route"
);
}
#[test]
fn resolve_region_skips_region_without_live_transport() {
assert_eq!(
resolve_region_for_peer(None, Some(rid(19)), |_| false),
None,
"a home region with no live transport must not be returned"
);
assert_eq!(
resolve_region_for_peer(Some(rid(7)), Some(rid(19)), |r| r == rid(19)),
None,
"an observed region without a live transport is skipped even if home is live-but-not-chosen"
);
assert_eq!(
resolve_region_for_peer(Some(rid(7)), Some(rid(19)), |r| r == rid(7)),
Some(rid(7)),
"the observed route is returned when its transport is live"
);
}
#[test]
fn observed_routes_prune_to_live_peers() {
let mut routes: HashMap<PeerId, RegionId> = HashMap::new();
routes.insert(PeerId(1), rid(19));
routes.insert(PeerId(2), rid(7));
routes.insert(PeerId(3), rid(19));
let live: std::collections::HashSet<PeerId> = [PeerId(1), PeerId(3)].into_iter().collect();
routes.retain(|peer_id, _| live.contains(peer_id));
assert_eq!(routes.get(&PeerId(1)), Some(&rid(19)), "live peer kept");
assert_eq!(routes.get(&PeerId(3)), Some(&rid(19)), "live peer kept");
assert!(
!routes.contains_key(&PeerId(2)),
"a peer no longer in the netmap must have its observed route pruned"
);
}
}