use alloc::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
};
use futures_util::{Stream, StreamExt};
use tokio::{
sync::{broadcast, mpsc, watch},
task::JoinSet,
};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use url::Url;
use crate::{
ControlDialer, Error,
map_request_builder::MapRequestBuilder,
tokio::{
map_stream::{StateUpdate, map_stream, send_map_request},
ping::handle_ping,
},
};
#[derive(Debug, Clone, Default)]
struct CarriedNetInfo {
preferred_derp: Option<ts_derp::RegionId>,
derp_latency: Option<BTreeMap<String, f64>>,
working_udp: Option<bool>,
mapping_varies_by_dest_ip: Option<bool>,
}
impl CarriedNetInfo {
fn apply<'a>(&'a self, mut builder: MapRequestBuilder<'a>) -> MapRequestBuilder<'a> {
if let Some(derp) = self.preferred_derp {
builder = builder.preferred_derp(derp);
}
if let Some(latencies) = &self.derp_latency {
builder = builder.derp_latencies(latencies.iter().map(|(k, v)| (k.as_str(), *v)));
}
if let Some(working_udp) = self.working_udp {
builder = builder.working_udp(working_udp);
}
if let Some(varies) = self.mapping_varies_by_dest_ip {
builder = builder.mapping_varies_by_dest_ip(varies);
}
builder
}
}
#[derive(Debug)]
pub struct AsyncControlClient {
base_url: Url,
state_tx: broadcast::Sender<Arc<StateUpdate>>,
command_tx: mpsc::Sender<Command>,
_tasks: JoinSet<()>,
}
impl AsyncControlClient {
pub async fn check_auth(
config: &crate::Config,
node_keys: &ts_keys::NodeState,
auth_key: Option<&str>,
) -> Result<(), Error> {
let control_url = &config.server_url;
let h2_client = crate::tokio::connect(
control_url,
&node_keys.machine_keys,
config.allow_http_key_fetch,
)
.await?;
crate::tokio::register(config, control_url, auth_key, node_keys, &h2_client).await?;
Ok(())
}
#[tracing::instrument(skip_all, fields(control_url = %config.server_url))]
pub async fn connect(
config: &crate::Config,
node_keys: &ts_keys::NodeState,
auth_key: Option<&str>,
auth_url_tx: watch::Sender<Option<Url>>,
) -> Result<
(
Self,
impl Stream<Item = Arc<StateUpdate>> + Send + Sync + use<>,
),
Error,
> {
let control_url = &config.server_url;
let mut tasks = JoinSet::new();
let h2_client = crate::tokio::connect(
control_url,
&node_keys.machine_keys,
config.allow_http_key_fetch,
)
.await?;
tracing::info!("connected to control, registering");
crate::tokio::register(config, control_url, auth_key, node_keys, &h2_client).await?;
tracing::info!("registered, starting netmap stream");
let (state_tx, state_rx) = broadcast::channel(32);
let (command_tx, command_rx) = mpsc::channel(32);
tasks.spawn({
let state_tx = state_tx.clone();
let control_url = control_url.clone();
let node_keys = node_keys.clone();
let auth_key = auth_key.map(ToOwned::to_owned);
let config = config.clone();
async move {
run(
state_tx,
command_rx,
control_url.clone(),
node_keys.clone(),
auth_key,
config,
auth_url_tx,
)
.await
}
});
Ok((
Self {
base_url: control_url.clone(),
state_tx,
command_tx,
_tasks: tasks,
},
netmap_stream(state_rx),
))
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url(), %region_id), level = "trace")]
pub async fn set_home_region<'c>(
&mut self,
region_id: ts_derp::RegionId,
latencies: impl IntoIterator<Item = (&'c str, f64)>,
) {
tracing::trace!(region = %region_id, "reporting home derp to control server");
if let Err(e) = self
.command_tx
.send(Command::SetDerpHomeRegion {
id: region_id,
latencies: latencies
.into_iter()
.map(|(name, sample)| (name.to_owned(), sample))
.collect(),
})
.await
{
tracing::error!(error = %e, "setting home derp region");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url(), n_endpoints), level = "trace")]
pub async fn set_endpoints(&mut self, endpoints: Vec<ts_control_serde::Endpoint>) {
tracing::Span::current().record("n_endpoints", endpoints.len());
tracing::trace!("reporting magicsock endpoints to control server");
if let Err(e) = self
.command_tx
.send(Command::SetEndpoints { endpoints })
.await
{
tracing::error!(error = %e, "setting endpoints");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url(), n_routes = routes.len()), level = "trace")]
pub async fn set_routable_ips(&mut self, routes: Vec<ipnet::IpNet>) {
tracing::trace!("reporting routable IPs to control server");
if let Err(e) = self
.command_tx
.send(Command::SetRoutableIPs { routes })
.await
{
tracing::error!(error = %e, "setting routable IPs");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url()), level = "trace")]
pub async fn set_hostname(&mut self, hostname: String) {
tracing::trace!("reporting hostname to control server");
if let Err(e) = self
.command_tx
.send(Command::SetHostname { hostname })
.await
{
tracing::error!(error = %e, "setting hostname");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url()))]
pub async fn reauth(&mut self) {
tracing::info!("requesting node-key reauth on the live map-poll loop");
if let Err(e) = self.command_tx.send(Command::Reauth).await {
tracing::debug!(error = %e, "requesting reauth");
}
}
pub fn map_url(&self) -> Url {
self.base_url
.join("machine/map")
.expect("map_url was parsed without issue before")
}
pub fn netmap_stream(&self) -> impl Stream<Item = Arc<StateUpdate>> + Send + Sync + use<> {
netmap_stream(self.state_tx.subscribe())
}
}
#[derive(Debug)]
pub enum Command {
SetDerpHomeRegion {
id: ts_derp::RegionId,
latencies: BTreeMap<String, f64>,
},
SetEndpoints {
endpoints: Vec<ts_control_serde::Endpoint>,
},
SetRoutableIPs { routes: Vec<ipnet::IpNet> },
SetHostname { hostname: String },
Reauth,
}
#[derive(Clone, Default)]
struct MapSession {
handle: String,
seq: i64,
}
const MAX_SESSION_HANDLE_LEN: usize = 256;
fn advance_session(session: &mut MapSession, update: &StateUpdate) {
if let Some(handle) = &update.session_handle {
let valid = !handle.is_empty()
&& handle.len() <= MAX_SESSION_HANDLE_LEN
&& handle.bytes().all(|b| b.is_ascii_graphic());
if valid && *handle != session.handle {
session.handle = handle.clone();
session.seq = 0;
} else if !valid {
tracing::warn!(
handle_len = handle.len(),
"control sent an invalid map-session handle; ignoring it"
);
}
}
if update.seq != 0 {
session.seq = update.seq;
}
}
fn frame_resets_backoff(update: &StateUpdate) -> bool {
!update.keep_alive
}
#[derive(Debug, Default)]
struct ControlBackoff {
n: u32,
}
const MAP_BACKOFF_MAX: core::time::Duration = core::time::Duration::from_secs(30);
impl ControlBackoff {
fn reset(&mut self) {
self.n = 0;
}
fn next_delay(&mut self, rng: &mut impl rand::RngExt) -> core::time::Duration {
let base_ms = u64::from(self.n)
.saturating_mul(u64::from(self.n))
.saturating_mul(10);
let capped = core::time::Duration::from_millis(base_ms).min(MAP_BACKOFF_MAX);
self.n = self.n.saturating_add(1);
let factor = rng.random::<f64>() + 0.5;
capped.mul_f64(factor)
}
}
fn reconnect_delay_after_poll(
received_frame: bool,
backoff: &mut ControlBackoff,
rng: &mut impl rand::RngExt,
) -> core::time::Duration {
if received_frame {
backoff.reset();
}
backoff.next_delay(rng)
}
fn surface_reauth_url(err: &Error, auth_url_tx: &watch::Sender<Option<Url>>) {
if let Error::MachineNotAuthorized(url) = err {
auth_url_tx.send_if_modified(|current| {
if current.as_ref() == Some(url) {
false
} else {
*current = Some(url.clone());
true
}
});
}
}
fn clear_reauth_url(auth_url_tx: &watch::Sender<Option<Url>>) {
auth_url_tx.send_if_modified(|current| {
if current.is_some() {
*current = None;
true
} else {
false
}
});
}
fn net_info_working_udp(endpoints: &[ts_control_serde::Endpoint]) -> bool {
endpoints
.iter()
.any(|e| e.ty == ts_control_serde::EndpointType::Stun)
}
fn net_info_mapping_varies(endpoints: &[ts_control_serde::Endpoint]) -> bool {
let mut seen = BTreeSet::new();
for e in endpoints {
if e.ty == ts_control_serde::EndpointType::Stun && e.endpoint.is_ipv4() {
seen.insert(e.endpoint);
if seen.len() >= 2 {
return true;
}
}
}
false
}
pub async fn run(
state_tx: broadcast::Sender<Arc<StateUpdate>>,
mut command_rx: mpsc::Receiver<Command>,
control_url: Url,
mut node_keys: ts_keys::NodeState,
auth_key: Option<String>,
config: crate::Config,
auth_url_tx: watch::Sender<Option<Url>>,
) {
let mut dialer = ControlDialer::default();
let mut session = MapSession::default();
let mut backoff = ControlBackoff::default();
let mut net_info = CarriedNetInfo::default();
loop {
let mut received_frame = false;
let mut reauth_requested = false;
let mut register_succeeded = false;
let outcome = run_once(
&state_tx,
&mut command_rx,
&control_url,
&node_keys,
auth_key.as_deref(),
&config,
&mut dialer,
&mut session,
&mut net_info,
&mut received_frame,
&mut reauth_requested,
&mut register_succeeded,
&auth_url_tx,
)
.await;
if register_succeeded {
node_keys.clear_old_node_key();
}
if reauth_requested {
node_keys.rotate_node_key();
tracing::info!("rotated node key for reauth; reconnecting to re-register (Go doLogin)");
continue;
}
if received_frame {
clear_reauth_url(&auth_url_tx);
}
let delay = match &outcome {
Err(Error::RateLimited(retry_after)) => *retry_after,
_ => reconnect_delay_after_poll(received_frame, &mut backoff, &mut rand::rng()),
};
match outcome {
Ok(()) => {
tracing::warn!(
resume_handle = %session.handle,
resume_seq = session.seq,
backoff_ms = delay.as_millis() as u64,
"netmap stream ended without error, attempting restart"
);
}
Err(Error::RateLimited(retry_after)) => {
tracing::warn!(
?retry_after,
resume_handle = %session.handle,
resume_seq = session.seq,
"control rate-limited the map-poll re-register; waiting the server-requested delay"
);
}
Err(e) => {
tracing::error!(
error = %e,
resume_handle = %session.handle,
resume_seq = session.seq,
backoff_ms = delay.as_millis() as u64,
"netmap stream failed, attempting restart"
);
}
}
tokio::time::sleep(delay).await;
}
}
async fn run_once(
state_tx: &broadcast::Sender<Arc<StateUpdate>>,
command_rx: &mut mpsc::Receiver<Command>,
control_url: &Url,
node_keys: &ts_keys::NodeState,
auth_key: Option<&str>,
config: &crate::Config,
control_dialer: &mut ControlDialer,
session: &mut MapSession,
net_info: &mut CarriedNetInfo,
received_frame: &mut bool,
reauth_requested: &mut bool,
register_succeeded: &mut bool,
auth_url_tx: &watch::Sender<Option<Url>>,
) -> Result<(), Error> {
let h2_client = control_dialer
.full_connect_next(
control_url,
&node_keys.machine_keys,
config.allow_http_key_fetch,
)
.await?;
match crate::tokio::register(config, control_url, auth_key, node_keys, &h2_client).await {
Ok(()) => {
clear_reauth_url(auth_url_tx);
*register_succeeded = true;
}
Err(e) => {
let err = Error::from(e);
surface_reauth_url(&err, auth_url_tx);
return Err(err);
}
}
let client_name = config.format_client_name();
let host = crate::hostinfo::HostInfoData::detect();
let advertised_vip_services = config.advertised_vip_services();
let services_hash = crate::services_hash(&advertised_vip_services);
let builder = MapRequestBuilder::new(node_keys)
.keep_alive(true)
.omit_peers(false)
.stream(true)
.routable_ips(config.advertised_routes())
.client_info(&client_name, crate::PKG_VERSION)
.host_environment(&host)
.request_tags(config.tags.iter().map(String::as_str))
.services(config.advertised_services())
.services_hash(&services_hash)
.wire_ingress(config.wire_ingress)
.ingress_enabled(
config
.ingress_active
.load(core::sync::atomic::Ordering::Relaxed),
)
.app_connector(config.advertise_app_connector)
.allows_update(config.auto_update_apply == Some(true))
.map_session(&session.handle, session.seq);
let builder = net_info.apply(builder);
let request = if let Some(hostname) = &config.hostname {
builder.hostname(hostname)
} else {
builder
}
.build();
let map_url = control_url.join("machine/map").unwrap();
let reader = send_map_request(request, &map_url, &h2_client).await?;
let mut stream = core::pin::pin!(map_stream(reader));
tracing::info!("netmap stream started");
loop {
tokio::select! {
state_update = stream.next() => {
let Some(state_update) = state_update else {
break;
};
if frame_resets_backoff(&state_update) {
*received_frame = true;
}
advance_session(session, &state_update);
let _ = handle_ping(&state_update, control_url, &h2_client, config).await;
if let Some(dial_plan) = &state_update.dial_plan
&& control_dialer.update_dial_plan(dial_plan)
{
tracing::trace!(new_dial_plan = ?dial_plan);
}
let _ignore = state_tx.send(Arc::new(state_update));
}
command = command_rx.recv() => {
match command.unwrap() {
Command::SetDerpHomeRegion { id, latencies } => {
net_info.preferred_derp = Some(id);
net_info.derp_latency = Some(latencies);
let mut builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(config.advertised_routes())
.host_environment(&host);
builder = net_info.apply(builder);
if let Some(hostname) = &config.hostname {
builder = builder.hostname(hostname);
}
let req = builder.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::SetEndpoints { endpoints } => {
net_info.working_udp = Some(net_info_working_udp(&endpoints));
net_info.mapping_varies_by_dest_ip = Some(net_info_mapping_varies(&endpoints));
let mut builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(config.advertised_routes())
.endpoints(endpoints)
.host_environment(&host);
builder = net_info.apply(builder);
if let Some(hostname) = &config.hostname {
builder = builder.hostname(hostname);
}
let req = builder.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::SetRoutableIPs { routes } => {
let mut builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(routes)
.host_environment(&host);
builder = net_info.apply(builder);
if let Some(hostname) = &config.hostname {
builder = builder.hostname(hostname);
}
let req = builder.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::SetHostname { hostname } => {
let builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(config.advertised_routes())
.host_environment(&host)
.hostname(&hostname);
let req = net_info.apply(builder).build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::Reauth => {
tracing::info!("reauth requested; breaking poll loop to rotate node key");
*reauth_requested = true;
break;
},
}
}
}
}
Ok(())
}
fn netmap_stream(
rx: broadcast::Receiver<Arc<StateUpdate>>,
) -> impl Stream<Item = Arc<StateUpdate>> + Send + Sync {
tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(async |x| {
if let Err(BroadcastStreamRecvError::Lagged(n)) = &x {
tracing::warn!(messages_missed = n, "map_stream lagged");
}
x.ok()
})
}
#[cfg(test)]
mod tests {
use super::*;
fn update(handle: Option<&str>, seq: i64) -> StateUpdate {
update_ka(handle, seq, false)
}
fn update_ka(handle: Option<&str>, seq: i64, keep_alive: bool) -> StateUpdate {
StateUpdate {
session_handle: handle.map(ToOwned::to_owned),
seq,
keep_alive,
derp: None,
node: None,
peer_update: None,
peer_patches: Vec::new(),
user_profiles: Vec::new(),
ping: None,
packetfilter: None,
cap_grants: None,
pop_browser_url: None,
dial_plan: None,
dns_config: None,
ssh_policy: None,
tka: None,
online_change: Default::default(),
peer_seen_change: Default::default(),
}
}
#[test]
fn advance_session_captures_handle_and_seq() {
let mut session = MapSession::default();
advance_session(&mut session, &update(Some("sess-1"), 5));
assert_eq!(session.handle, "sess-1");
assert_eq!(session.seq, 5);
}
#[test]
fn advance_session_keepalive_preserves_cursor() {
let mut session = MapSession {
handle: "sess-1".to_owned(),
seq: 7,
};
advance_session(&mut session, &update(None, 0));
assert_eq!(session.handle, "sess-1");
assert_eq!(session.seq, 7);
}
#[test]
fn advance_session_resets_seq_on_new_handle() {
let mut session = MapSession {
handle: "sess-1".to_owned(),
seq: 42,
};
advance_session(&mut session, &update(Some("sess-2"), 0));
assert_eq!(session.handle, "sess-2");
assert_eq!(session.seq, 0);
}
#[test]
fn advance_session_same_handle_keeps_seq() {
let mut session = MapSession {
handle: "sess-1".to_owned(),
seq: 10,
};
advance_session(&mut session, &update(Some("sess-1"), 0));
assert_eq!(session.handle, "sess-1");
assert_eq!(session.seq, 10);
}
#[test]
fn backoff_reset_keys_on_keepalive_not_seq() {
assert!(
!frame_resets_backoff(&update_ka(None, 0, true)),
"a keep-alive must not reset the backoff"
);
assert!(
!frame_resets_backoff(&update_ka(Some("sess-1"), 0, true)),
"a session-opening keep-alive must not reset the backoff"
);
assert!(
frame_resets_backoff(&update_ka(None, 0, false)),
"a substantive netmap with seq==0 (Headscale-style) MUST reset the backoff"
);
assert!(
frame_resets_backoff(&update_ka(Some("sess-1"), 0, false)),
"a session-opening substantive netmap with seq==0 MUST reset the backoff"
);
assert!(
frame_resets_backoff(&update_ka(Some("sess-1"), 1, false)),
"a substantive response with a resume cursor (seq==1) must reset the backoff"
);
}
#[test]
fn advance_session_rejects_overlong_handle() {
let mut session = MapSession::default();
let huge = "a".repeat(MAX_SESSION_HANDLE_LEN + 1);
advance_session(&mut session, &update(Some(&huge), 3));
assert_eq!(session.handle, "");
assert_eq!(session.seq, 3);
}
#[test]
fn advance_session_rejects_non_graphic_handle() {
let mut session = MapSession::default();
advance_session(&mut session, &update(Some("bad\nhandle"), 1));
assert_eq!(session.handle, "");
assert_eq!(session.seq, 1);
}
#[test]
fn control_backoff_delay_is_within_the_go_jitter_envelope() {
let mut rng = rand::rng();
for n in 0u32..80 {
let unjittered_ms = u64::from(n)
.saturating_mul(u64::from(n))
.saturating_mul(10)
.min(MAP_BACKOFF_MAX.as_millis() as u64);
let unjittered = core::time::Duration::from_millis(unjittered_ms);
for _ in 0..100 {
let mut probe = ControlBackoff { n };
let d = probe.next_delay(&mut rng);
if unjittered.is_zero() {
assert_eq!(d, core::time::Duration::ZERO, "n=0 delay must be zero");
} else {
assert!(
d >= unjittered.mul_f64(0.5) && d < unjittered.mul_f64(1.5),
"n={n}: delay {d:?} outside [0.5,1.5) x {unjittered:?}"
);
}
}
}
}
#[test]
fn control_backoff_saturates_at_the_cap() {
let mut rng = rand::rng();
let mut probe = ControlBackoff { n: 1000 };
let d = probe.next_delay(&mut rng);
assert!(
d >= MAP_BACKOFF_MAX.mul_f64(0.5) && d < MAP_BACKOFF_MAX.mul_f64(1.5),
"saturated delay {d:?} outside the cap's jitter envelope"
);
let mut probe = ControlBackoff { n: u32::MAX };
let d = probe.next_delay(&mut rng);
assert!(d < MAP_BACKOFF_MAX.mul_f64(1.5), "overflowed at u32::MAX");
}
#[test]
fn control_backoff_reset_returns_to_bottom() {
let mut rng = rand::rng();
let mut bo = ControlBackoff::default();
for _ in 0..5 {
let _ = bo.next_delay(&mut rng);
}
assert!(bo.n > 0, "counter advanced");
bo.reset();
assert_eq!(bo.n, 0, "reset zeroes the counter");
let d = bo.next_delay(&mut rng);
assert_eq!(d, core::time::Duration::ZERO, "n=0 delay is zero");
assert_eq!(bo.n, 1, "counter advances after the n=0 draw");
}
#[test]
fn reconnect_delay_resets_only_when_a_frame_arrived() {
let mut rng = rand::rng();
let mut backoff = ControlBackoff::default();
let mut last_n = backoff.n;
for i in 0..6 {
let d = reconnect_delay_after_poll(false, &mut backoff, &mut rng);
assert!(
backoff.n > last_n,
"frameless poll {i} must advance the counter (no reset)"
);
last_n = backoff.n;
if i > 0 {
assert!(
d > core::time::Duration::ZERO,
"frameless reconnect {i} must be delayed, not a 0ms spin"
);
}
}
let d = reconnect_delay_after_poll(true, &mut backoff, &mut rng);
assert_eq!(
d,
core::time::Duration::ZERO,
"a poll that delivered a frame resets to the immediate (n=0) reconnect"
);
assert_eq!(backoff.n, 1, "reset then one draw leaves the counter at 1");
}
fn auth_url() -> Url {
"https://login.example/a/abc123".parse().unwrap()
}
#[test]
fn mid_session_machine_not_authorized_sets_auth_url_cell() {
let (tx, rx) = watch::channel(None);
let url = auth_url();
surface_reauth_url(&Error::MachineNotAuthorized(url.clone()), &tx);
assert_eq!(*rx.borrow(), Some(url));
}
#[test]
fn machine_not_authorized_none_does_not_set_url_cell() {
let (tx, rx) = watch::channel(None);
let err =
Error::from(crate::tokio::register::RegistrationError::MachineNotAuthorized(None));
assert!(matches!(err, Error::NeedsMachineAuth));
surface_reauth_url(&err, &tx);
assert_eq!(
*rx.borrow(),
None,
"no auth URL on offer must not set the cell"
);
}
#[test]
fn non_auth_error_does_not_set_url_cell() {
let (tx, rx) = watch::channel(None);
surface_reauth_url(&Error::NetworkError(crate::Operation::Registration), &tx);
assert_eq!(*rx.borrow(), None);
}
#[test]
fn clear_reauth_url_resets_a_pending_url() {
let (tx, rx) = watch::channel(Some(auth_url()));
clear_reauth_url(&tx);
assert_eq!(*rx.borrow(), None);
}
#[test]
fn clear_reauth_url_on_empty_cell_does_not_notify() {
let (tx, rx) = watch::channel::<Option<Url>>(None);
clear_reauth_url(&tx);
assert!(!rx.has_changed().unwrap());
assert_eq!(*rx.borrow(), None);
}
#[test]
fn surface_then_clear_leaves_cell_empty() {
let (tx, rx) = watch::channel(None);
let url = auth_url();
surface_reauth_url(&Error::MachineNotAuthorized(url.clone()), &tx);
assert_eq!(*rx.borrow(), Some(url));
clear_reauth_url(&tx); assert_eq!(*rx.borrow(), None);
}
fn ep(addr: &str, ty: ts_control_serde::EndpointType) -> ts_control_serde::Endpoint {
ts_control_serde::Endpoint {
endpoint: addr.parse().unwrap(),
ty,
}
}
#[test]
fn working_udp_true_iff_a_stun_endpoint_is_present() {
use ts_control_serde::EndpointType::{Local, Stun};
assert!(net_info_working_udp(&[
ep("192.168.1.2:41641", Local),
ep("203.0.113.7:41641", Stun),
]));
assert!(!net_info_working_udp(&[ep("192.168.1.2:41641", Local)]));
assert!(!net_info_working_udp(&[]));
}
#[test]
fn mapping_varies_iff_two_distinct_stun_reflexives() {
use ts_control_serde::EndpointType::{Local, Stun, Stun4LocalPort};
assert!(net_info_mapping_varies(&[
ep("203.0.113.7:41641", Stun),
ep("198.51.100.9:51000", Stun),
]));
assert!(!net_info_mapping_varies(&[
ep("203.0.113.7:41641", Stun),
ep("203.0.113.7:41641", Stun),
]));
assert!(!net_info_mapping_varies(&[
ep("203.0.113.7:41641", Stun),
ep("203.0.113.7:50000", Stun4LocalPort),
ep("192.168.1.2:41641", Local),
]));
assert!(!net_info_mapping_varies(&[
ep("[2001:db8::1]:41641", Stun),
ep("[2001:db8::2]:41641", Stun),
]));
assert!(net_info_mapping_varies(&[
ep("[2001:db8::1]:41641", Stun),
ep("203.0.113.7:41641", Stun),
ep("198.51.100.9:51000", Stun),
]));
}
#[allow(clippy::field_reassign_with_default)]
#[test]
fn carried_net_info_sends_whole_set_not_partial() {
use ts_control_serde::EndpointType::Stun;
let node_keys = ts_keys::NodeState::generate();
let mut carried = CarriedNetInfo::default();
carried.preferred_derp = Some(region(5));
carried.derp_latency = Some(BTreeMap::from([("5-v4".to_owned(), 0.012)]));
let endpoints = [
ep("203.0.113.7:41641", Stun),
ep("198.51.100.9:51000", Stun),
];
carried.working_udp = Some(net_info_working_udp(&endpoints));
carried.mapping_varies_by_dest_ip = Some(net_info_mapping_varies(&endpoints));
let builder = carried.apply(MapRequestBuilder::new(&node_keys));
let req = builder.build();
let ni = req
.host_info
.as_ref()
.and_then(|h| h.net_info.as_ref())
.expect("net_info present");
let want = ts_control_serde::DerpRegionId::from(core::num::NonZeroU32::new(5).unwrap());
assert_eq!(
ni.preferred_derp,
Some(want),
"home region must persist across the endpoints request"
);
assert!(ni.derp_latency.is_some(), "derp latency must persist too");
assert_eq!(ni.working_udp, Some(true));
assert_eq!(ni.mapping_varies_by_dest_ip, Some(true));
}
#[test]
fn empty_carried_net_info_emits_no_net_info() {
let node_keys = ts_keys::NodeState::generate();
let carried = CarriedNetInfo::default();
let req = carried.apply(MapRequestBuilder::new(&node_keys)).build();
assert!(
req.host_info
.as_ref()
.and_then(|h| h.net_info.as_ref())
.is_none(),
"empty carried NetInfo must leave net_info absent on the wire"
);
}
fn region(n: u32) -> ts_derp::RegionId {
ts_derp::RegionId(core::num::NonZeroU32::new(n).unwrap())
}
#[test]
fn reauth_rotation_records_old_node_key_and_preserves_other_keys() {
let mut node_keys = ts_keys::NodeState::generate();
let prior_node = node_keys.node_keys.public;
let disco_before = node_keys.disco_keys.public;
let machine_before = node_keys.machine_keys.public;
node_keys.rotate_node_key();
assert_eq!(
node_keys.old_node_key,
Some(prior_node),
"the prior node key must be recorded as OldNodeKey for the re-register"
);
assert_ne!(
node_keys.node_keys.public, prior_node,
"a fresh node key must replace the expired one"
);
assert_eq!(node_keys.disco_keys.public, disco_before);
assert_eq!(node_keys.machine_keys.public, machine_before);
}
#[tokio::test]
async fn reauth_enqueues_reauth_command() {
let (command_tx, mut command_rx) = mpsc::channel(4);
let (state_tx, _state_rx) = broadcast::channel(4);
let mut client = AsyncControlClient {
base_url: "https://control.example/".parse().unwrap(),
state_tx,
command_tx,
_tasks: JoinSet::new(),
};
client.reauth().await;
let cmd = command_rx.try_recv().expect("a command was enqueued");
assert!(
matches!(cmd, Command::Reauth),
"reauth() must enqueue Command::Reauth, got {cmd:?}"
);
}
}