#![doc = include_str!("../README.md")]
pub mod async_tokio;
use std::{collections::HashMap, sync::Arc, time::Instant};
use ts_bart::RoutingTable;
use ts_keys::NodePublicKey;
use ts_overlay_router as or;
use ts_packet::PacketMut;
use ts_packetfilter::{FilterExt, IpProto};
use ts_time::{Handle, Scheduler};
use ts_transport::{OverlayTransportId, UnderlayTransportId};
use ts_tunnel::{Endpoint, NodeKeyPair, PeerConfig};
use ts_underlay_router as ur;
pub enum Subsystem {
Wireguard,
}
pub struct DataPlane {
pub wireguard: Endpoint,
pub or_out: or::outbound::Router,
pub ur_out: ur::outbound::Router,
pub src_filter_in: Arc<ts_bart::Table<NodePublicKey>>,
pub or_in: or::inbound::Router,
pub packet_filter: Arc<dyn ts_packetfilter::Filter + Send + Sync>,
pub events: Scheduler<Subsystem>,
pub wg_next: Option<Handle<Subsystem>>,
}
impl DataPlane {
pub fn new(my_key: NodeKeyPair) -> Self {
DataPlane {
wireguard: Endpoint::new(my_key),
or_out: Default::default(),
ur_out: Default::default(),
src_filter_in: Default::default(),
or_in: Default::default(),
events: Default::default(),
packet_filter: Arc::new(ts_packetfilter::DropAllFilter),
wg_next: None,
}
}
#[tracing::instrument(skip_all, fields(n_packets = packets.len()))]
pub fn process_outbound(&mut self, packets: Vec<PacketMut>) -> OutboundResult {
let or::outbound::Result {
to_wireguard,
loopback,
} = self.or_out.route(packets);
let to_wireguard = to_wireguard
.into_iter()
.map(|(k, v)| {
let id = self
.wireguard
.peer_id(k)
.or_else(|| {
self.wireguard.add_peer(PeerConfig {
key: k,
psk: [0u8; 32].into(),
})
})
.unwrap();
(id, v)
})
.collect::<Vec<_>>();
let ts_tunnel::SendResult {
to_peers: encrypted,
} = self.wireguard.send(to_wireguard);
let to_peers = self.ur_out.route(
encrypted
.into_iter()
.filter_map(|(k, v)| Some((self.wireguard.peer_key(k)?, v))),
);
if let Some(next) = self.wireguard.next_event()
&& let Some(prev) = self
.wg_next
.replace(self.events.add(next, Subsystem::Wireguard))
{
prev.cancel();
}
OutboundResult { to_peers, loopback }
}
pub fn process_inbound(
&mut self,
packets: impl IntoIterator<Item = PacketMut>,
) -> InboundResult {
let ts_tunnel::RecvResult { to_local, to_peers } = self.wireguard.recv(packets);
let to_local = to_local
.into_iter()
.map(|(peer_id, mut packets)| {
let span = tracing::trace_span!("src_filter_inbound", peer_id = ?peer_id, n_packet = packets.len(), peer_key = tracing::field::Empty).entered();
let Some(key) = self.wireguard.peer_key(peer_id) else {
tracing::warn!("no nodekey for peer");
return (peer_id, vec![]);
};
span.record("peer_key", tracing::field::display(key));
packets.retain(|packet| {
let Some(src) = packet.get_src_addr() else {
tracing::trace!("does not look like ip packet");
return false;
};
let verdict = if let Some(allowed_key) = self.src_filter_in.lookup(src) {
*allowed_key == key
} else {
false
};
tracing::trace!(?src, verdict);
verdict
});
(peer_id, packets)
})
.map(|(k, mut v)| {
let span = tracing::trace_span!("packet_filter_inbound", peer_id = ?k, n_packet = v.len(), peer_key = tracing::field::Empty).entered();
let Some(key) = self.wireguard.peer_key(k) else {
tracing::warn!("no nodekey for peer");
return (k, vec![]);
};
span.record("peer_key", tracing::field::display(key));
v.retain(|pkt| {
let Ok(pkt) = etherparse::SlicedPacket::from_ip(pkt.as_ref()) else {
tracing::trace!("does not look like ip packet");
return false;
};
let (proto, src, dst) = match pkt.net {
Some(etherparse::NetSlice::Ipv4(ipv4)) => {
(IpProto::new(ipv4.payload().ip_number.0 as _), ipv4.header().source_addr().into(), ipv4.header().destination_addr().into())
}
Some(etherparse::NetSlice::Ipv6(ipv6)) => {
(IpProto::new(ipv6.payload().ip_number.0 as _), ipv6.header().source_addr().into(), ipv6.header().destination_addr().into())
}
_ => {
unreachable!("unexpected packet kind");
}
};
let (_src_port, dst_port) = match pkt.transport {
Some(etherparse::TransportSlice::Udp(udp)) => {
(udp.source_port(), udp.destination_port())
}
Some(etherparse::TransportSlice::Tcp(tcp)) => {
(tcp.source_port(), tcp.destination_port())
}
_ => (0, 0),
};
let info = ts_packetfilter::PacketInfo {
ip_proto: proto,
port: dst_port,
src,
dst,
};
let caps = [];
let verdict = self.packet_filter
.can_access(&info, caps);
tracing::trace!(?info, ?caps, verdict);
verdict
});
(k, v)
});
let to_peers = to_peers
.into_iter()
.filter_map(|(k, v)| Some((self.wireguard.peer_key(k)?, v)));
let to_local = self
.or_in
.route(to_local.flat_map(|(_id, packets)| packets));
let to_peers = self.ur_out.route(to_peers);
if let Some(next) = self.wireguard.next_event()
&& let Some(prev) = self
.wg_next
.replace(self.events.add(next, Subsystem::Wireguard))
{
prev.cancel();
}
InboundResult { to_local, to_peers }
}
pub fn next_event(&self) -> Option<Instant> {
self.events.next_dispatch()
}
pub fn process_events(&mut self) -> EventResult {
let mut to_peers = HashMap::new();
let now = Instant::now();
for event in self.events.dispatch(now) {
match event {
Subsystem::Wireguard => {
let res = self.wireguard.dispatch_events(now);
to_peers.extend(
res.to_peers
.into_iter()
.filter_map(|(id, pkts)| Some((self.wireguard.peer_key(id)?, pkts))),
);
}
}
}
let to_peers = self.ur_out.route(to_peers);
if let Some(next) = self.wireguard.next_event()
&& let Some(prev) = self
.wg_next
.replace(self.events.add(next, Subsystem::Wireguard))
{
prev.cancel();
}
EventResult { to_peers }
}
}
pub struct OutboundResult {
pub to_peers: HashMap<(UnderlayTransportId, NodePublicKey), Vec<PacketMut>>,
pub loopback: HashMap<OverlayTransportId, Vec<PacketMut>>,
}
pub struct InboundResult {
pub to_local: HashMap<OverlayTransportId, Vec<PacketMut>>,
pub to_peers: HashMap<(UnderlayTransportId, NodePublicKey), Vec<PacketMut>>,
}
#[derive(Default)]
pub struct EventResult {
pub to_peers: HashMap<(UnderlayTransportId, NodePublicKey), Vec<PacketMut>>,
}