use alloc::collections::BTreeMap;
use bytes::Bytes;
use futures_util::Stream;
use tokio::io::{AsyncRead, AsyncReadExt};
use ts_control_serde::{MapRequest, MapResponse, PingRequest};
use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt};
use ts_packet::PacketMut;
use ts_packetfilter as pf;
use ts_packetfilter_state as pf_state;
use url::Url;
use crate::{DialPlan, NodeId};
#[derive(Debug, thiserror::Error, Clone, Copy, Eq, PartialEq)]
pub enum MapStreamError {
#[error("serialization error")]
SerDe,
#[error("unsuccessful HTTP request or upgrade")]
Http,
#[error("Network error")]
NetworkError,
}
impl From<serde_json::Error> for MapStreamError {
fn from(error: serde_json::Error) -> Self {
tracing::error!(%error, "serialization error sending map request");
MapStreamError::SerDe
}
}
impl From<ts_http_util::Error> for MapStreamError {
fn from(error: ts_http_util::Error) -> Self {
tracing::error!(%error, "http error sending map request");
if crate::http_error_is_recoverable(error) {
MapStreamError::NetworkError
} else {
MapStreamError::Http
}
}
}
impl From<MapStreamError> for crate::Error {
fn from(e: MapStreamError) -> Self {
match e {
MapStreamError::SerDe => crate::Error::Internal(
crate::InternalErrorKind::SerDe,
crate::Operation::MapRequest,
),
MapStreamError::Http => {
crate::Error::Internal(crate::InternalErrorKind::Http, crate::Operation::MapRequest)
}
MapStreamError::NetworkError => {
crate::Error::NetworkError(crate::Operation::MapRequest)
}
}
}
}
#[derive(Debug)]
pub enum PeerUpdate {
Full(Vec<crate::Node>),
Delta {
upsert: Vec<crate::Node>,
remove: Vec<NodeId>,
},
}
pub type FilterUpdate = (Option<pf::Ruleset>, BTreeMap<String, Option<pf::Ruleset>>);
#[derive(Debug)]
pub struct StateUpdate {
pub session_handle: Option<alloc::string::String>,
pub seq: i64,
pub derp: Option<crate::DerpMap>,
pub node: Option<crate::Node>,
pub peer_update: Option<PeerUpdate>,
pub peer_patches: Vec<crate::PeerChange>,
pub user_profiles: Vec<crate::UserProfile>,
pub ping: Option<PingRequest>,
pub packetfilter: Option<FilterUpdate>,
pub cap_grants: Option<Vec<ts_packetfilter_state::CapGrant>>,
pub pop_browser_url: Option<Url>,
pub dial_plan: Option<DialPlan>,
pub dns_config: Option<crate::DnsConfig>,
pub ssh_policy: Option<crate::SshPolicy>,
pub tka: Option<crate::TkaStatus>,
pub online_change: alloc::collections::BTreeMap<crate::NodeId, bool>,
pub peer_seen_change: alloc::collections::BTreeMap<crate::NodeId, bool>,
}
const MAX_NETMAP_FRAME: u32 = 16 * 1024 * 1024;
const MAP_READ_WATCHDOG: core::time::Duration = core::time::Duration::from_secs(120);
pub fn map_stream(reader: impl AsyncRead + Unpin) -> impl Stream<Item = StateUpdate> {
futures_util::stream::unfold(reader, async |mut reader| {
let msg_len = match tokio::time::timeout(MAP_READ_WATCHDOG, reader.read_u32_le()).await {
Ok(res) => res
.inspect_err(|e| {
tracing::error!(error = %e, "could not read netmap length");
})
.ok()?,
Err(_elapsed) => {
tracing::error!(
watchdog_secs = MAP_READ_WATCHDOG.as_secs(),
"no netmap frame within the keep-alive watchdog; ending stream to reconnect"
);
return None;
}
};
if msg_len > MAX_NETMAP_FRAME {
tracing::error!(
?msg_len,
max = MAX_NETMAP_FRAME,
"netmap frame too large; ending stream"
);
return None;
}
let mut buf = PacketMut::new(msg_len as usize);
tracing::trace!(?msg_len, "reading netmap");
match tokio::time::timeout(MAP_READ_WATCHDOG, reader.read_exact(buf.as_mut())).await {
Ok(res) => res
.inspect_err(|e| {
tracing::error!(error = %e, "could not read netmap");
})
.ok()?,
Err(_elapsed) => {
tracing::error!(
watchdog_secs = MAP_READ_WATCHDOG.as_secs(),
"netmap body did not arrive within the watchdog; ending stream to reconnect"
);
return None;
}
};
let map_response: MapResponse = serde_json::from_slice(buf.as_ref())
.inspect_err(|e| {
tracing::error!(error = %e, "deserializing netmap");
})
.ok()?;
tracing::trace!(?msg_len, ?map_response);
let packetfilter = packet_filter(&map_response);
let cap_grants = cap_grants(&map_response);
fn nonempty<T>(x: &Option<Vec<T>>) -> bool {
x.as_ref().is_some_and(|x| !x.is_empty())
}
let peer_patches: Vec<crate::PeerChange> = map_response
.peers_changed_patch
.iter()
.flatten()
.map(crate::PeerChange::from)
.collect();
let peer_update = if let Some(full_map) = map_response.peers {
Some(PeerUpdate::Full(full_map.iter().map(Into::into).collect()))
} else if nonempty(&map_response.peers_removed) || nonempty(&map_response.peers_changed) {
Some(PeerUpdate::Delta {
remove: map_response.peers_removed.unwrap_or_default(),
upsert: map_response
.peers_changed
.unwrap_or_default()
.iter()
.map(Into::into)
.collect(),
})
} else {
None
};
Some((
StateUpdate {
session_handle: (!map_response.map_session_handle.is_empty())
.then(|| map_response.map_session_handle.to_owned()),
seq: map_response.seq,
peer_update,
peer_patches,
user_profiles: map_response
.user_profiles
.iter()
.map(crate::UserProfile::from)
.collect(),
node: map_response.node.as_ref().map(Into::into),
derp: map_response
.derp_map
.as_ref()
.map(|x| crate::convert_derp_map(x).collect()),
ping: map_response.ping_request,
packetfilter,
cap_grants,
pop_browser_url: map_response.pop_browser_url.and_then(|u| {
u.parse()
.inspect_err(|e| tracing::error!(error = %e, "invalid pop browser url"))
.ok()
}),
dial_plan: map_response.control_dial_plan.map(Into::into),
dns_config: map_response
.dns_config
.as_ref()
.map(crate::DnsConfig::from_serde),
ssh_policy: map_response
.ssh_policy
.as_ref()
.map(crate::SshPolicy::from_serde),
tka: map_response
.tka_info
.as_ref()
.map(crate::TkaStatus::from_serde),
online_change: map_response.online_change.clone(),
peer_seen_change: map_response.peer_seen_change.clone(),
},
reader,
))
})
}
fn packet_filter(map_response: &MapResponse<'_>) -> Option<FilterUpdate> {
if map_response.packet_filter.is_none() && map_response.packet_filters.is_empty() {
return None;
}
Some((
map_response
.packet_filter
.as_ref()
.map(|x| pf_state::rules_to_pf(x).collect()),
map_response
.packet_filters
.iter()
.map(|(rule_name, rules)| {
(
rule_name.to_string(),
rules
.as_ref()
.map(|x| Some(pf_state::rules_to_pf(x).collect()))
.unwrap_or_default(),
)
})
.collect(),
))
}
fn cap_grants(map_response: &MapResponse<'_>) -> Option<Vec<ts_packetfilter_state::CapGrant>> {
if map_response.packet_filter.is_none() && map_response.packet_filters.is_empty() {
return None;
}
let mut grants = Vec::new();
if let Some(rules) = map_response.packet_filter.as_ref() {
grants.extend(pf_state::retain_cap_grants(rules));
}
for rules in map_response.packet_filters.values().flatten() {
grants.extend(pf_state::retain_cap_grants(rules));
}
Some(grants)
}
#[tracing::instrument(skip_all, fields(map_url = %url.as_str()))]
pub async fn send_map_request(
map_request: MapRequest<'_>,
url: &Url,
http2_conn: &Http2<BytesBody>,
) -> Result<impl AsyncRead + 'static, MapStreamError> {
tracing::debug!("sending map request to control server...");
let body = if cfg!(debug_assertions) {
serde_json::to_string_pretty(&map_request)?
} else {
serde_json::to_string(&map_request)?
};
tracing::trace!(
%body,
"sending map request"
);
let resp = http2_conn.post(url, None, Bytes::from(body).into()).await?;
let status = resp.status();
tracing::trace!(?status, "received map response");
if !status.is_success() {
tracing::error!(
status = status.as_u16(),
"failed to register map updates with unsuccessful HTTP status code"
);
return Err(MapStreamError::Http);
}
Ok(resp.into_read())
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use futures_util::StreamExt;
use super::*;
fn frame(bodies: &[&str]) -> Vec<u8> {
let mut buf = Vec::new();
for body in bodies {
buf.extend_from_slice(&(body.len() as u32).to_le_bytes());
buf.extend_from_slice(body.as_bytes());
}
buf
}
struct StallAfter {
prefix: alloc::collections::VecDeque<u8>,
}
impl StallAfter {
fn new(prefix: &[u8]) -> Self {
Self {
prefix: prefix.iter().copied().collect(),
}
}
}
impl tokio::io::AsyncRead for StallAfter {
fn poll_read(
mut self: core::pin::Pin<&mut Self>,
_cx: &mut core::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> core::task::Poll<std::io::Result<()>> {
if self.prefix.is_empty() {
return core::task::Poll::Pending;
}
while buf.remaining() > 0 {
let Some(b) = self.prefix.pop_front() else {
break;
};
buf.put_slice(&[b]);
}
core::task::Poll::Ready(Ok(()))
}
}
#[tokio::test(start_paused = true)]
async fn map_stream_watchdog_ends_stream_on_silent_connection() {
let reader = StallAfter::new(&frame(&[r#"{"MapSessionHandle":"sess-1","Seq":1}"#]));
let mut stream = core::pin::pin!(map_stream(reader));
let update = stream.next().await.expect("first frame");
assert_eq!(update.seq, 1);
assert!(
stream.next().await.is_none(),
"watchdog must end the stream on a silent connection"
);
}
#[tokio::test(start_paused = true)]
async fn map_stream_watchdog_ends_stream_when_no_frame_ever_arrives() {
let reader = StallAfter::new(&[]);
let mut stream = core::pin::pin!(map_stream(reader));
assert!(
stream.next().await.is_none(),
"watchdog must end a stream that never produces a frame"
);
}
#[tokio::test(start_paused = true)]
async fn map_stream_watchdog_ends_stream_on_partial_body() {
let mut bytes = 64u32.to_le_bytes().to_vec();
bytes.extend_from_slice(b"abc");
let reader = StallAfter::new(&bytes);
let mut stream = core::pin::pin!(map_stream(reader));
assert!(
stream.next().await.is_none(),
"watchdog must end the stream when the body never completes"
);
}
#[tokio::test]
async fn map_stream_carries_session_handle_and_seq() {
let buf = frame(&[r#"{"MapSessionHandle":"sess-xyz","Seq":12}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert_eq!(update.session_handle.as_deref(), Some("sess-xyz"));
assert_eq!(update.seq, 12);
}
#[tokio::test]
async fn map_stream_empty_handle_maps_to_none() {
let buf = frame(&[r#"{"KeepAlive":true}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert_eq!(update.session_handle, None);
assert_eq!(update.seq, 0);
}
#[tokio::test]
async fn map_stream_surfaces_peers_changed_patch() {
let buf = frame(&[r#"{
"Seq": 7,
"PeersChangedPatch": [
{ "NodeID": 42, "Endpoints": ["203.0.113.7:41641"], "DERPRegion": 5 }
]
}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert!(
update.peer_update.is_none(),
"no whole-node set in this response"
);
assert_eq!(update.peer_patches.len(), 1);
assert_eq!(update.peer_patches[0].id, 42);
assert_eq!(
update.peer_patches[0].underlay_addresses.as_deref(),
Some(&["203.0.113.7:41641".parse().unwrap()][..])
);
assert_eq!(
update.peer_patches[0].derp_region,
Some(ts_derp::RegionId(core::num::NonZeroU32::new(5).unwrap()))
);
}
#[tokio::test]
async fn map_stream_carries_both_delta_and_patch_when_co_occurring() {
let buf = frame(&[r#"{
"Seq": 8,
"PeersChanged": [
{ "ID": 1, "StableID": "n1", "Name": "a.ts.net.", "User": 1,
"Key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000" }
],
"PeersChangedPatch": [ { "NodeID": 1, "DERPRegion": 9 } ]
}"#]);
let mut stream = core::pin::pin!(map_stream(&buf[..]));
let update = stream.next().await.expect("one update");
assert!(matches!(update.peer_update, Some(PeerUpdate::Delta { .. })));
assert_eq!(update.peer_patches.len(), 1, "patch must not be dropped");
assert_eq!(update.peer_patches[0].id, 1);
assert_eq!(
update.peer_patches[0].derp_region,
Some(ts_derp::RegionId(core::num::NonZeroU32::new(9).unwrap()))
);
}
}