use alloc::collections::BTreeMap;
use bytes::Bytes;
use futures_util::Stream;
use tokio::{
io::{AsyncRead, AsyncReadExt},
sync::watch,
};
use ts_control_serde::{MapRequest, MapResponse, PingRequest};
use ts_http_util::{BytesBody, ClientExt, Http2, ResponseExt};
use ts_packet::PacketMut;
use ts_packetfilter as pf;
use ts_packetfilter_state as pf_state;
use url::Url;
use crate::{DialPlan, NodeId};
#[derive(Debug, thiserror::Error)]
pub enum MapStreamError {
#[error("failed to deserialize map response message: {0}")]
DeserializeFailed(serde_json::Error),
#[error("failed to forward ping request to handler; channel was not initialized")]
InvalidPingChannel,
#[error(transparent)]
JoinFailed(#[from] tokio::task::JoinError),
#[error("map stream message was missing length: {0}")]
LengthMissing(std::io::Error),
#[error("failed to register for map updates; control returned HTTP {0}")]
MapRequestFailed(u16),
#[error("failed to construct request")]
Request,
#[error("failed to serialize map request message: {0}")]
SerializeFailed(serde_json::Error),
#[error("map stream encountered unexpected EOF: {0}")]
UnexpectedEof(std::io::Error),
#[error(transparent)]
Utf8Error(#[from] core::str::Utf8Error),
#[error(transparent)]
WatchRecv(#[from] watch::error::RecvError),
#[error(transparent)]
Http(#[from] ts_http_util::Error),
}
impl From<MapStreamError> for std::io::Error {
fn from(value: MapStreamError) -> Self {
std::io::Error::other(value.to_string())
}
}
#[derive(Debug)]
pub enum PeerUpdate {
Full(Vec<crate::Node>),
Delta {
upsert: Vec<crate::Node>,
remove: Vec<NodeId>,
},
}
pub type FilterUpdate = (Option<pf::Ruleset>, BTreeMap<String, Option<pf::Ruleset>>);
#[derive(Debug)]
pub struct StateUpdate {
pub derp: Option<crate::DerpMap>,
pub node: Option<crate::Node>,
pub peer_update: Option<PeerUpdate>,
pub ping: Option<PingRequest>,
pub packetfilter: Option<FilterUpdate>,
pub pop_browser_url: Option<Url>,
pub dial_plan: Option<DialPlan>,
}
pub fn map_stream(reader: impl AsyncRead + Unpin) -> impl Stream<Item = StateUpdate> {
futures_util::stream::unfold(reader, async |mut reader| {
let msg_len = reader
.read_u32_le()
.await
.inspect_err(|e| {
tracing::error!(error = %e, "could not read netmap length");
})
.ok()?;
let mut buf = PacketMut::new(msg_len as usize);
tracing::trace!(?msg_len, "reading netmap");
reader
.read_exact(buf.as_mut())
.await
.inspect_err(|e| {
tracing::error!(error = %e, "could not read netmap");
})
.ok()?;
let map_response: MapResponse = serde_json::from_slice(buf.as_ref())
.inspect_err(|e| {
tracing::error!(error = %e, "deserializing netmap");
})
.ok()?;
tracing::trace!(?msg_len, ?map_response);
let packetfilter = packet_filter(&map_response);
fn nonempty<T>(x: &Option<Vec<T>>) -> bool {
x.as_ref().is_some_and(|x| !x.is_empty())
}
let peer_update = if let Some(full_map) = map_response.peers {
Some(PeerUpdate::Full(full_map.iter().map(Into::into).collect()))
} else if nonempty(&map_response.peers_removed) || nonempty(&map_response.peers_changed) {
Some(PeerUpdate::Delta {
remove: map_response.peers_removed.unwrap_or_default(),
upsert: map_response
.peers_changed
.unwrap_or_default()
.iter()
.map(Into::into)
.collect(),
})
} else {
None
};
Some((
StateUpdate {
peer_update,
node: map_response.node.as_ref().map(Into::into),
derp: map_response
.derp_map
.as_ref()
.map(|x| crate::convert_derp_map(x).collect()),
ping: map_response.ping_request,
packetfilter,
pop_browser_url: map_response.pop_browser_url.and_then(|u| {
u.parse()
.inspect_err(|e| tracing::error!(error = %e, "invalid pop browser url"))
.ok()
}),
dial_plan: map_response.control_dial_plan.map(Into::into),
},
reader,
))
})
}
fn packet_filter(map_response: &MapResponse<'_>) -> Option<FilterUpdate> {
if map_response.packet_filter.is_none() && map_response.packet_filters.is_empty() {
return None;
}
Some((
map_response
.packet_filter
.as_ref()
.map(|x| pf_state::rules_to_pf(x).collect()),
map_response
.packet_filters
.iter()
.map(|(rule_name, rules)| {
(
rule_name.to_string(),
rules
.as_ref()
.map(|x| Some(pf_state::rules_to_pf(x).collect()))
.unwrap_or_default(),
)
})
.collect(),
))
}
#[tracing::instrument(skip_all, fields(map_url = %url.as_str()))]
pub async fn send_map_request(
map_request: MapRequest<'_>,
url: &Url,
http2_conn: &Http2<BytesBody>,
) -> Result<impl AsyncRead + 'static, MapStreamError> {
tracing::debug!("sending map request to control server...");
let body = if cfg!(debug_assertions) {
serde_json::to_string_pretty(&map_request).map_err(MapStreamError::SerializeFailed)
} else {
serde_json::to_string(&map_request).map_err(MapStreamError::SerializeFailed)
}?;
tracing::trace!(
%body,
"sending map request"
);
let resp = http2_conn.post(url, None, Bytes::from(body).into()).await?;
let status = resp.status();
tracing::trace!(?status, "received map response");
if !status.is_success() {
return Err(MapStreamError::MapRequestFailed(status.as_u16()));
}
Ok(resp.into_read())
}