1use crate::{
2 packet::IpStackPacketProtocol,
3 stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport},
4};
5use async_channel::{Receiver, Sender};
6use async_executor::Executor;
7use bytes::Bytes;
8use log::trace;
9use moka::{sync::Cache, Expiry};
10use packet::{NetworkPacket, NetworkTuple};
11use parking_lot::Mutex;
12use std::time::{Duration, Instant};
13
14pub(crate) type PacketSender = Sender<NetworkPacket>;
15pub(crate) type PacketReceiver = Receiver<NetworkPacket>;
16pub(crate) type SessionCollection = Cache<NetworkTuple, PacketSender>;
17
18mod packet;
19pub mod stream;
20
21const DROP_TTL: u8 = 0;
22
23const TTL: u8 = 64;
24
25pub struct IpStackConfig {
26 pub mtu: u16,
27
28 pub tcp_timeout: Duration,
29 pub udp_timeout: Duration,
30}
31
32impl Default for IpStackConfig {
33 fn default() -> Self {
34 IpStackConfig {
35 mtu: 16384,
36
37 tcp_timeout: Duration::from_secs(3600),
38 udp_timeout: Duration::from_secs(600),
39 }
40 }
41}
42
43pub struct IpStack {
44 accept_receiver: Receiver<IpStackStream>,
45 exec: Executor<'static>,
46}
47
48impl IpStack {
49 pub fn new(
50 config: IpStackConfig,
51 recv_packet: Receiver<Bytes>,
52 send_packet: Sender<Bytes>,
53 ) -> IpStack {
54 let (accept_sender, accept_receiver) = async_channel::unbounded();
55 let exec = Executor::new();
56 exec.spawn(run(config, recv_packet, send_packet, accept_sender))
57 .detach();
58
59 IpStack {
60 accept_receiver,
61 exec,
62 }
63 }
64
65 pub async fn accept(&self) -> anyhow::Result<IpStackStream> {
66 self.exec
67 .run(async { Ok(self.accept_receiver.recv().await?) })
68 .await
69 }
70}
71
72async fn run(
73 config: IpStackConfig,
74 recv_packet: Receiver<Bytes>,
75 send_packet: Sender<Bytes>,
76 accept_sender: Sender<IpStackStream>,
77) -> anyhow::Result<()> {
78 let sessions: SessionCollection = Cache::builder()
79 .expire_after(SessionExpiry {
80 tcp_timeout: config.tcp_timeout,
81 udp_timeout: config.udp_timeout,
82 })
83 .build();
84 let sessions = Mutex::new(sessions);
85
86 let (pkt_sender, pkt_receiver) = async_channel::unbounded::<NetworkPacket>();
87
88 let accept_loop = async {
89 loop {
90 let packet = recv_packet.recv().await?;
91 let mut sessions = sessions.lock();
92 if let Some(stream) =
93 process_device_read(&packet, &mut sessions, pkt_sender.clone(), &config)
94 {
95 let _ = accept_sender.try_send(stream);
96 }
97 }
98 };
99
100 let inject_loop = async {
101 loop {
102 let packet = pkt_receiver.recv().await?;
103 let mut sessions = sessions.lock();
104 process_upstream_recv(packet, &mut sessions, send_packet.clone())?;
105 }
106 };
107
108 futures_lite::future::race(accept_loop, inject_loop).await
109}
110
111struct SessionExpiry {
112 tcp_timeout: Duration,
113 udp_timeout: Duration,
114}
115
116impl Expiry<NetworkTuple, PacketSender> for SessionExpiry {
117 fn expire_after_create(
118 &self,
119 key: &NetworkTuple,
120 _value: &PacketSender,
121 _created_at: Instant,
122 ) -> Option<Duration> {
123 Some(if key.tcp {
124 self.tcp_timeout
125 } else {
126 self.udp_timeout
127 })
128 }
129
130 fn expire_after_read(
131 &self,
132 key: &NetworkTuple,
133 _value: &PacketSender,
134 _read_at: Instant,
135 _duration_until_expiry: Option<Duration>,
136 _last_modified_at: Instant,
137 ) -> Option<Duration> {
138 self.expire_after_create(key, _value, _read_at)
139 }
140
141 fn expire_after_update(
142 &self,
143 key: &NetworkTuple,
144 _value: &PacketSender,
145 _updated_at: Instant,
146 _duration_until_expiry: Option<Duration>,
147 ) -> Option<Duration> {
148 self.expire_after_create(key, _value, _updated_at)
149 }
150}
151
152fn process_device_read(
153 data: &[u8],
154 sessions: &mut SessionCollection,
155 pkt_sender: PacketSender,
156 config: &IpStackConfig,
157) -> Option<IpStackStream> {
158 let Ok(packet) = NetworkPacket::parse(data) else {
159 return Some(IpStackStream::UnknownNetwork(data.to_owned()));
160 };
161
162 if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
163 return Some(IpStackStream::UnknownTransport(
164 IpStackUnknownTransport::new(
165 packet.src_addr().ip(),
166 packet.dst_addr().ip(),
167 packet.payload,
168 &packet.ip,
169 config.mtu,
170 pkt_sender,
171 ),
172 ));
173 }
174
175 if let Some(sender) = sessions.get(&packet.network_tuple()) {
176 let _ = sender.try_send(packet);
177 None
178 } else {
179 let (a, b) = create_stream(packet.clone(), config, pkt_sender)?;
180 sessions.insert(packet.network_tuple(), a);
181 Some(b)
182 }
183}
184
185fn create_stream(
186 packet: NetworkPacket,
187 config: &IpStackConfig,
188 pkt_sender: PacketSender,
189) -> Option<(PacketSender, IpStackStream)> {
190 match packet.transport_protocol() {
191 IpStackPacketProtocol::Tcp(h) => {
192 match IpStackTcpStream::new(
193 packet.src_addr(),
194 packet.dst_addr(),
195 h,
196 pkt_sender,
197 config.mtu,
198 config.tcp_timeout,
199 ) {
200 Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))),
201 Err(e) => {
202 log::debug!("IpStackTcpStream::new failed \"{}\"", e);
203
204 None
205 }
206 }
207 }
208 IpStackPacketProtocol::Udp => {
209 let stream = IpStackUdpStream::new(
210 packet.src_addr(),
211 packet.dst_addr(),
212 pkt_sender,
213 config.mtu,
214 config.udp_timeout,
215 );
216 let _ = stream.stream_sender().try_send(packet.clone());
217 Some((stream.stream_sender(), IpStackStream::Udp(stream)))
218 }
219 IpStackPacketProtocol::Unknown => {
220 unreachable!()
221 }
222 }
223}
224
225fn process_upstream_recv(
226 packet: NetworkPacket,
227 sessions: &mut SessionCollection,
228 device: Sender<Bytes>,
229) -> anyhow::Result<()> {
230 if packet.ttl() == 0 {
231 sessions.remove(&packet.reverse_network_tuple());
232 return Ok(());
233 }
234 #[allow(unused_mut)]
235 let Ok(mut packet_bytes) = packet.to_bytes() else {
236 trace!("to_bytes error");
237 return Ok(());
238 };
239
240 let _ = device.try_send(packet_bytes.into());
241 Ok(())
244}
245
246pub trait Device {
247 fn read_packet(&self) -> Bytes;
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::packet::{tcp_flags, IpHeader, TransportHeader};
254 use etherparse::{IpNumber, Ipv4Header, TcpHeader};
255 use futures_lite::{
256 future::{poll_fn, poll_once},
257 AsyncRead, AsyncWrite,
258 };
259 use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
260
261 fn udp_packet(src_port: u16, dst_port: u16, payload: &[u8]) -> Vec<u8> {
262 let builder =
263 etherparse::PacketBuilder::ipv4(Ipv4Addr::LOCALHOST.octets(), [10, 0, 0, 2], 64)
264 .udp(src_port, dst_port);
265 let mut buf = Vec::new();
266 builder.write(&mut buf, payload).unwrap();
267 buf
268 }
269
270 fn tcp_packet(
271 src_port: u16,
272 dst_port: u16,
273 seq: u32,
274 ack: Option<u32>,
275 flags: u8,
276 payload: &[u8],
277 ) -> Vec<u8> {
278 let mut ip = Ipv4Header::new(
279 0,
280 64,
281 IpNumber::TCP,
282 Ipv4Addr::LOCALHOST.octets(),
283 [10, 0, 0, 2],
284 )
285 .unwrap();
286 let mut tcp = TcpHeader::new(src_port, dst_port, seq, u16::MAX);
287 tcp.syn = flags & tcp_flags::SYN != 0;
288 tcp.fin = flags & tcp_flags::FIN != 0;
289 tcp.rst = flags & tcp_flags::RST != 0;
290 tcp.psh = flags & tcp_flags::PSH != 0;
291 tcp.ack = ack.is_some() || flags & tcp_flags::ACK != 0;
292 tcp.acknowledgment_number = ack.unwrap_or(0);
293 ip.set_payload_len(payload.len() + tcp.header_len())
294 .unwrap();
295 tcp.checksum = tcp.calc_checksum_ipv4(&ip, payload).unwrap();
296
297 NetworkPacket {
298 ip: IpHeader::Ipv4(ip),
299 transport: TransportHeader::Tcp(tcp),
300 payload: payload.to_vec(),
301 }
302 .to_bytes()
303 .unwrap()
304 }
305
306 fn packet_tcp_header(packet: &NetworkPacket) -> &TcpHeader {
307 let TransportHeader::Tcp(tcp) = &packet.transport else {
308 panic!("expected TCP packet");
309 };
310 tcp
311 }
312
313 #[test]
314 fn session_expiry_uses_protocol_specific_configured_timeout() {
315 let expiry = SessionExpiry {
316 tcp_timeout: Duration::from_secs(11),
317 udp_timeout: Duration::from_secs(7),
318 };
319 let (sender, _receiver) = async_channel::unbounded();
320 let tcp_tuple = NetworkTuple {
321 src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1000)),
322 dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 2000)),
323 tcp: true,
324 };
325 let udp_tuple = NetworkTuple {
326 tcp: false,
327 ..tcp_tuple
328 };
329
330 assert_eq!(
331 expiry.expire_after_create(&tcp_tuple, &sender, Instant::now()),
332 Some(Duration::from_secs(11))
333 );
334 assert_eq!(
335 expiry.expire_after_create(&udp_tuple, &sender, Instant::now()),
336 Some(Duration::from_secs(7))
337 );
338 }
339
340 #[test]
341 fn process_device_read_creates_udp_stream_and_routes_later_packets_to_it() {
342 let config = IpStackConfig::default();
343 let (packet_sender, _packet_receiver) = async_channel::unbounded();
344 let mut sessions = Cache::builder()
345 .expire_after(SessionExpiry {
346 tcp_timeout: config.tcp_timeout,
347 udp_timeout: config.udp_timeout,
348 })
349 .build();
350
351 let first = udp_packet(1000, 2000, b"one");
352 let Some(IpStackStream::Udp(stream)) =
353 process_device_read(&first, &mut sessions, packet_sender.clone(), &config)
354 else {
355 panic!("expected first UDP packet to create stream");
356 };
357
358 let second = udp_packet(1000, 2000, b"two");
359 assert!(process_device_read(&second, &mut sessions, packet_sender, &config).is_none());
360
361 assert_eq!(&*pollster::block_on(stream.recv()).unwrap(), b"one");
362 assert_eq!(&*pollster::block_on(stream.recv()).unwrap(), b"two");
363 }
364
365 #[test]
366 fn process_upstream_recv_drop_ttl_removes_reverse_session() {
367 let config = IpStackConfig::default();
368 let mut sessions: SessionCollection = Cache::builder()
369 .expire_after(SessionExpiry {
370 tcp_timeout: config.tcp_timeout,
371 udp_timeout: config.udp_timeout,
372 })
373 .build();
374
375 let raw = udp_packet(1000, 2000, b"payload");
376 let packet = NetworkPacket::parse(&raw).unwrap();
377 let (sender, _receiver) = async_channel::unbounded();
378 let removed_tuple = packet.reverse_network_tuple();
379 sessions.insert(removed_tuple, sender);
380 assert!(sessions.get(&removed_tuple).is_some());
381
382 let mut drop_packet = packet.clone();
383 match &mut drop_packet.ip {
384 packet::IpHeader::Ipv4(ip) => ip.time_to_live = DROP_TTL,
385 packet::IpHeader::Ipv6(ip) => ip.hop_limit = DROP_TTL,
386 }
387 let (device_sender, _device_receiver) = async_channel::unbounded();
388
389 process_upstream_recv(drop_packet, &mut sessions, device_sender).unwrap();
390 assert!(sessions.get(&removed_tuple).is_none());
391 }
392
393 #[test]
394 fn tcp_happy_path_handshake_write_ack_and_read_payload() {
395 let config = IpStackConfig {
396 mtu: 1500,
397 tcp_timeout: Duration::from_secs(60),
398 udp_timeout: Duration::from_secs(60),
399 };
400 let (packet_sender, packet_receiver) = async_channel::unbounded();
401 let mut sessions = Cache::builder()
402 .expire_after(SessionExpiry {
403 tcp_timeout: config.tcp_timeout,
404 udp_timeout: config.udp_timeout,
405 })
406 .build();
407
408 let syn = tcp_packet(1000, 2000, 1000, None, tcp_flags::SYN, &[]);
409 let Some(IpStackStream::Tcp(stream)) =
410 process_device_read(&syn, &mut sessions, packet_sender.clone(), &config)
411 else {
412 panic!("expected SYN to create TCP stream");
413 };
414 let mut stream = Box::pin(stream);
415
416 let mut empty = [];
417 let first_read = pollster::block_on(poll_once(poll_fn(|cx| {
418 stream.as_mut().poll_read(cx, &mut empty)
419 })));
420 assert!(first_read.is_none());
421
422 let syn_ack = packet_receiver.try_recv().unwrap();
423 let syn_ack_tcp = packet_tcp_header(&syn_ack);
424 assert!(syn_ack_tcp.syn);
425 assert!(syn_ack_tcp.ack);
426 assert_eq!(syn_ack_tcp.sequence_number, 100);
427 assert_eq!(syn_ack_tcp.acknowledgment_number, 1001);
428
429 let client_ack = tcp_packet(
430 1000,
431 2000,
432 1001,
433 Some(syn_ack_tcp.sequence_number + 1),
434 tcp_flags::ACK,
435 &[],
436 );
437 assert!(
438 process_device_read(&client_ack, &mut sessions, packet_sender.clone(), &config)
439 .is_none()
440 );
441 let establish = pollster::block_on(poll_once(poll_fn(|cx| {
442 stream.as_mut().poll_read(cx, &mut empty)
443 })));
444 assert!(establish.is_none());
445
446 let written =
447 pollster::block_on(poll_fn(|cx| stream.as_mut().poll_write(cx, b"server-data")))
448 .unwrap();
449 assert_eq!(written, b"server-data".len());
450
451 let outbound = packet_receiver.try_recv().unwrap();
452 let outbound_tcp = packet_tcp_header(&outbound);
453 assert!(outbound_tcp.psh);
454 assert!(outbound_tcp.ack);
455 assert_eq!(outbound.payload, b"server-data");
456
457 let server_next_seq = outbound_tcp.sequence_number + outbound.payload.len() as u32;
458 let ack_server_data =
459 tcp_packet(1000, 2000, 1001, Some(server_next_seq), tcp_flags::ACK, &[]);
460 assert!(process_device_read(
461 &ack_server_data,
462 &mut sessions,
463 packet_sender.clone(),
464 &config
465 )
466 .is_none());
467 let ack_poll = pollster::block_on(poll_once(poll_fn(|cx| {
468 stream.as_mut().poll_read(cx, &mut empty)
469 })));
470 assert!(ack_poll.is_none());
471
472 let inbound = tcp_packet(
473 1000,
474 2000,
475 1001,
476 Some(server_next_seq),
477 tcp_flags::PSH | tcp_flags::ACK,
478 b"client-data",
479 );
480 assert!(process_device_read(&inbound, &mut sessions, packet_sender, &config).is_none());
481
482 let mut read_buf = [0; 32];
483 let read =
484 pollster::block_on(poll_fn(|cx| stream.as_mut().poll_read(cx, &mut read_buf))).unwrap();
485 assert_eq!(&read_buf[..read], b"client-data");
486
487 let data_ack = packet_receiver.try_recv().unwrap();
488 let data_ack_tcp = packet_tcp_header(&data_ack);
489 assert!(data_ack_tcp.ack);
490 assert_eq!(
491 data_ack_tcp.acknowledgment_number,
492 1001 + b"client-data".len() as u32
493 );
494 }
495}