1#![doc = include_str!("../README.md")]
2
3use std::{collections::HashMap, sync::Arc, time::Instant};
4
5use ts_bart::RoutingTable;
6use ts_overlay_router as or;
7use ts_packet::PacketMut;
8use ts_packetfilter::{FilterExt, IpProto};
9use ts_time::{Handle, Scheduler};
10use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId};
11use ts_tunnel::{Endpoint, NodeKeyPair};
12use ts_underlay_router as ur;
13
14pub mod async_tokio;
15
16pub enum Subsystem {
18 Wireguard,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CapturePath {
26 FromLocal = 0,
28 FromPeer = 1,
30 SynthesizedToLocal = 2,
34 SynthesizedToPeer = 3,
37}
38
39impl CapturePath {
40 pub fn code(self) -> u16 {
42 self as u16
43 }
44}
45
46pub type CaptureHook = std::sync::Arc<dyn Fn(CapturePath, &[u8]) + Send + Sync>;
52
53pub struct DataPlane {
55 pub wireguard: Endpoint,
57
58 pub or_out: or::outbound::Router,
60 pub ur_out: ur::outbound::Router,
62
63 pub src_filter_in: Arc<ts_bart::Table<PeerId>>,
65 pub or_in: or::inbound::Router,
67
68 pub packet_filter: Arc<dyn ts_packetfilter::Filter + Send + Sync>,
70
71 pub events: Scheduler<Subsystem>,
73
74 pub wg_next: Option<Handle<Subsystem>>,
76
77 pub capture: Option<CaptureHook>,
81}
82
83impl DataPlane {
84 pub fn new(my_key: NodeKeyPair) -> Self {
86 DataPlane {
87 wireguard: Endpoint::new(my_key),
88 or_out: Default::default(),
89 ur_out: Default::default(),
90 src_filter_in: Default::default(),
91 or_in: Default::default(),
92 events: Default::default(),
93 packet_filter: Arc::new(ts_packetfilter::DropAllFilter),
94 wg_next: None,
95 capture: None,
96 }
97 }
98
99 #[tracing::instrument(skip_all, fields(n_packets = packets.len()))]
101 pub fn process_outbound(&mut self, packets: Vec<PacketMut>) -> OutboundResult {
102 if let Some(hook) = &self.capture {
103 for p in &packets {
104 hook(CapturePath::FromLocal, p.as_ref());
105 }
106 }
107
108 let or::outbound::Result {
109 to_wireguard,
110 loopback,
111 } = self.or_out.route(packets);
112
113 let to_wireguard = to_wireguard
114 .into_iter()
115 .map(|(k, v)| (ts_tunnel::PeerId(k.0), v))
116 .collect::<Vec<_>>();
117
118 let ts_tunnel::SendResult {
119 to_peers: encrypted,
120 } = self.wireguard.send(to_wireguard);
121
122 let to_peers = self
123 .ur_out
124 .route(encrypted.into_iter().map(|(k, v)| (PeerId(k.0), v)));
125
126 if let Some(next) = self.wireguard.next_event()
127 && let Some(prev) = self
128 .wg_next
129 .replace(self.events.add(next, Subsystem::Wireguard))
130 {
131 prev.cancel();
132 }
133
134 OutboundResult { to_peers, loopback }
135 }
136
137 pub fn process_inbound(
139 &mut self,
140 packets: impl IntoIterator<Item = PacketMut>,
141 ) -> InboundResult {
142 let ts_tunnel::RecvResult { to_local, to_peers } = self.wireguard.recv(packets);
143
144 if let Some(hook) = &self.capture {
145 for packets in to_local.values() {
146 for p in packets {
147 hook(CapturePath::FromPeer, p.as_ref());
148 }
149 }
150 }
151
152 let to_local = to_local
153 .into_iter()
154 .map(|(peer_id, mut packets)| -> Vec<PacketMut> {
155 let _span = tracing::trace_span!(
156 "src_filter_inbound",
157 peer_id = ?peer_id,
158 n_packet = packets.len(),
159 )
160 .entered();
161
162 packets.retain(|packet| {
163 let Some(src) = packet.get_src_addr() else {
164 tracing::trace!("does not look like ip packet");
165 return false;
166 };
167 let verdict = if let Some(allowed_peer) = self.src_filter_in.lookup(src) {
168 *allowed_peer == PeerId(peer_id.0)
169 } else {
170 tracing::trace!(remote_ip = %src, "unknown peer address");
171 false
172 };
173 tracing::trace!(?src, verdict);
174 verdict
175 });
176
177 packets
178 })
179 .map(|mut v| {
180 let _span =
181 tracing::trace_span!("packet_filter_inbound", n_packet = v.len()).entered();
182
183 v.retain(|pkt| {
184 let Ok(pkt) = etherparse::SlicedPacket::from_ip(pkt.as_ref()) else {
185 tracing::trace!("does not look like ip packet");
186 return false;
187 };
188
189 let (proto, src, dst) = match pkt.net {
190 Some(etherparse::NetSlice::Ipv4(ipv4)) => (
191 IpProto::new(ipv4.payload().ip_number.0 as _),
192 ipv4.header().source_addr().into(),
193 ipv4.header().destination_addr().into(),
194 ),
195 Some(etherparse::NetSlice::Ipv6(ipv6)) => (
196 IpProto::new(ipv6.payload().ip_number.0 as _),
197 ipv6.header().source_addr().into(),
198 ipv6.header().destination_addr().into(),
199 ),
200 _ => {
201 unreachable!("unexpected packet kind");
202 }
203 };
204
205 let (_src_port, dst_port) = match pkt.transport {
206 Some(etherparse::TransportSlice::Udp(udp)) => {
207 (udp.source_port(), udp.destination_port())
208 }
209 Some(etherparse::TransportSlice::Tcp(tcp)) => {
210 (tcp.source_port(), tcp.destination_port())
211 }
212 _ => (0, 0),
213 };
214
215 let info = ts_packetfilter::PacketInfo {
216 ip_proto: proto,
217 port: dst_port,
218 src,
219 dst,
220 };
221
222 let caps = [];
224 let verdict = self.packet_filter.can_access(&info, caps);
225
226 tracing::trace!(?info, ?caps, verdict);
227
228 verdict
229 });
230
231 v
232 });
233
234 let to_peers = to_peers
235 .into_iter()
236 .map(|(k, v)| (ts_transport::PeerId(k.0), v));
237
238 let to_local = self.or_in.route(to_local.flatten());
239 let to_peers = self.ur_out.route(to_peers);
240
241 if let Some(next) = self.wireguard.next_event()
242 && let Some(prev) = self
243 .wg_next
244 .replace(self.events.add(next, Subsystem::Wireguard))
245 {
246 prev.cancel();
247 }
248
249 InboundResult { to_local, to_peers }
250 }
251
252 pub fn next_event(&self) -> Option<Instant> {
259 self.events.next_dispatch()
260 }
261
262 pub fn process_events(&mut self) -> EventResult {
267 let mut to_peers = HashMap::new();
268 let now = Instant::now();
269 for event in self.events.dispatch(now) {
270 match event {
271 Subsystem::Wireguard => {
272 let res = self.wireguard.dispatch_events(now);
273 to_peers.extend(
274 res.to_peers
275 .into_iter()
276 .map(|(id, pkts)| (ts_transport::PeerId(id.0), pkts)),
277 );
278 }
279 }
280 }
281 let to_peers = self.ur_out.route(to_peers);
282
283 if let Some(next) = self.wireguard.next_event()
284 && let Some(prev) = self
285 .wg_next
286 .replace(self.events.add(next, Subsystem::Wireguard))
287 {
288 prev.cancel();
289 }
290
291 EventResult { to_peers }
292 }
293}
294
295pub struct OutboundResult {
297 pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
299 pub loopback: HashMap<OverlayTransportId, Vec<PacketMut>>,
301}
302
303pub struct InboundResult {
305 pub to_local: HashMap<OverlayTransportId, Vec<PacketMut>>,
307 pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
309}
310
311#[derive(Default)]
313pub struct EventResult {
314 pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec<PacketMut>>,
316}
317
318#[cfg(test)]
319mod tests {
320 use std::sync::Mutex;
321
322 use super::*;
323
324 type CaptureLog = Arc<Mutex<Vec<(CapturePath, Vec<u8>)>>>;
326
327 #[test]
328 fn capture_path_codes() {
329 assert_eq!(CapturePath::FromLocal.code(), 0);
330 assert_eq!(CapturePath::FromPeer.code(), 1);
331 assert_eq!(CapturePath::SynthesizedToLocal.code(), 2);
332 assert_eq!(CapturePath::SynthesizedToPeer.code(), 3);
333 }
334
335 #[test]
342 fn capture_hook_fires_on_outbound() {
343 let mut dp = DataPlane::new(NodeKeyPair::new());
344
345 let recorded: CaptureLog = Arc::new(Mutex::new(Vec::new()));
346 let sink = recorded.clone();
347 dp.capture = Some(Arc::new(move |path: CapturePath, bytes: &[u8]| {
348 sink.lock().unwrap().push((path, bytes.to_vec()));
349 }));
350
351 let payload: Vec<u8> = vec![0xde, 0xad, 0xbe, 0xef];
353 let packet = PacketMut::from(payload.clone());
354
355 drop(dp.process_outbound(vec![packet]));
356
357 let captured = recorded.lock().unwrap();
358 assert_eq!(captured.len(), 1, "hook must fire exactly once per packet");
359 assert_eq!(captured[0].0, CapturePath::FromLocal);
360 assert_eq!(captured[0].1, payload);
361 }
362}