1#![doc = include_str!("../README.md")]
2
3use std::{collections::HashMap, sync::Arc, time::Instant};
4
5use ts_bart::RoutingTable;
6use ts_overlay_router as or;
7use ts_packet::PacketMut;
8use ts_packetfilter::{FilterExt, IpProto};
9use ts_time::{Handle, Scheduler};
10use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId};
11use ts_tunnel::{Endpoint, NodeKeyPair};
12use ts_underlay_router as ur;
13
14pub mod async_tokio;
15
16pub enum Subsystem {
18 Wireguard,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CapturePath {
26 FromLocal = 0,
28 FromPeer = 1,
30 SynthesizedToLocal = 2,
34 SynthesizedToPeer = 3,
37}
38
39impl CapturePath {
40 pub fn code(self) -> u16 {
42 self as u16
43 }
44}
45
46pub type CaptureHook = std::sync::Arc<dyn Fn(CapturePath, &[u8]) + Send + Sync>;
52
53pub struct DataPlane {
55 pub wireguard: Endpoint,
57
58 pub or_out: or::outbound::Router,
60 pub ur_out: ur::outbound::Router,
62
63 pub src_filter_in: Arc<ts_bart::Table<PeerId>>,
65 pub or_in: or::inbound::Router,
67
68 pub packet_filter: Arc<dyn ts_packetfilter::Filter + Send + Sync>,
70
71 pub events: Scheduler<Subsystem>,
73
74 pub wg_next: Option<Handle<Subsystem>>,
76
77 pub capture: Option<CaptureHook>,
81}
82
83impl DataPlane {
84 pub fn new(my_key: NodeKeyPair) -> Self {
86 DataPlane {
87 wireguard: Endpoint::new(my_key),
88 or_out: Default::default(),
89 ur_out: Default::default(),
90 src_filter_in: Default::default(),
91 or_in: Default::default(),
92 events: Default::default(),
93 packet_filter: Arc::new(ts_packetfilter::DropAllFilter),
94 wg_next: None,
95 capture: None,
96 }
97 }
98
99 #[tracing::instrument(skip_all, fields(n_packets = packets.len()))]
101 pub fn process_outbound(&mut self, packets: Vec<PacketMut>) -> OutboundResult {
102 if let Some(hook) = &self.capture {
103 for p in &packets {
104 hook(CapturePath::FromLocal, p.as_ref());
105 }
106 }
107
108 let or::outbound::Result {
109 to_wireguard,
110 loopback,
111 } = self.or_out.route(packets);
112
113 let to_wireguard = to_wireguard
114 .into_iter()
115 .map(|(k, v)| (ts_tunnel::PeerId(k.0), v))
116 .collect::<Vec<_>>();
117
118 let ts_tunnel::SendResult {
119 to_peers: encrypted,
120 } = self.wireguard.send(to_wireguard);
121
122 let to_peers = self
123 .ur_out
124 .route(encrypted.into_iter().map(|(k, v)| (PeerId(k.0), v)));
125
126 if let Some(next) = self.wireguard.next_event()
127 && let Some(prev) = self
128 .wg_next
129 .replace(self.events.add(next, Subsystem::Wireguard))
130 {
131 prev.cancel();
132 }
133
134 OutboundResult { to_peers, loopback }
135 }
136
137 pub fn process_inbound(
139 &mut self,
140 packets: impl IntoIterator<Item = PacketMut>,
141 ) -> InboundResult {
142 let ts_tunnel::RecvResult { to_local, to_peers } = self.wireguard.recv(packets);
143
144 if let Some(hook) = &self.capture {
145 for packets in to_local.values() {
146 for p in packets {
147 hook(CapturePath::FromPeer, p.as_ref());
148 }
149 }
150 }
151
152 let to_local = to_local
153 .into_iter()
154 .map(|(peer_id, mut packets)| -> Vec<PacketMut> {
155 let _span = tracing::trace_span!(
156 "src_filter_inbound",
157 peer_id = ?peer_id,
158 n_packet = packets.len(),
159 )
160 .entered();
161
162 packets.retain(|packet| {
163 let Some(src) = packet.get_src_addr() else {
164 tracing::trace!("does not look like ip packet");
165 return false;
166 };
167 let verdict = if let Some(allowed_peer) = self.src_filter_in.lookup(src) {
168 *allowed_peer == PeerId(peer_id.0)
169 } else {
170 tracing::trace!(remote_ip = %src, "unknown peer address");
171 false
172 };
173 tracing::trace!(?src, verdict);
174 verdict
175 });
176
177 packets
178 })
179 .map(|mut v| {
180 let _span =
181 tracing::trace_span!("packet_filter_inbound", n_packet = v.len()).entered();
182
183 v.retain(|pkt| {
184 let Ok(pkt) = etherparse::SlicedPacket::from_ip(pkt.as_ref()) else {
185 tracing::trace!("does not look like ip packet");
186 return false;
187 };
188
189 let (proto, src, dst) = match pkt.net {
190 Some(etherparse::NetSlice::Ipv4(ipv4)) => (
191 IpProto::new(ipv4.payload().ip_number.0 as _),
192 ipv4.header().source_addr().into(),
193 ipv4.header().destination_addr().into(),
194 ),
195 Some(etherparse::NetSlice::Ipv6(ipv6)) => (
196 IpProto::new(ipv6.payload().ip_number.0 as _),
197 ipv6.header().source_addr().into(),
198 ipv6.header().destination_addr().into(),
199 ),
200 _ => {
201 tracing::trace!("parsed packet is neither IPv4 nor IPv6; dropping");
207 return false;
208 }
209 };
210
211 let (_src_port, dst_port) = match pkt.transport {
212 Some(etherparse::TransportSlice::Udp(udp)) => {
213 (udp.source_port(), udp.destination_port())
214 }
215 Some(etherparse::TransportSlice::Tcp(tcp)) => {
216 (tcp.source_port(), tcp.destination_port())
217 }
218 _ => (0, 0),
219 };
220
221 let info = ts_packetfilter::PacketInfo {
222 ip_proto: proto,
223 port: dst_port,
224 src,
225 dst,
226 };
227
228 let caps = [];
230 let verdict = self.packet_filter.can_access(&info, caps);
231
232 tracing::trace!(?info, ?caps, verdict);
233
234 verdict
235 });
236
237 v
238 });
239
240 let to_peers = to_peers
241 .into_iter()
242 .map(|(k, v)| (ts_transport::PeerId(k.0), v));
243
244 let to_local = self.or_in.route(to_local.flatten());
245 let to_peers = self.ur_out.route(to_peers);
246
247 if let Some(next) = self.wireguard.next_event()
248 && let Some(prev) = self
249 .wg_next
250 .replace(self.events.add(next, Subsystem::Wireguard))
251 {
252 prev.cancel();
253 }
254
255 InboundResult { to_local, to_peers }
256 }
257
258 pub fn next_event(&self) -> Option<Instant> {
265 self.events.next_dispatch()
266 }
267
268 pub fn process_events(&mut self) -> EventResult {
273 let mut to_peers = HashMap::new();
274 let now = Instant::now();
275 for event in self.events.dispatch(now) {
276 match event {
277 Subsystem::Wireguard => {
278 let res = self.wireguard.dispatch_events(now);
279 to_peers.extend(
280 res.to_peers
281 .into_iter()
282 .map(|(id, pkts)| (ts_transport::PeerId(id.0), pkts)),
283 );
284 }
285 }
286 }
287 let to_peers = self.ur_out.route(to_peers);
288
289 if let Some(next) = self.wireguard.next_event()
290 && let Some(prev) = self
291 .wg_next
292 .replace(self.events.add(next, Subsystem::Wireguard))
293 {
294 prev.cancel();
295 }
296
297 EventResult { to_peers }
298 }
299}
300
301pub struct OutboundResult {
303 pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
305 pub loopback: HashMap<OverlayTransportId, Vec<PacketMut>>,
307}
308
309pub struct InboundResult {
311 pub to_local: HashMap<OverlayTransportId, Vec<PacketMut>>,
313 pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
315}
316
317#[derive(Default)]
319pub struct EventResult {
320 pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
322}
323
324#[cfg(test)]
325mod tests {
326 use std::sync::Mutex;
327
328 use super::*;
329
330 type CaptureLog = Arc<Mutex<Vec<(CapturePath, Vec<u8>)>>>;
332
333 #[test]
334 fn capture_path_codes() {
335 assert_eq!(CapturePath::FromLocal.code(), 0);
336 assert_eq!(CapturePath::FromPeer.code(), 1);
337 assert_eq!(CapturePath::SynthesizedToLocal.code(), 2);
338 assert_eq!(CapturePath::SynthesizedToPeer.code(), 3);
339 }
340
341 #[test]
348 fn capture_hook_fires_on_outbound() {
349 let mut dp = DataPlane::new(NodeKeyPair::new());
350
351 let recorded: CaptureLog = Arc::new(Mutex::new(Vec::new()));
352 let sink = recorded.clone();
353 dp.capture = Some(Arc::new(move |path: CapturePath, bytes: &[u8]| {
354 sink.lock().unwrap().push((path, bytes.to_vec()));
355 }));
356
357 let payload: Vec<u8> = vec![0xde, 0xad, 0xbe, 0xef];
359 let packet = PacketMut::from(payload.clone());
360
361 drop(dp.process_outbound(vec![packet]));
362
363 let captured = recorded.lock().unwrap();
364 assert_eq!(captured.len(), 1, "hook must fire exactly once per packet");
365 assert_eq!(captured[0].0, CapturePath::FromLocal);
366 assert_eq!(captured[0].1, payload);
367 }
368}