#![doc = include_str!("../README.md")]
use std::{collections::HashMap, sync::Arc, time::Instant};
use ts_bart::RoutingTable;
use ts_overlay_router as or;
use ts_packet::PacketMut;
use ts_packetfilter::{FilterExt, IpProto};
use ts_time::{Handle, Scheduler};
use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId};
use ts_tunnel::{Endpoint, NodeKeyPair};
use ts_underlay_router as ur;
pub mod async_tokio;
pub enum Subsystem {
Wireguard,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CapturePath {
FromLocal = 0,
FromPeer = 1,
SynthesizedToLocal = 2,
SynthesizedToPeer = 3,
}
impl CapturePath {
pub fn code(self) -> u16 {
self as u16
}
}
pub type CaptureHook = std::sync::Arc<dyn Fn(CapturePath, &[u8]) + Send + Sync>;
pub struct DataPlane {
pub wireguard: Endpoint,
pub or_out: or::outbound::Router,
pub ur_out: ur::outbound::Router,
pub src_filter_in: Arc<ts_bart::Table<PeerId>>,
pub or_in: or::inbound::Router,
pub packet_filter: Arc<dyn ts_packetfilter::Filter + Send + Sync>,
pub events: Scheduler<Subsystem>,
pub wg_next: Option<Handle<Subsystem>>,
pub capture: Option<CaptureHook>,
}
impl DataPlane {
pub fn new(my_key: NodeKeyPair) -> Self {
DataPlane {
wireguard: Endpoint::new(my_key),
or_out: Default::default(),
ur_out: Default::default(),
src_filter_in: Default::default(),
or_in: Default::default(),
events: Default::default(),
packet_filter: Arc::new(ts_packetfilter::DropAllFilter),
wg_next: None,
capture: None,
}
}
#[tracing::instrument(skip_all, fields(n_packets = packets.len()))]
pub fn process_outbound(&mut self, packets: Vec<PacketMut>) -> OutboundResult {
if let Some(hook) = &self.capture {
for p in &packets {
hook(CapturePath::FromLocal, p.as_ref());
}
}
let or::outbound::Result {
to_wireguard,
loopback,
} = self.or_out.route(packets);
let to_wireguard = to_wireguard
.into_iter()
.map(|(k, v)| (ts_tunnel::PeerId(k.0), v))
.collect::<Vec<_>>();
let ts_tunnel::SendResult {
to_peers: encrypted,
} = self.wireguard.send(to_wireguard);
let to_peers = self
.ur_out
.route(encrypted.into_iter().map(|(k, v)| (PeerId(k.0), v)));
if let Some(next) = self.wireguard.next_event()
&& let Some(prev) = self
.wg_next
.replace(self.events.add(next, Subsystem::Wireguard))
{
prev.cancel();
}
OutboundResult { to_peers, loopback }
}
pub fn process_inbound(
&mut self,
packets: impl IntoIterator<Item = PacketMut>,
) -> InboundResult {
let ts_tunnel::RecvResult { to_local, to_peers } = self.wireguard.recv(packets);
if let Some(hook) = &self.capture {
for packets in to_local.values() {
for p in packets {
hook(CapturePath::FromPeer, p.as_ref());
}
}
}
let to_local = to_local
.into_iter()
.map(|(peer_id, mut packets)| -> Vec<PacketMut> {
let _span = tracing::trace_span!(
"src_filter_inbound",
peer_id = ?peer_id,
n_packet = packets.len(),
)
.entered();
packets.retain(|packet| {
let Some(src) = packet.get_src_addr() else {
tracing::trace!("does not look like ip packet");
return false;
};
let verdict = if let Some(allowed_peer) = self.src_filter_in.lookup(src) {
*allowed_peer == PeerId(peer_id.0)
} else {
tracing::trace!(remote_ip = %src, "unknown peer address");
false
};
tracing::trace!(?src, verdict);
verdict
});
packets
})
.map(|mut v| {
let _span =
tracing::trace_span!("packet_filter_inbound", n_packet = v.len()).entered();
v.retain(|pkt| {
let Ok(pkt) = etherparse::SlicedPacket::from_ip(pkt.as_ref()) else {
tracing::trace!("does not look like ip packet");
return false;
};
let (proto, src, dst) = match pkt.net {
Some(etherparse::NetSlice::Ipv4(ipv4)) => (
IpProto::new(ipv4.payload().ip_number.0 as _),
ipv4.header().source_addr().into(),
ipv4.header().destination_addr().into(),
),
Some(etherparse::NetSlice::Ipv6(ipv6)) => (
IpProto::new(ipv6.payload().ip_number.0 as _),
ipv6.header().source_addr().into(),
ipv6.header().destination_addr().into(),
),
_ => {
unreachable!("unexpected packet kind");
}
};
let (_src_port, dst_port) = match pkt.transport {
Some(etherparse::TransportSlice::Udp(udp)) => {
(udp.source_port(), udp.destination_port())
}
Some(etherparse::TransportSlice::Tcp(tcp)) => {
(tcp.source_port(), tcp.destination_port())
}
_ => (0, 0),
};
let info = ts_packetfilter::PacketInfo {
ip_proto: proto,
port: dst_port,
src,
dst,
};
let caps = [];
let verdict = self.packet_filter.can_access(&info, caps);
tracing::trace!(?info, ?caps, verdict);
verdict
});
v
});
let to_peers = to_peers
.into_iter()
.map(|(k, v)| (ts_transport::PeerId(k.0), v));
let to_local = self.or_in.route(to_local.flatten());
let to_peers = self.ur_out.route(to_peers);
if let Some(next) = self.wireguard.next_event()
&& let Some(prev) = self
.wg_next
.replace(self.events.add(next, Subsystem::Wireguard))
{
prev.cancel();
}
InboundResult { to_local, to_peers }
}
pub fn next_event(&self) -> Option<Instant> {
self.events.next_dispatch()
}
pub fn process_events(&mut self) -> EventResult {
let mut to_peers = HashMap::new();
let now = Instant::now();
for event in self.events.dispatch(now) {
match event {
Subsystem::Wireguard => {
let res = self.wireguard.dispatch_events(now);
to_peers.extend(
res.to_peers
.into_iter()
.map(|(id, pkts)| (ts_transport::PeerId(id.0), pkts)),
);
}
}
}
let to_peers = self.ur_out.route(to_peers);
if let Some(next) = self.wireguard.next_event()
&& let Some(prev) = self
.wg_next
.replace(self.events.add(next, Subsystem::Wireguard))
{
prev.cancel();
}
EventResult { to_peers }
}
}
pub struct OutboundResult {
pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
pub loopback: HashMap<OverlayTransportId, Vec<PacketMut>>,
}
pub struct InboundResult {
pub to_local: HashMap<OverlayTransportId, Vec<PacketMut>>,
pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
}
#[derive(Default)]
pub struct EventResult {
pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
type CaptureLog = Arc<Mutex<Vec<(CapturePath, Vec<u8>)>>>;
#[test]
fn capture_path_codes() {
assert_eq!(CapturePath::FromLocal.code(), 0);
assert_eq!(CapturePath::FromPeer.code(), 1);
assert_eq!(CapturePath::SynthesizedToLocal.code(), 2);
assert_eq!(CapturePath::SynthesizedToPeer.code(), 3);
}
#[test]
fn capture_hook_fires_on_outbound() {
let mut dp = DataPlane::new(NodeKeyPair::new());
let recorded: CaptureLog = Arc::new(Mutex::new(Vec::new()));
let sink = recorded.clone();
dp.capture = Some(Arc::new(move |path: CapturePath, bytes: &[u8]| {
sink.lock().unwrap().push((path, bytes.to_vec()));
}));
let payload: Vec<u8> = vec![0xde, 0xad, 0xbe, 0xef];
let packet = PacketMut::from(payload.clone());
drop(dp.process_outbound(vec![packet]));
let captured = recorded.lock().unwrap();
assert_eq!(captured.len(), 1, "hook must fire exactly once per packet");
assert_eq!(captured[0].0, CapturePath::FromLocal);
assert_eq!(captured[0].1, payload);
}
}