use alloc::collections::BTreeMap;
use bytes::Bytes;
use futures_util::Stream;
use tokio::io::{AsyncRead, AsyncReadExt};
use ts_control_serde::{MapRequest, MapResponse, PingRequest};
use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt};
use ts_packet::PacketMut;
use ts_packetfilter as pf;
use ts_packetfilter_state as pf_state;
use url::Url;
use crate::{DialPlan, NodeId};
#[derive(Debug, thiserror::Error, Clone, Copy, Eq, PartialEq)]
pub enum MapStreamError {
#[error("serialization error")]
SerDe,
#[error("unsuccessful HTTP request or upgrade")]
Http,
#[error("Network error")]
NetworkError,
}
impl From<serde_json::Error> for MapStreamError {
fn from(error: serde_json::Error) -> Self {
tracing::error!(%error, "serialization error sending map request");
MapStreamError::SerDe
}
}
impl From<ts_http_util::Error> for MapStreamError {
fn from(error: ts_http_util::Error) -> Self {
tracing::error!(%error, "http error sending map request");
if crate::http_error_is_recoverable(error) {
MapStreamError::NetworkError
} else {
MapStreamError::Http
}
}
}
impl From<MapStreamError> for crate::Error {
fn from(e: MapStreamError) -> Self {
match e {
MapStreamError::SerDe => crate::Error::Internal(
crate::InternalErrorKind::SerDe,
crate::Operation::MapRequest,
),
MapStreamError::Http => {
crate::Error::Internal(crate::InternalErrorKind::Http, crate::Operation::MapRequest)
}
MapStreamError::NetworkError => {
crate::Error::NetworkError(crate::Operation::MapRequest)
}
}
}
}
#[derive(Debug)]
pub enum PeerUpdate {
Full(Vec<crate::Node>),
Delta {
upsert: Vec<crate::Node>,
remove: Vec<NodeId>,
},
}
pub type FilterUpdate = (Option<pf::Ruleset>, BTreeMap<String, Option<pf::Ruleset>>);
#[derive(Debug)]
pub struct StateUpdate {
pub session_handle: Option<alloc::string::String>,
pub seq: i64,
pub keep_alive: bool,
pub derp: Option<crate::DerpMap>,
pub node: Option<crate::Node>,
pub peer_update: Option<PeerUpdate>,
pub peer_patches: Vec<crate::PeerChange>,
pub user_profiles: Vec<crate::UserProfile>,
pub ping: Option<PingRequest>,
pub packetfilter: Option<FilterUpdate>,
pub cap_grants: Option<Vec<ts_packetfilter_state::CapGrant>>,
pub pop_browser_url: Option<Url>,
pub dial_plan: Option<DialPlan>,
pub dns_config: Option<crate::DnsConfig>,
pub ssh_policy: Option<crate::SshPolicy>,
pub tka: Option<crate::TkaStatus>,
pub online_change: alloc::collections::BTreeMap<crate::NodeId, bool>,
pub peer_seen_change: alloc::collections::BTreeMap<crate::NodeId, bool>,
}
const MAX_NETMAP_FRAME: u32 = 16 * 1024 * 1024;
const MAX_DECODED_NETMAP: u64 = 64 * 1024 * 1024;
const MAP_READ_WATCHDOG: core::time::Duration = core::time::Duration::from_secs(120);
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xb5, 0x2f, 0xfd];
fn decompress_frame(frame: &[u8]) -> Option<alloc::vec::Vec<u8>> {
use std::io::Read as _;
if frame.len() < ZSTD_MAGIC.len() || frame[..ZSTD_MAGIC.len()] != ZSTD_MAGIC {
return Some(frame.to_vec());
}
let mut decoder = ruzstd::decoding::StreamingDecoder::new(frame)
.inspect_err(|e| tracing::error!(error = %e, "initializing zstd decoder for netmap frame"))
.ok()?;
let mut decoded = alloc::vec::Vec::new();
decoder
.by_ref()
.take(MAX_DECODED_NETMAP + 1)
.read_to_end(&mut decoded)
.inspect_err(|e| tracing::error!(error = %e, "decompressing netmap frame"))
.ok()?;
if decoded.len() as u64 > MAX_DECODED_NETMAP {
tracing::error!(
max = MAX_DECODED_NETMAP,
"decompressed netmap frame exceeds bound; ending stream"
);
return None;
}
Some(decoded)
}
pub fn map_stream(reader: impl AsyncRead + Unpin) -> impl Stream<Item = StateUpdate> {
futures_util::stream::unfold(reader, async |mut reader| {
let msg_len = match tokio::time::timeout(MAP_READ_WATCHDOG, reader.read_u32_le()).await {
Ok(res) => res
.inspect_err(|e| {
tracing::error!(error = %e, "could not read netmap length");
})
.ok()?,
Err(_elapsed) => {
tracing::error!(
watchdog_secs = MAP_READ_WATCHDOG.as_secs(),
"no netmap frame within the keep-alive watchdog; ending stream to reconnect"
);
return None;
}
};
if msg_len > MAX_NETMAP_FRAME {
tracing::error!(
?msg_len,
max = MAX_NETMAP_FRAME,
"netmap frame too large; ending stream"
);
return None;
}
let mut buf = PacketMut::new(msg_len as usize);
tracing::trace!(?msg_len, "reading netmap");
match tokio::time::timeout(MAP_READ_WATCHDOG, reader.read_exact(buf.as_mut())).await {
Ok(res) => res
.inspect_err(|e| {
tracing::error!(error = %e, "could not read netmap");
})
.ok()?,
Err(_elapsed) => {
tracing::error!(
watchdog_secs = MAP_READ_WATCHDOG.as_secs(),
"netmap body did not arrive within the watchdog; ending stream to reconnect"
);
return None;
}
};
let decoded = decompress_frame(buf.as_ref())?;
let map_response: MapResponse = serde_json::from_slice(&decoded)
.inspect_err(|e| {
tracing::error!(error = %e, "deserializing netmap");
})
.ok()?;
tracing::trace!(?msg_len, ?map_response);
let packetfilter = packet_filter(&map_response);
let cap_grants = cap_grants(&map_response);
fn nonempty<T>(x: &Option<Vec<T>>) -> bool {
x.as_ref().is_some_and(|x| !x.is_empty())
}
let peer_patches: Vec<crate::PeerChange> = map_response
.peers_changed_patch
.iter()
.flatten()
.map(crate::PeerChange::from)
.collect();
let peer_update = if nonempty(&map_response.peers) {
let full_map = map_response.peers.unwrap_or_default();
Some(PeerUpdate::Full(full_map.iter().map(Into::into).collect()))
} else if nonempty(&map_response.peers_removed) || nonempty(&map_response.peers_changed) {
Some(PeerUpdate::Delta {
remove: map_response.peers_removed.unwrap_or_default(),
upsert: map_response
.peers_changed
.unwrap_or_default()
.iter()
.map(Into::into)
.collect(),
})
} else {
None
};
Some((
StateUpdate {
session_handle: (!map_response.map_session_handle.is_empty())
.then(|| map_response.map_session_handle.to_owned()),
seq: map_response.seq,
keep_alive: map_response.keep_alive.unwrap_or(false),
peer_update,
peer_patches,
user_profiles: map_response
.user_profiles
.iter()
.map(crate::UserProfile::from)
.collect(),
node: map_response.node.as_ref().map(Into::into),
derp: map_response
.derp_map
.as_ref()
.map(|x| crate::convert_derp_map(x).collect()),
ping: map_response.ping_request,
packetfilter,
cap_grants,
pop_browser_url: map_response.pop_browser_url.and_then(|u| {
u.parse()
.inspect_err(|e| tracing::error!(error = %e, "invalid pop browser url"))
.ok()
}),
dial_plan: map_response.control_dial_plan.map(Into::into),
dns_config: map_response
.dns_config
.as_ref()
.map(crate::DnsConfig::from_serde),
ssh_policy: map_response
.ssh_policy
.as_ref()
.map(crate::SshPolicy::from_serde),
tka: map_response
.tka_info
.as_ref()
.map(crate::TkaStatus::from_serde),
online_change: map_response.online_change.clone(),
peer_seen_change: map_response.peer_seen_change.clone(),
},
reader,
))
})
}
fn packet_filter(map_response: &MapResponse<'_>) -> Option<FilterUpdate> {
if map_response.packet_filter.is_none() && map_response.packet_filters.is_empty() {
return None;
}
Some((
map_response
.packet_filter
.as_ref()
.map(|x| pf_state::rules_to_pf(x).collect()),
map_response
.packet_filters
.iter()
.map(|(rule_name, rules)| {
(
rule_name.to_string(),
rules
.as_ref()
.map(|x| Some(pf_state::rules_to_pf(x).collect()))
.unwrap_or_default(),
)
})
.collect(),
))
}
fn cap_grants(map_response: &MapResponse<'_>) -> Option<Vec<ts_packetfilter_state::CapGrant>> {
if map_response.packet_filter.is_none() && map_response.packet_filters.is_empty() {
return None;
}
let mut grants = Vec::new();
if let Some(rules) = map_response.packet_filter.as_ref() {
grants.extend(pf_state::retain_cap_grants(rules));
}
for rules in map_response.packet_filters.values().flatten() {
grants.extend(pf_state::retain_cap_grants(rules));
}
Some(grants)
}
#[tracing::instrument(skip_all, fields(map_url = %url.as_str()))]
pub async fn send_map_request(
map_request: MapRequest<'_>,
url: &Url,
http2_conn: &Http2<BytesBody>,
) -> Result<impl AsyncRead + 'static, MapStreamError> {
tracing::debug!("sending map request to control server...");
let body = if cfg!(debug_assertions) {
serde_json::to_string_pretty(&map_request)?
} else {
serde_json::to_string(&map_request)?
};
tracing::trace!(
%body,
"sending map request"
);
let resp = http2_conn.post(url, None, Bytes::from(body).into()).await?;
let status = resp.status();
tracing::trace!(?status, "received map response");
if !status.is_success() {
tracing::error!(
status = status.as_u16(),
"failed to register map updates with unsuccessful HTTP status code"
);
return Err(MapStreamError::Http);
}
Ok(resp.into_read())
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use futures_util::StreamExt;
use super::*;
fn frame(bodies: &[&str]) -> Vec<u8> {
let mut buf = Vec::new();
for body in bodies {
let compressed = ruzstd::encoding::compress_to_vec(
body.as_bytes(),
ruzstd::encoding::CompressionLevel::Fastest,
);
buf.extend_from_slice(&(compressed.len() as u32).to_le_bytes());
buf.extend_from_slice(&compressed);
}
buf
}
fn frame_uncompressed(bodies: &[&str]) -> Vec<u8> {
let mut buf = Vec::new();
for body in bodies {
buf.extend_from_slice(&(body.len() as u32).to_le_bytes());
buf.extend_from_slice(body.as_bytes());
}
buf
}
struct StallAfter {
prefix: alloc::collections::VecDeque<u8>,
}
impl StallAfter {
fn new(prefix: &[u8]) -> Self {
Self {
prefix: prefix.iter().copied().collect(),
}
}
}
impl tokio::io::AsyncRead for StallAfter {
fn poll_read(
mut self: core::pin::Pin<&mut Self>,
_cx: &mut core::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> core::task::Poll<std::io::Result<()>> {
if self.prefix.is_empty() {
return core::task::Poll::Pending;
}
while buf.remaining() > 0 {
let Some(b) = self.prefix.pop_front() else {
break;
};
buf.put_slice(&[b]);
}
core::task::Poll::Ready(Ok(()))
}
}
#[tokio::test(start_paused = true)]
async fn map_stream_watchdog_ends_stream_on_silent_connection() {
let reader = StallAfter::new(&frame(&[r#"{"MapSessionHandle":"sess-1","Seq":1}"#]));
let mut stream = core::pin::pin!(map_stream(reader));
let update = stream.next().await.expect("first frame");
assert_eq!(update.seq, 1);
assert!(
stream.next().await.is_none(),
"watchdog must end the stream on a silent connection"
);
}
#[tokio::test(start_paused = true)]
async fn map_stream_watchdog_ends_stream_when_no_frame_ever_arrives() {
let reader = StallAfter::new(&[]);
let mut stream = core::pin::pin!(map_stream(reader));
assert!(
stream.next().await.is_none(),
"watchdog must end a stream that never produces a frame"
);
}
#[tokio::test(start_paused = true)]
async fn map_stream_watchdog_ends_stream_on_partial_body() {
let mut bytes = 64u32.to_le_bytes().to_vec();
bytes.extend_from_slice(b"abc");
let reader = StallAfter::new(&bytes);
let mut stream = core::pin::pin!(map_stream(reader));
assert!(
stream.next().await.is_none(),
"watchdog must end the stream when the body never completes"
);
}
#[tokio::test]
async fn map_stream_carries_session_handle_and_seq() {
let buf = frame(&[r#"{"MapSessionHandle":"sess-xyz","Seq":12}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert_eq!(update.session_handle.as_deref(), Some("sess-xyz"));
assert_eq!(update.seq, 12);
}
#[tokio::test]
async fn map_stream_empty_handle_maps_to_none() {
let buf = frame(&[r#"{"KeepAlive":true}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert_eq!(update.session_handle, None);
assert_eq!(update.seq, 0);
assert!(
update.keep_alive,
"a KeepAlive response must surface keep_alive=true"
);
}
#[tokio::test]
async fn map_stream_substantive_response_has_keep_alive_false() {
let buf = frame(&[r#"{ "Node": { "Name": "n" } }"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert_eq!(update.seq, 0, "this fixture omits Seq (Headscale-style)");
assert!(
!update.keep_alive,
"a response without KeepAlive must surface keep_alive=false (substantive)"
);
}
#[tokio::test]
async fn map_stream_surfaces_peers_changed_patch() {
let buf = frame(&[r#"{
"Seq": 7,
"PeersChangedPatch": [
{ "NodeID": 42, "Endpoints": ["203.0.113.7:41641"], "DERPRegion": 5 }
]
}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert!(
update.peer_update.is_none(),
"no whole-node set in this response"
);
assert_eq!(update.peer_patches.len(), 1);
assert_eq!(update.peer_patches[0].id, 42);
assert_eq!(
update.peer_patches[0].underlay_addresses.as_deref(),
Some(&["203.0.113.7:41641".parse().unwrap()][..])
);
assert_eq!(
update.peer_patches[0].derp_region,
Some(ts_derp::RegionId(core::num::NonZeroU32::new(5).unwrap()))
);
}
#[tokio::test]
async fn map_stream_carries_both_delta_and_patch_when_co_occurring() {
let buf = frame(&[r#"{
"Seq": 8,
"PeersChanged": [
{ "ID": 1, "StableID": "n1", "Name": "a.ts.net.", "User": 1,
"Key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000" }
],
"PeersChangedPatch": [ { "NodeID": 1, "DERPRegion": 9 } ]
}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert!(matches!(update.peer_update, Some(PeerUpdate::Delta { .. })));
assert_eq!(update.peer_patches.len(), 1, "patch must not be dropped");
assert_eq!(update.peer_patches[0].id, 1);
assert_eq!(
update.peer_patches[0].derp_region,
Some(ts_derp::RegionId(core::num::NonZeroU32::new(9).unwrap()))
);
}
#[tokio::test]
async fn empty_peers_array_is_noop_not_full_wipe() {
let buf = frame(&[r#"{ "Seq": 9, "Peers": [] }"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert!(
update.peer_update.is_none(),
"an empty Peers:[] must be a no-op, NOT PeerUpdate::Full(empty) which would wipe all peers"
);
}
#[tokio::test]
async fn nonempty_peers_array_is_full_reset() {
let buf = frame(&[r#"{
"Seq": 10,
"Peers": [
{ "ID": 1, "StableID": "n1", "Name": "a.ts.net.", "User": 1,
"Key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000" }
]
}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
match update.peer_update {
Some(PeerUpdate::Full(peers)) => {
assert_eq!(peers.len(), 1, "the one peer is the full set")
}
other => panic!("a non-empty Peers must be PeerUpdate::Full, got {other:?}"),
}
}
#[tokio::test]
async fn empty_peers_with_delta_is_delta_not_noop() {
let buf = frame(&[r#"{ "Seq": 11, "Peers": [], "PeersRemoved": [42] }"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
match update.peer_update {
Some(PeerUpdate::Delta { remove, upsert }) => {
assert_eq!(
remove.len(),
1,
"the PeersRemoved entry is honored as a delta removal"
);
assert!(upsert.is_empty(), "no PeersChanged ⇒ no upserts");
}
other => {
panic!("empty Peers + PeersRemoved must be Delta (delta honored), got {other:?}")
}
}
}
#[tokio::test]
async fn decodes_foreign_zstd_frame_interop_kat() {
const GOLDEN_ZSTD_FRAME: &[u8] = &[
0x28, 0xb5, 0x2f, 0xfd, 0x04, 0x68, 0x51, 0x01, 0x00, 0x7b, 0x22, 0x4d, 0x61, 0x70,
0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x48, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x22,
0x3a, 0x22, 0x73, 0x65, 0x73, 0x73, 0x2d, 0x67, 0x6f, 0x6c, 0x64, 0x65, 0x6e, 0x22,
0x2c, 0x22, 0x53, 0x65, 0x71, 0x22, 0x3a, 0x37, 0x7d, 0xaf, 0xf4, 0x50, 0x88,
];
assert_eq!(&GOLDEN_ZSTD_FRAME[..4], &ZSTD_MAGIC);
let mut buf = (GOLDEN_ZSTD_FRAME.len() as u32).to_le_bytes().to_vec();
buf.extend_from_slice(GOLDEN_ZSTD_FRAME);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream
.next()
.await
.expect("one update from the foreign zstd frame");
assert_eq!(update.session_handle.as_deref(), Some("sess-golden"));
assert_eq!(update.seq, 7);
}
#[tokio::test]
async fn decodes_uncompressed_frame_when_control_ignores_compress() {
let buf = frame_uncompressed(&[r#"{"MapSessionHandle":"sess-plain","Seq":3}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream
.next()
.await
.expect("uncompressed frame must still decode");
assert_eq!(update.session_handle.as_deref(), Some("sess-plain"));
assert_eq!(update.seq, 3);
}
#[tokio::test]
async fn decodes_self_compressed_zstd_frame() {
let buf = frame(&[r#"{"MapSessionHandle":"sess-self","Seq":9}"#]);
assert_eq!(&buf[4..8], &ZSTD_MAGIC);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream
.next()
.await
.expect("self-compressed frame must decode");
assert_eq!(update.session_handle.as_deref(), Some("sess-self"));
assert_eq!(update.seq, 9);
}
#[tokio::test]
async fn rejects_zstd_bomb_exceeding_decoded_bound() {
let oversized = alloc::vec![0u8; (MAX_DECODED_NETMAP + 1) as usize];
let compressed = ruzstd::encoding::compress_to_vec(
oversized.as_slice(),
ruzstd::encoding::CompressionLevel::Fastest,
);
assert!((compressed.len() as u32) < MAX_NETMAP_FRAME);
let mut buf = (compressed.len() as u32).to_le_bytes().to_vec();
buf.extend_from_slice(&compressed);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
assert!(
stream.next().await.is_none(),
"a frame decompressing past MAX_DECODED_NETMAP must end the stream, not allocate it"
);
}
#[tokio::test]
async fn accepts_large_in_bounds_decoded_frame() {
let filler = "a".repeat(1024 * 1024);
let body = alloc::format!(r#"{{"Pad":"{filler}","MapSessionHandle":"sess-big","Seq":5}}"#);
let buf = frame(&[&body]);
assert!(body.len() as u64 > 1_000_000 && (body.len() as u64) < MAX_DECODED_NETMAP);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream
.next()
.await
.expect("a large in-bounds frame must be accepted");
assert_eq!(update.session_handle.as_deref(), Some("sess-big"));
assert_eq!(update.seq, 5);
}
#[tokio::test]
async fn good_frame_then_malformed_ends_stream_after_first() {
let mut buf = frame(&[r#"{"MapSessionHandle":"sess-1","Seq":1}"#]);
let mut bad = ZSTD_MAGIC.to_vec();
bad.extend_from_slice(&[0xff, 0xff, 0xff, 0xff, 0x00]);
buf.extend_from_slice(&(bad.len() as u32).to_le_bytes());
buf.extend_from_slice(&bad);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let first = stream
.next()
.await
.expect("the first good frame must be delivered");
assert_eq!(first.seq, 1);
assert!(
stream.next().await.is_none(),
"the malformed second frame must end the stream after the first was delivered"
);
}
#[tokio::test]
async fn rejects_malformed_zstd_frame() {
let mut body = ZSTD_MAGIC.to_vec();
body.extend_from_slice(&[0xff, 0xff, 0xff, 0xff, 0x00]);
let mut buf = (body.len() as u32).to_le_bytes().to_vec();
buf.extend_from_slice(&body);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
assert!(
stream.next().await.is_none(),
"a malformed zstd frame must end the stream cleanly"
);
}
}