Skip to main content

ipstack_geph/
lib.rs

1use crate::{
2    packet::IpStackPacketProtocol,
3    stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport},
4};
5use async_channel::{Receiver, Sender};
6use async_executor::Executor;
7use bytes::Bytes;
8use log::trace;
9use moka::{sync::Cache, Expiry};
10use packet::{NetworkPacket, NetworkTuple};
11use parking_lot::Mutex;
12use std::time::{Duration, Instant};
13
14pub(crate) type PacketSender = Sender<NetworkPacket>;
15pub(crate) type PacketReceiver = Receiver<NetworkPacket>;
16pub(crate) type SessionCollection = Cache<NetworkTuple, PacketSender>;
17
18mod packet;
19pub mod stream;
20
21const DROP_TTL: u8 = 0;
22
23const TTL: u8 = 64;
24
25pub struct IpStackConfig {
26    pub mtu: u16,
27
28    pub tcp_timeout: Duration,
29    pub udp_timeout: Duration,
30}
31
32impl Default for IpStackConfig {
33    fn default() -> Self {
34        IpStackConfig {
35            mtu: 16384,
36
37            tcp_timeout: Duration::from_secs(3600),
38            udp_timeout: Duration::from_secs(600),
39        }
40    }
41}
42
43pub struct IpStack {
44    accept_receiver: Receiver<IpStackStream>,
45    exec: Executor<'static>,
46}
47
48impl IpStack {
49    pub fn new(
50        config: IpStackConfig,
51        recv_packet: Receiver<Bytes>,
52        send_packet: Sender<Bytes>,
53    ) -> IpStack {
54        let (accept_sender, accept_receiver) = async_channel::unbounded();
55        let exec = Executor::new();
56        exec.spawn(run(config, recv_packet, send_packet, accept_sender))
57            .detach();
58
59        IpStack {
60            accept_receiver,
61            exec,
62        }
63    }
64
65    pub async fn accept(&self) -> anyhow::Result<IpStackStream> {
66        self.exec
67            .run(async { Ok(self.accept_receiver.recv().await?) })
68            .await
69    }
70}
71
72async fn run(
73    config: IpStackConfig,
74    recv_packet: Receiver<Bytes>,
75    send_packet: Sender<Bytes>,
76    accept_sender: Sender<IpStackStream>,
77) -> anyhow::Result<()> {
78    let sessions: SessionCollection = Cache::builder()
79        .expire_after(SessionExpiry {
80            tcp_timeout: config.tcp_timeout,
81            udp_timeout: config.udp_timeout,
82        })
83        .build();
84    let sessions = Mutex::new(sessions);
85
86    let (pkt_sender, pkt_receiver) = async_channel::unbounded::<NetworkPacket>();
87
88    let accept_loop = async {
89        loop {
90            let packet = recv_packet.recv().await?;
91            let mut sessions = sessions.lock();
92            if let Some(stream) =
93                process_device_read(&packet, &mut sessions, pkt_sender.clone(), &config)
94            {
95                let _ = accept_sender.try_send(stream);
96            }
97        }
98    };
99
100    let inject_loop = async {
101        loop {
102            let packet = pkt_receiver.recv().await?;
103            let mut sessions = sessions.lock();
104            process_upstream_recv(packet, &mut sessions, send_packet.clone())?;
105        }
106    };
107
108    futures_lite::future::race(accept_loop, inject_loop).await
109}
110
111struct SessionExpiry {
112    tcp_timeout: Duration,
113    udp_timeout: Duration,
114}
115
116impl Expiry<NetworkTuple, PacketSender> for SessionExpiry {
117    fn expire_after_create(
118        &self,
119        key: &NetworkTuple,
120        _value: &PacketSender,
121        _created_at: Instant,
122    ) -> Option<Duration> {
123        Some(if key.tcp {
124            self.tcp_timeout
125        } else {
126            self.udp_timeout
127        })
128    }
129
130    fn expire_after_read(
131        &self,
132        key: &NetworkTuple,
133        _value: &PacketSender,
134        _read_at: Instant,
135        _duration_until_expiry: Option<Duration>,
136        _last_modified_at: Instant,
137    ) -> Option<Duration> {
138        self.expire_after_create(key, _value, _read_at)
139    }
140
141    fn expire_after_update(
142        &self,
143        key: &NetworkTuple,
144        _value: &PacketSender,
145        _updated_at: Instant,
146        _duration_until_expiry: Option<Duration>,
147    ) -> Option<Duration> {
148        self.expire_after_create(key, _value, _updated_at)
149    }
150}
151
152fn process_device_read(
153    data: &[u8],
154    sessions: &mut SessionCollection,
155    pkt_sender: PacketSender,
156    config: &IpStackConfig,
157) -> Option<IpStackStream> {
158    let Ok(packet) = NetworkPacket::parse(data) else {
159        return Some(IpStackStream::UnknownNetwork(data.to_owned()));
160    };
161
162    if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
163        return Some(IpStackStream::UnknownTransport(
164            IpStackUnknownTransport::new(
165                packet.src_addr().ip(),
166                packet.dst_addr().ip(),
167                packet.payload,
168                &packet.ip,
169                config.mtu,
170                pkt_sender,
171            ),
172        ));
173    }
174
175    if let Some(sender) = sessions.get(&packet.network_tuple()) {
176        let _ = sender.try_send(packet);
177        None
178    } else {
179        let (a, b) = create_stream(packet.clone(), config, pkt_sender)?;
180        sessions.insert(packet.network_tuple(), a);
181        Some(b)
182    }
183}
184
185fn create_stream(
186    packet: NetworkPacket,
187    config: &IpStackConfig,
188    pkt_sender: PacketSender,
189) -> Option<(PacketSender, IpStackStream)> {
190    match packet.transport_protocol() {
191        IpStackPacketProtocol::Tcp(h) => {
192            match IpStackTcpStream::new(
193                packet.src_addr(),
194                packet.dst_addr(),
195                h,
196                pkt_sender,
197                config.mtu,
198                config.tcp_timeout,
199            ) {
200                Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))),
201                Err(e) => {
202                    log::debug!("IpStackTcpStream::new failed \"{}\"", e);
203
204                    None
205                }
206            }
207        }
208        IpStackPacketProtocol::Udp => {
209            let stream = IpStackUdpStream::new(
210                packet.src_addr(),
211                packet.dst_addr(),
212                pkt_sender,
213                config.mtu,
214                config.udp_timeout,
215            );
216            let _ = stream.stream_sender().try_send(packet.clone());
217            Some((stream.stream_sender(), IpStackStream::Udp(stream)))
218        }
219        IpStackPacketProtocol::Unknown => {
220            unreachable!()
221        }
222    }
223}
224
225fn process_upstream_recv(
226    packet: NetworkPacket,
227    sessions: &mut SessionCollection,
228    device: Sender<Bytes>,
229) -> anyhow::Result<()> {
230    if packet.ttl() == 0 {
231        sessions.remove(&packet.reverse_network_tuple());
232        return Ok(());
233    }
234    #[allow(unused_mut)]
235    let Ok(mut packet_bytes) = packet.to_bytes() else {
236        trace!("to_bytes error");
237        return Ok(());
238    };
239
240    let _ = device.try_send(packet_bytes.into());
241    // device.flush().await.unwrap();
242
243    Ok(())
244}
245
246pub trait Device {
247    fn read_packet(&self) -> Bytes;
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::packet::{tcp_flags, IpHeader, TransportHeader};
254    use etherparse::{IpNumber, Ipv4Header, TcpHeader};
255    use futures_lite::{
256        future::{poll_fn, poll_once},
257        AsyncRead, AsyncWrite,
258    };
259    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
260
261    fn udp_packet(src_port: u16, dst_port: u16, payload: &[u8]) -> Vec<u8> {
262        let builder =
263            etherparse::PacketBuilder::ipv4(Ipv4Addr::LOCALHOST.octets(), [10, 0, 0, 2], 64)
264                .udp(src_port, dst_port);
265        let mut buf = Vec::new();
266        builder.write(&mut buf, payload).unwrap();
267        buf
268    }
269
270    fn tcp_packet(
271        src_port: u16,
272        dst_port: u16,
273        seq: u32,
274        ack: Option<u32>,
275        flags: u8,
276        payload: &[u8],
277    ) -> Vec<u8> {
278        let mut ip = Ipv4Header::new(
279            0,
280            64,
281            IpNumber::TCP,
282            Ipv4Addr::LOCALHOST.octets(),
283            [10, 0, 0, 2],
284        )
285        .unwrap();
286        let mut tcp = TcpHeader::new(src_port, dst_port, seq, u16::MAX);
287        tcp.syn = flags & tcp_flags::SYN != 0;
288        tcp.fin = flags & tcp_flags::FIN != 0;
289        tcp.rst = flags & tcp_flags::RST != 0;
290        tcp.psh = flags & tcp_flags::PSH != 0;
291        tcp.ack = ack.is_some() || flags & tcp_flags::ACK != 0;
292        tcp.acknowledgment_number = ack.unwrap_or(0);
293        ip.set_payload_len(payload.len() + tcp.header_len())
294            .unwrap();
295        tcp.checksum = tcp.calc_checksum_ipv4(&ip, payload).unwrap();
296
297        NetworkPacket {
298            ip: IpHeader::Ipv4(ip),
299            transport: TransportHeader::Tcp(tcp),
300            payload: payload.to_vec(),
301        }
302        .to_bytes()
303        .unwrap()
304    }
305
306    fn packet_tcp_header(packet: &NetworkPacket) -> &TcpHeader {
307        let TransportHeader::Tcp(tcp) = &packet.transport else {
308            panic!("expected TCP packet");
309        };
310        tcp
311    }
312
313    #[test]
314    fn session_expiry_uses_protocol_specific_configured_timeout() {
315        let expiry = SessionExpiry {
316            tcp_timeout: Duration::from_secs(11),
317            udp_timeout: Duration::from_secs(7),
318        };
319        let (sender, _receiver) = async_channel::unbounded();
320        let tcp_tuple = NetworkTuple {
321            src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1000)),
322            dst: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 2000)),
323            tcp: true,
324        };
325        let udp_tuple = NetworkTuple {
326            tcp: false,
327            ..tcp_tuple
328        };
329
330        assert_eq!(
331            expiry.expire_after_create(&tcp_tuple, &sender, Instant::now()),
332            Some(Duration::from_secs(11))
333        );
334        assert_eq!(
335            expiry.expire_after_create(&udp_tuple, &sender, Instant::now()),
336            Some(Duration::from_secs(7))
337        );
338    }
339
340    #[test]
341    fn process_device_read_creates_udp_stream_and_routes_later_packets_to_it() {
342        let config = IpStackConfig::default();
343        let (packet_sender, _packet_receiver) = async_channel::unbounded();
344        let mut sessions = Cache::builder()
345            .expire_after(SessionExpiry {
346                tcp_timeout: config.tcp_timeout,
347                udp_timeout: config.udp_timeout,
348            })
349            .build();
350
351        let first = udp_packet(1000, 2000, b"one");
352        let Some(IpStackStream::Udp(stream)) =
353            process_device_read(&first, &mut sessions, packet_sender.clone(), &config)
354        else {
355            panic!("expected first UDP packet to create stream");
356        };
357
358        let second = udp_packet(1000, 2000, b"two");
359        assert!(process_device_read(&second, &mut sessions, packet_sender, &config).is_none());
360
361        assert_eq!(&*pollster::block_on(stream.recv()).unwrap(), b"one");
362        assert_eq!(&*pollster::block_on(stream.recv()).unwrap(), b"two");
363    }
364
365    #[test]
366    fn process_upstream_recv_drop_ttl_removes_reverse_session() {
367        let config = IpStackConfig::default();
368        let mut sessions: SessionCollection = Cache::builder()
369            .expire_after(SessionExpiry {
370                tcp_timeout: config.tcp_timeout,
371                udp_timeout: config.udp_timeout,
372            })
373            .build();
374
375        let raw = udp_packet(1000, 2000, b"payload");
376        let packet = NetworkPacket::parse(&raw).unwrap();
377        let (sender, _receiver) = async_channel::unbounded();
378        let removed_tuple = packet.reverse_network_tuple();
379        sessions.insert(removed_tuple, sender);
380        assert!(sessions.get(&removed_tuple).is_some());
381
382        let mut drop_packet = packet.clone();
383        match &mut drop_packet.ip {
384            packet::IpHeader::Ipv4(ip) => ip.time_to_live = DROP_TTL,
385            packet::IpHeader::Ipv6(ip) => ip.hop_limit = DROP_TTL,
386        }
387        let (device_sender, _device_receiver) = async_channel::unbounded();
388
389        process_upstream_recv(drop_packet, &mut sessions, device_sender).unwrap();
390        assert!(sessions.get(&removed_tuple).is_none());
391    }
392
393    #[test]
394    fn tcp_happy_path_handshake_write_ack_and_read_payload() {
395        let config = IpStackConfig {
396            mtu: 1500,
397            tcp_timeout: Duration::from_secs(60),
398            udp_timeout: Duration::from_secs(60),
399        };
400        let (packet_sender, packet_receiver) = async_channel::unbounded();
401        let mut sessions = Cache::builder()
402            .expire_after(SessionExpiry {
403                tcp_timeout: config.tcp_timeout,
404                udp_timeout: config.udp_timeout,
405            })
406            .build();
407
408        let syn = tcp_packet(1000, 2000, 1000, None, tcp_flags::SYN, &[]);
409        let Some(IpStackStream::Tcp(stream)) =
410            process_device_read(&syn, &mut sessions, packet_sender.clone(), &config)
411        else {
412            panic!("expected SYN to create TCP stream");
413        };
414        let mut stream = Box::pin(stream);
415
416        let mut empty = [];
417        let first_read = pollster::block_on(poll_once(poll_fn(|cx| {
418            stream.as_mut().poll_read(cx, &mut empty)
419        })));
420        assert!(first_read.is_none());
421
422        let syn_ack = packet_receiver.try_recv().unwrap();
423        let syn_ack_tcp = packet_tcp_header(&syn_ack);
424        assert!(syn_ack_tcp.syn);
425        assert!(syn_ack_tcp.ack);
426        assert_eq!(syn_ack_tcp.sequence_number, 100);
427        assert_eq!(syn_ack_tcp.acknowledgment_number, 1001);
428
429        let client_ack = tcp_packet(
430            1000,
431            2000,
432            1001,
433            Some(syn_ack_tcp.sequence_number + 1),
434            tcp_flags::ACK,
435            &[],
436        );
437        assert!(
438            process_device_read(&client_ack, &mut sessions, packet_sender.clone(), &config)
439                .is_none()
440        );
441        let establish = pollster::block_on(poll_once(poll_fn(|cx| {
442            stream.as_mut().poll_read(cx, &mut empty)
443        })));
444        assert!(establish.is_none());
445
446        let written =
447            pollster::block_on(poll_fn(|cx| stream.as_mut().poll_write(cx, b"server-data")))
448                .unwrap();
449        assert_eq!(written, b"server-data".len());
450
451        let outbound = packet_receiver.try_recv().unwrap();
452        let outbound_tcp = packet_tcp_header(&outbound);
453        assert!(outbound_tcp.psh);
454        assert!(outbound_tcp.ack);
455        assert_eq!(outbound.payload, b"server-data");
456
457        let server_next_seq = outbound_tcp.sequence_number + outbound.payload.len() as u32;
458        let ack_server_data =
459            tcp_packet(1000, 2000, 1001, Some(server_next_seq), tcp_flags::ACK, &[]);
460        assert!(process_device_read(
461            &ack_server_data,
462            &mut sessions,
463            packet_sender.clone(),
464            &config
465        )
466        .is_none());
467        let ack_poll = pollster::block_on(poll_once(poll_fn(|cx| {
468            stream.as_mut().poll_read(cx, &mut empty)
469        })));
470        assert!(ack_poll.is_none());
471
472        let inbound = tcp_packet(
473            1000,
474            2000,
475            1001,
476            Some(server_next_seq),
477            tcp_flags::PSH | tcp_flags::ACK,
478            b"client-data",
479        );
480        assert!(process_device_read(&inbound, &mut sessions, packet_sender, &config).is_none());
481
482        let mut read_buf = [0; 32];
483        let read =
484            pollster::block_on(poll_fn(|cx| stream.as_mut().poll_read(cx, &mut read_buf))).unwrap();
485        assert_eq!(&read_buf[..read], b"client-data");
486
487        let data_ack = packet_receiver.try_recv().unwrap();
488        let data_ack_tcp = packet_tcp_header(&data_ack);
489        assert!(data_ack_tcp.ack);
490        assert_eq!(
491            data_ack_tcp.acknowledgment_number,
492            1001 + b"client-data".len() as u32
493        );
494    }
495}