use alloc::{collections::BTreeMap, sync::Arc};
use futures_util::{Stream, StreamExt};
use tokio::{
sync::{broadcast, mpsc},
task::JoinSet,
};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use url::Url;
use crate::{
ControlDialer, Error,
map_request_builder::MapRequestBuilder,
tokio::{
map_stream::{StateUpdate, map_stream, send_map_request},
ping::handle_ping,
},
};
#[derive(Debug)]
pub struct AsyncControlClient {
base_url: Url,
state_tx: broadcast::Sender<Arc<StateUpdate>>,
command_tx: mpsc::Sender<Command>,
_tasks: JoinSet<()>,
}
impl AsyncControlClient {
pub async fn check_auth(
config: &crate::Config,
node_keys: &ts_keys::NodeState,
auth_key: Option<&str>,
) -> Result<(), Error> {
let control_url = &config.server_url;
let h2_client = crate::tokio::connect(
control_url,
&node_keys.machine_keys,
config.allow_http_key_fetch,
)
.await?;
crate::tokio::register(config, control_url, auth_key, node_keys, &h2_client).await?;
Ok(())
}
#[tracing::instrument(skip_all, fields(control_url = %config.server_url))]
pub async fn connect(
config: &crate::Config,
node_keys: &ts_keys::NodeState,
auth_key: Option<&str>,
) -> Result<
(
Self,
impl Stream<Item = Arc<StateUpdate>> + Send + Sync + use<>,
),
Error,
> {
let control_url = &config.server_url;
let mut tasks = JoinSet::new();
let h2_client = crate::tokio::connect(
control_url,
&node_keys.machine_keys,
config.allow_http_key_fetch,
)
.await?;
tracing::info!("connected to control, registering");
crate::tokio::register(config, control_url, auth_key, node_keys, &h2_client).await?;
tracing::info!("registered, starting netmap stream");
let (state_tx, state_rx) = broadcast::channel(32);
let (command_tx, command_rx) = mpsc::channel(32);
tasks.spawn({
let state_tx = state_tx.clone();
let control_url = control_url.clone();
let node_keys = node_keys.clone();
let auth_key = auth_key.map(ToOwned::to_owned);
let config = config.clone();
async move {
run(
state_tx,
command_rx,
control_url.clone(),
node_keys.clone(),
auth_key,
config,
)
.await
}
});
Ok((
Self {
base_url: control_url.clone(),
state_tx,
command_tx,
_tasks: tasks,
},
netmap_stream(state_rx),
))
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url(), %region_id), level = "trace")]
pub async fn set_home_region<'c>(
&mut self,
region_id: ts_derp::RegionId,
latencies: impl IntoIterator<Item = (&'c str, f64)>,
) {
tracing::trace!(region = %region_id, "reporting home derp to control server");
if let Err(e) = self
.command_tx
.send(Command::SetDerpHomeRegion {
id: region_id,
latencies: latencies
.into_iter()
.map(|(name, sample)| (name.to_owned(), sample))
.collect(),
})
.await
{
tracing::error!(error = %e, "setting home derp region");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url(), n_endpoints), level = "trace")]
pub async fn set_endpoints(&mut self, endpoints: Vec<ts_control_serde::Endpoint>) {
tracing::Span::current().record("n_endpoints", endpoints.len());
tracing::trace!("reporting magicsock endpoints to control server");
if let Err(e) = self
.command_tx
.send(Command::SetEndpoints { endpoints })
.await
{
tracing::error!(error = %e, "setting endpoints");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url(), n_routes = routes.len()), level = "trace")]
pub async fn set_routable_ips(&mut self, routes: Vec<ipnet::IpNet>) {
tracing::trace!("reporting routable IPs to control server");
if let Err(e) = self
.command_tx
.send(Command::SetRoutableIPs { routes })
.await
{
tracing::error!(error = %e, "setting routable IPs");
}
}
#[tracing::instrument(skip_all, fields(map_url = %self.map_url()), level = "trace")]
pub async fn set_hostname(&mut self, hostname: String) {
tracing::trace!("reporting hostname to control server");
if let Err(e) = self
.command_tx
.send(Command::SetHostname { hostname })
.await
{
tracing::error!(error = %e, "setting hostname");
}
}
pub fn map_url(&self) -> Url {
self.base_url
.join("machine/map")
.expect("map_url was parsed without issue before")
}
pub fn netmap_stream(&self) -> impl Stream<Item = Arc<StateUpdate>> + Send + Sync + use<> {
netmap_stream(self.state_tx.subscribe())
}
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug)]
pub enum Command {
SetDerpHomeRegion {
id: ts_derp::RegionId,
latencies: BTreeMap<String, f64>,
},
SetEndpoints {
endpoints: Vec<ts_control_serde::Endpoint>,
},
SetRoutableIPs { routes: Vec<ipnet::IpNet> },
SetHostname { hostname: String },
}
#[derive(Clone, Default)]
struct MapSession {
handle: String,
seq: i64,
}
const MAX_SESSION_HANDLE_LEN: usize = 256;
fn advance_session(session: &mut MapSession, update: &StateUpdate) {
if let Some(handle) = &update.session_handle {
let valid = !handle.is_empty()
&& handle.len() <= MAX_SESSION_HANDLE_LEN
&& handle.bytes().all(|b| b.is_ascii_graphic());
if valid && *handle != session.handle {
session.handle = handle.clone();
session.seq = 0;
} else if !valid {
tracing::warn!(
handle_len = handle.len(),
"control sent an invalid map-session handle; ignoring it"
);
}
}
if update.seq != 0 {
session.seq = update.seq;
}
}
#[derive(Debug, Default)]
struct ControlBackoff {
n: u32,
}
const MAP_BACKOFF_MAX: core::time::Duration = core::time::Duration::from_secs(30);
impl ControlBackoff {
fn reset(&mut self) {
self.n = 0;
}
fn next_delay(&mut self, rng: &mut impl rand::RngExt) -> core::time::Duration {
let base_ms = u64::from(self.n)
.saturating_mul(u64::from(self.n))
.saturating_mul(10);
let capped = core::time::Duration::from_millis(base_ms).min(MAP_BACKOFF_MAX);
self.n = self.n.saturating_add(1);
let factor = rng.random::<f64>() + 0.5;
capped.mul_f64(factor)
}
}
fn reconnect_delay_after_poll(
received_frame: bool,
backoff: &mut ControlBackoff,
rng: &mut impl rand::RngExt,
) -> core::time::Duration {
if received_frame {
backoff.reset();
}
backoff.next_delay(rng)
}
pub async fn run(
state_tx: broadcast::Sender<Arc<StateUpdate>>,
mut command_rx: mpsc::Receiver<Command>,
control_url: Url,
node_keys: ts_keys::NodeState,
auth_key: Option<String>,
config: crate::Config,
) {
let mut dialer = ControlDialer::default();
let mut session = MapSession::default();
let mut backoff = ControlBackoff::default();
loop {
let mut received_frame = false;
let outcome = run_once(
&state_tx,
&mut command_rx,
&control_url,
&node_keys,
auth_key.as_deref(),
&config,
&mut dialer,
&mut session,
&mut received_frame,
)
.await;
let delay = reconnect_delay_after_poll(received_frame, &mut backoff, &mut rand::rng());
match outcome {
Ok(()) => {
tracing::warn!(
resume_handle = %session.handle,
resume_seq = session.seq,
backoff_ms = delay.as_millis() as u64,
"netmap stream ended without error, attempting restart"
);
}
Err(e) => {
tracing::error!(
error = %e,
resume_handle = %session.handle,
resume_seq = session.seq,
backoff_ms = delay.as_millis() as u64,
"netmap stream failed, attempting restart"
);
}
}
tokio::time::sleep(delay).await;
}
}
async fn run_once(
state_tx: &broadcast::Sender<Arc<StateUpdate>>,
command_rx: &mut mpsc::Receiver<Command>,
control_url: &Url,
node_keys: &ts_keys::NodeState,
auth_key: Option<&str>,
config: &crate::Config,
control_dialer: &mut ControlDialer,
session: &mut MapSession,
received_frame: &mut bool,
) -> Result<(), Error> {
let h2_client = control_dialer
.full_connect_next(
control_url,
&node_keys.machine_keys,
config.allow_http_key_fetch,
)
.await?;
crate::tokio::register(config, control_url, auth_key, node_keys, &h2_client).await?;
let client_name = config.format_client_name();
let advertised_vip_services = config.advertised_vip_services();
let services_hash = crate::services_hash(&advertised_vip_services);
let builder = MapRequestBuilder::new(node_keys)
.keep_alive(true)
.omit_peers(false)
.stream(true)
.routable_ips(config.advertised_routes())
.client_info(&client_name, crate::PKG_VERSION)
.request_tags(config.tags.iter().map(String::as_str))
.services(config.advertised_services())
.services_hash(&services_hash)
.wire_ingress(config.wire_ingress)
.ingress_enabled(
config
.ingress_active
.load(core::sync::atomic::Ordering::Relaxed),
)
.map_session(&session.handle, session.seq);
let request = if let Some(hostname) = &config.hostname {
builder.hostname(hostname)
} else {
builder
}
.build();
let map_url = control_url.join("machine/map").unwrap();
let reader = send_map_request(request, &map_url, &h2_client).await?;
let mut stream = core::pin::pin!(map_stream(reader));
tracing::info!("netmap stream started");
loop {
tokio::select! {
state_update = stream.next() => {
let Some(state_update) = state_update else {
break;
};
*received_frame = true;
advance_session(session, &state_update);
let _ = handle_ping(&state_update, control_url, &h2_client, config).await;
if let Some(dial_plan) = &state_update.dial_plan
&& control_dialer.update_dial_plan(dial_plan)
{
tracing::trace!(new_dial_plan = ?dial_plan);
}
let _ignore = state_tx.send(Arc::new(state_update));
}
command = command_rx.recv() => {
match command.unwrap() {
Command::SetDerpHomeRegion { id, latencies } => {
let mut builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(config.advertised_routes())
.preferred_derp(id)
.derp_latencies(latencies.iter().map(|(k, v)| (k.as_str(), *v)));
if let Some(hostname) = &config.hostname {
builder = builder.hostname(hostname);
}
let req = builder.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::SetEndpoints { endpoints } => {
let mut builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(config.advertised_routes())
.endpoints(endpoints);
if let Some(hostname) = &config.hostname {
builder = builder.hostname(hostname);
}
let req = builder.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::SetRoutableIPs { routes } => {
let mut builder = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(routes);
if let Some(hostname) = &config.hostname {
builder = builder.hostname(hostname);
}
let req = builder.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
Command::SetHostname { hostname } => {
let req = MapRequestBuilder::new(node_keys)
.keep_alive(false)
.omit_peers(true)
.stream(false)
.routable_ips(config.advertised_routes())
.hostname(&hostname)
.build();
drop(send_map_request(req, &map_url, &h2_client).await?);
},
}
}
}
}
Ok(())
}
fn netmap_stream(
rx: broadcast::Receiver<Arc<StateUpdate>>,
) -> impl Stream<Item = Arc<StateUpdate>> + Send + Sync {
tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(async |x| {
if let Err(BroadcastStreamRecvError::Lagged(n)) = &x {
tracing::warn!(messages_missed = n, "map_stream lagged");
}
x.ok()
})
}
#[cfg(test)]
mod tests {
use super::*;
fn update(handle: Option<&str>, seq: i64) -> StateUpdate {
StateUpdate {
session_handle: handle.map(ToOwned::to_owned),
seq,
derp: None,
node: None,
peer_update: None,
peer_patches: Vec::new(),
user_profiles: Vec::new(),
ping: None,
packetfilter: None,
cap_grants: None,
pop_browser_url: None,
dial_plan: None,
dns_config: None,
ssh_policy: None,
tka: None,
online_change: Default::default(),
peer_seen_change: Default::default(),
}
}
#[test]
fn advance_session_captures_handle_and_seq() {
let mut session = MapSession::default();
advance_session(&mut session, &update(Some("sess-1"), 5));
assert_eq!(session.handle, "sess-1");
assert_eq!(session.seq, 5);
}
#[test]
fn advance_session_keepalive_preserves_cursor() {
let mut session = MapSession {
handle: "sess-1".to_owned(),
seq: 7,
};
advance_session(&mut session, &update(None, 0));
assert_eq!(session.handle, "sess-1");
assert_eq!(session.seq, 7);
}
#[test]
fn advance_session_resets_seq_on_new_handle() {
let mut session = MapSession {
handle: "sess-1".to_owned(),
seq: 42,
};
advance_session(&mut session, &update(Some("sess-2"), 0));
assert_eq!(session.handle, "sess-2");
assert_eq!(session.seq, 0);
}
#[test]
fn advance_session_same_handle_keeps_seq() {
let mut session = MapSession {
handle: "sess-1".to_owned(),
seq: 10,
};
advance_session(&mut session, &update(Some("sess-1"), 0));
assert_eq!(session.handle, "sess-1");
assert_eq!(session.seq, 10);
}
#[test]
fn advance_session_rejects_overlong_handle() {
let mut session = MapSession::default();
let huge = "a".repeat(MAX_SESSION_HANDLE_LEN + 1);
advance_session(&mut session, &update(Some(&huge), 3));
assert_eq!(session.handle, "");
assert_eq!(session.seq, 3);
}
#[test]
fn advance_session_rejects_non_graphic_handle() {
let mut session = MapSession::default();
advance_session(&mut session, &update(Some("bad\nhandle"), 1));
assert_eq!(session.handle, "");
assert_eq!(session.seq, 1);
}
#[test]
fn control_backoff_delay_is_within_the_go_jitter_envelope() {
let mut rng = rand::rng();
for n in 0u32..80 {
let unjittered_ms = u64::from(n)
.saturating_mul(u64::from(n))
.saturating_mul(10)
.min(MAP_BACKOFF_MAX.as_millis() as u64);
let unjittered = core::time::Duration::from_millis(unjittered_ms);
for _ in 0..100 {
let mut probe = ControlBackoff { n };
let d = probe.next_delay(&mut rng);
if unjittered.is_zero() {
assert_eq!(d, core::time::Duration::ZERO, "n=0 delay must be zero");
} else {
assert!(
d >= unjittered.mul_f64(0.5) && d < unjittered.mul_f64(1.5),
"n={n}: delay {d:?} outside [0.5,1.5) x {unjittered:?}"
);
}
}
}
}
#[test]
fn control_backoff_saturates_at_the_cap() {
let mut rng = rand::rng();
let mut probe = ControlBackoff { n: 1000 };
let d = probe.next_delay(&mut rng);
assert!(
d >= MAP_BACKOFF_MAX.mul_f64(0.5) && d < MAP_BACKOFF_MAX.mul_f64(1.5),
"saturated delay {d:?} outside the cap's jitter envelope"
);
let mut probe = ControlBackoff { n: u32::MAX };
let d = probe.next_delay(&mut rng);
assert!(d < MAP_BACKOFF_MAX.mul_f64(1.5), "overflowed at u32::MAX");
}
#[test]
fn control_backoff_reset_returns_to_bottom() {
let mut rng = rand::rng();
let mut bo = ControlBackoff::default();
for _ in 0..5 {
let _ = bo.next_delay(&mut rng);
}
assert!(bo.n > 0, "counter advanced");
bo.reset();
assert_eq!(bo.n, 0, "reset zeroes the counter");
let d = bo.next_delay(&mut rng);
assert_eq!(d, core::time::Duration::ZERO, "n=0 delay is zero");
assert_eq!(bo.n, 1, "counter advances after the n=0 draw");
}
#[test]
fn reconnect_delay_resets_only_when_a_frame_arrived() {
let mut rng = rand::rng();
let mut backoff = ControlBackoff::default();
let mut last_n = backoff.n;
for i in 0..6 {
let d = reconnect_delay_after_poll(false, &mut backoff, &mut rng);
assert!(
backoff.n > last_n,
"frameless poll {i} must advance the counter (no reset)"
);
last_n = backoff.n;
if i > 0 {
assert!(
d > core::time::Duration::ZERO,
"frameless reconnect {i} must be delayed, not a 0ms spin"
);
}
}
let d = reconnect_delay_after_poll(true, &mut backoff, &mut rng);
assert_eq!(
d,
core::time::Duration::ZERO,
"a poll that delivered a frame resets to the immediate (n=0) reconnect"
);
assert_eq!(backoff.n, 1, "reset then one draw leaves the counter at 1");
}
}