Skip to main content

fake_tcp/
lib.rs

1//! A minimum, userspace TCP based datagram stack
2//!
3//! # Overview
4//!
5//! `fake-tcp` is a reusable library that implements a minimum TCP stack in
6//! user space using the Tun interface. It allows programs to send datagrams
7//! as if they are part of a TCP connection. `fake-tcp` has been tested to
8//! be able to pass through a variety of NAT and stateful firewalls while
9//! fully preserves certain desirable behavior such as out of order delivery
10//! and no congestion/flow controls.
11//!
12//! # Core Concepts
13//!
14//! The core of the `fake-tcp` crate compose of two structures. [`Stack`] and
15//! [`Socket`].
16//!
17//! ## [`Stack`]
18//!
19//! [`Stack`] represents a virtual TCP stack that operates at
20//! Layer 3. It is responsible for:
21//!
22//! * TCP active and passive open and handshake
23//! * `RST` handling
24//! * Interact with the Tun interface at Layer 3
25//! * Distribute incoming datagrams to corresponding [`Socket`]
26//!
27//! ## [`Socket`]
28//!
29//! [`Socket`] represents a TCP connection. It registers the identifying
30//! tuple `(src_ip, src_port, dest_ip, dest_port)` inside the [`Stack`] so
31//! so that incoming packets can be distributed to the right [`Socket`] with
32//! using a channel. It is also what the client should use for
33//! sending/receiving datagrams.
34//!
35//! # Examples
36//!
37//! Please see [`client.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/client.rs)
38//! and [`server.rs`](https://github.com/dndx/phantun/blob/main/phantun/src/bin/server.rs) files
39//! from the `phantun` crate for how to use this library in client/server mode, respectively.
40
41#![cfg_attr(feature = "benchmark", feature(test))]
42
43pub mod packet;
44
45use bytes::{Bytes, BytesMut};
46use log::{error, info, trace, warn};
47use packet::*;
48use pnet::packet::{tcp, Packet};
49use rand::prelude::*;
50use std::collections::{HashMap, HashSet};
51use std::fmt;
52use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
53use std::sync::{
54    atomic::{AtomicU32, Ordering},
55    Arc, RwLock,
56};
57use tokio::sync::broadcast;
58use tokio::sync::mpsc;
59use tokio::time;
60use tokio_tun::Tun;
61
62const TIMEOUT: time::Duration = time::Duration::from_secs(1);
63const RETRIES: usize = 6;
64const MPMC_BUFFER_LEN: usize = 512;
65const MPSC_BUFFER_LEN: usize = 128;
66const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; // 128MB
67
68#[derive(Hash, Eq, PartialEq, Clone, Debug)]
69struct AddrTuple {
70    local_addr: SocketAddr,
71    remote_addr: SocketAddr,
72}
73
74impl AddrTuple {
75    fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> AddrTuple {
76        AddrTuple {
77            local_addr,
78            remote_addr,
79        }
80    }
81}
82
83struct Shared {
84    tuples: RwLock<HashMap<AddrTuple, flume::Sender<Bytes>>>,
85    listening: RwLock<HashSet<u16>>,
86    tun: Vec<Arc<Tun>>,
87    ready: mpsc::Sender<Socket>,
88    tuples_purge: broadcast::Sender<AddrTuple>,
89}
90
91pub struct Stack {
92    shared: Arc<Shared>,
93    local_ip: Ipv4Addr,
94    local_ip6: Option<Ipv6Addr>,
95    ready: mpsc::Receiver<Socket>,
96}
97
98pub enum State {
99    Idle,
100    SynSent,
101    SynReceived,
102    Established,
103}
104
105pub struct Socket {
106    shared: Arc<Shared>,
107    tun: Arc<Tun>,
108    incoming: flume::Receiver<Bytes>,
109    local_addr: SocketAddr,
110    remote_addr: SocketAddr,
111    seq: AtomicU32,
112    ack: AtomicU32,
113    last_ack: AtomicU32,
114    state: State,
115}
116
117/// A socket that represents a unique TCP connection between a server and client.
118///
119/// The `Socket` object itself satisfies `Sync` and `Send`, which means it can
120/// be safely called within an async future.
121///
122/// To close a TCP connection that is no longer needed, simply drop this object
123/// out of scope.
124impl Socket {
125    fn new(
126        shared: Arc<Shared>,
127        tun: Arc<Tun>,
128        local_addr: SocketAddr,
129        remote_addr: SocketAddr,
130        ack: Option<u32>,
131        state: State,
132    ) -> (Socket, flume::Sender<Bytes>) {
133        let (incoming_tx, incoming_rx) = flume::bounded(MPMC_BUFFER_LEN);
134
135        (
136            Socket {
137                shared,
138                tun,
139                incoming: incoming_rx,
140                local_addr,
141                remote_addr,
142                seq: AtomicU32::new(0),
143                ack: AtomicU32::new(ack.unwrap_or(0)),
144                last_ack: AtomicU32::new(ack.unwrap_or(0)),
145                state,
146            },
147            incoming_tx,
148        )
149    }
150
151    fn build_tcp_packet(&self, flags: u8, payload: Option<&[u8]>) -> Bytes {
152        let ack = self.ack.load(Ordering::Relaxed);
153        self.last_ack.store(ack, Ordering::Relaxed);
154
155        build_tcp_packet(
156            self.local_addr,
157            self.remote_addr,
158            self.seq.load(Ordering::Relaxed),
159            ack,
160            flags,
161            payload,
162        )
163    }
164
165    /// Sends a datagram to the other end.
166    ///
167    /// This method takes `&self`, and it can be called safely by multiple threads
168    /// at the same time.
169    ///
170    /// A return of `None` means the Tun socket returned an error
171    /// and this socket must be closed.
172    pub async fn send(&self, payload: &[u8]) -> Option<()> {
173        match self.state {
174            State::Established => {
175                let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
176                self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
177                self.tun.send(&buf).await.ok().and(Some(()))
178            }
179            _ => unreachable!(),
180        }
181    }
182
183    /// Attempt to receive a datagram from the other end.
184    ///
185    /// This method takes `&self`, and it can be called safely by multiple threads
186    /// at the same time.
187    ///
188    /// A return of `None` means the TCP connection is broken
189    /// and this socket must be closed.
190    pub async fn recv(&self, buf: &mut [u8]) -> Option<usize> {
191        match self.state {
192            State::Established => {
193                self.incoming.recv_async().await.ok().and_then(|raw_buf| {
194                    let (_v4_packet, tcp_packet) = parse_ip_packet(&raw_buf).unwrap();
195
196                    if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
197                        info!("Connection {} reset by peer", self);
198                        return None;
199                    }
200
201                    let payload = tcp_packet.payload();
202
203                    let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
204                    let last_ask = self.last_ack.load(Ordering::Relaxed);
205                    self.ack.store(new_ack, Ordering::Relaxed);
206
207                    if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN {
208                        let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
209                        if let Err(e) = self.tun.try_send(&buf) {
210                            // This should not really happen as we have not sent anything for
211                            // quite some time...
212                            info!("Connection {} unable to send idling ACK back: {}", self, e)
213                        }
214                    }
215
216                    buf[..payload.len()].copy_from_slice(payload);
217
218                    Some(payload.len())
219                })
220            }
221            _ => unreachable!(),
222        }
223    }
224
225    async fn accept(mut self) {
226        for _ in 0..RETRIES {
227            match self.state {
228                State::Idle => {
229                    let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None);
230                    // ACK set by constructor
231                    self.tun.send(&buf).await.unwrap();
232                    self.state = State::SynReceived;
233                    info!("Sent SYN + ACK to client");
234                }
235                State::SynReceived => {
236                    let res = time::timeout(TIMEOUT, self.incoming.recv_async()).await;
237                    if let Ok(buf) = res {
238                        let buf = buf.unwrap();
239                        let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap();
240
241                        if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
242                            return;
243                        }
244
245                        if tcp_packet.get_flags() == tcp::TcpFlags::ACK
246                            && tcp_packet.get_acknowledgement()
247                                == self.seq.load(Ordering::Relaxed) + 1
248                        {
249                            // found our ACK
250                            self.seq.fetch_add(1, Ordering::Relaxed);
251                            self.state = State::Established;
252
253                            info!("Connection from {:?} established", self.remote_addr);
254                            let ready = self.shared.ready.clone();
255                            if let Err(e) = ready.send(self).await {
256                                error!("Unable to send accepted socket to ready queue: {}", e);
257                            }
258                            return;
259                        }
260                    } else {
261                        info!("Waiting for client ACK timed out");
262                        self.state = State::Idle;
263                    }
264                }
265                _ => unreachable!(),
266            }
267        }
268    }
269
270    async fn connect(&mut self) -> Option<()> {
271        for _ in 0..RETRIES {
272            match self.state {
273                State::Idle => {
274                    let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None);
275                    self.tun.send(&buf).await.unwrap();
276                    self.state = State::SynSent;
277                    info!("Sent SYN to server");
278                }
279                State::SynSent => {
280                    match time::timeout(TIMEOUT, self.incoming.recv_async()).await {
281                        Ok(buf) => {
282                            let buf = buf.unwrap();
283                            let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap();
284
285                            if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
286                                return None;
287                            }
288
289                            if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK
290                                && tcp_packet.get_acknowledgement()
291                                    == self.seq.load(Ordering::Relaxed) + 1
292                            {
293                                // found our SYN + ACK
294                                self.seq.fetch_add(1, Ordering::Relaxed);
295                                self.ack
296                                    .store(tcp_packet.get_sequence() + 1, Ordering::Relaxed);
297
298                                // send ACK to finish handshake
299                                let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
300                                self.tun.send(&buf).await.unwrap();
301
302                                self.state = State::Established;
303
304                                info!("Connection to {:?} established", self.remote_addr);
305                                return Some(());
306                            }
307                        }
308                        Err(_) => {
309                            info!("Waiting for SYN + ACK timed out");
310                            self.state = State::Idle;
311                        }
312                    }
313                }
314                _ => unreachable!(),
315            }
316        }
317
318        None
319    }
320}
321
322impl Drop for Socket {
323    /// Drop the socket and close the TCP connection
324    fn drop(&mut self) {
325        let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
326        // dissociates ourself from the dispatch map
327        assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
328        // purge cache
329        self.shared.tuples_purge.send(tuple).unwrap();
330
331        let buf = build_tcp_packet(
332            self.local_addr,
333            self.remote_addr,
334            self.seq.load(Ordering::Relaxed),
335            0,
336            tcp::TcpFlags::RST,
337            None,
338        );
339        if let Err(e) = self.tun.try_send(&buf) {
340            warn!("Unable to send RST to remote end: {}", e);
341        }
342
343        info!("Fake TCP connection to {} closed", self);
344    }
345}
346
347impl fmt::Display for Socket {
348    /// User-friendly string representation of the socket
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        write!(
351            f,
352            "(Fake TCP connection from {} to {})",
353            self.local_addr, self.remote_addr
354        )
355    }
356}
357
358/// A userspace TCP state machine
359impl Stack {
360    /// Create a new stack, `tun` is an array of [`Tun`](tokio_tun::Tun).
361    /// When more than one [`Tun`](tokio_tun::Tun) object is passed in, same amount
362    /// of reader will be spawned later. This allows user to utilize the performance
363    /// benefit of Multiqueue Tun support on machines with SMP.
364    pub fn new(tun: Vec<Tun>, local_ip: Ipv4Addr, local_ip6: Option<Ipv6Addr>) -> Stack {
365        let tun: Vec<Arc<Tun>> = tun.into_iter().map(Arc::new).collect();
366        let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
367        let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
368        let shared = Arc::new(Shared {
369            tuples: RwLock::new(HashMap::new()),
370            tun: tun.clone(),
371            listening: RwLock::new(HashSet::new()),
372            ready: ready_tx,
373            tuples_purge: tuples_purge_tx.clone(),
374        });
375
376        for t in tun {
377            tokio::spawn(Stack::reader_task(
378                t,
379                shared.clone(),
380                tuples_purge_tx.subscribe(),
381            ));
382        }
383
384        Stack {
385            shared,
386            local_ip,
387            local_ip6,
388            ready: ready_rx,
389        }
390    }
391
392    /// Listens for incoming connections on the given `port`.
393    pub fn listen(&mut self, port: u16) {
394        assert!(self.shared.listening.write().unwrap().insert(port));
395    }
396
397    /// Accepts an incoming connection.
398    pub async fn accept(&mut self) -> Socket {
399        self.ready.recv().await.unwrap()
400    }
401
402    /// Connects to the remote end. `None` returned means
403    /// the connection attempt failed.
404    pub async fn connect(&mut self, addr: SocketAddr) -> Option<Socket> {
405        let mut rng = SmallRng::from_os_rng();
406        for local_port in rng.random_range(32768..=60999)..=60999 {
407            let local_addr = SocketAddr::new(
408                if addr.is_ipv4() {
409                    IpAddr::V4(self.local_ip)
410                } else {
411                    IpAddr::V6(self.local_ip6.expect("IPv6 local address undefined"))
412                },
413                local_port,
414            );
415            let tuple = AddrTuple::new(local_addr, addr);
416            let mut sock;
417
418            {
419                let mut tuples = self.shared.tuples.write().unwrap();
420                if tuples.contains_key(&tuple) {
421                    trace!(
422                        "Fake TCP connection to {}, local port number {} already in use, trying another one",
423                        addr, local_port
424                    );
425                    continue;
426                }
427
428                let incoming;
429                (sock, incoming) = Socket::new(
430                    self.shared.clone(),
431                    self.shared.tun.choose(&mut rng).unwrap().clone(),
432                    local_addr,
433                    addr,
434                    None,
435                    State::Idle,
436                );
437
438                assert!(tuples.insert(tuple, incoming).is_none());
439            }
440
441            return sock.connect().await.map(|_| sock);
442        }
443
444        error!(
445            "Fake TCP connection to {} failed, emphemeral port number exhausted",
446            addr
447        );
448        None
449    }
450
451    async fn reader_task(
452        tun: Arc<Tun>,
453        shared: Arc<Shared>,
454        mut tuples_purge: broadcast::Receiver<AddrTuple>,
455    ) {
456        let mut tuples: HashMap<AddrTuple, flume::Sender<Bytes>> = HashMap::new();
457
458        loop {
459            let mut buf = BytesMut::zeroed(MAX_PACKET_LEN);
460
461            tokio::select! {
462                size = tun.recv(&mut buf) => {
463                    let size = size.unwrap();
464                    buf.truncate(size);
465                    let buf = buf.freeze();
466
467                    match parse_ip_packet(&buf) {
468                        Some((ip_packet, tcp_packet)) => {
469                            let local_addr =
470                                SocketAddr::new(ip_packet.get_destination(), tcp_packet.get_destination());
471                            let remote_addr = SocketAddr::new(ip_packet.get_source(), tcp_packet.get_source());
472
473                            let tuple = AddrTuple::new(local_addr, remote_addr);
474                            if let Some(c) = tuples.get(&tuple) {
475                                if c.send_async(buf).await.is_err() {
476                                    trace!("Cache hit, but receiver already closed, dropping packet");
477                                }
478
479                                continue;
480
481                                // If not Ok, receiver has been closed and just fall through to the slow
482                                // path below
483                            } else {
484                                trace!("Cache miss, checking the shared tuples table for connection");
485                                let sender = {
486                                    let tuples = shared.tuples.read().unwrap();
487                                    tuples.get(&tuple).cloned()
488                                };
489
490                                if let Some(c) = sender {
491                                    trace!("Storing connection information into local tuples");
492                                    tuples.insert(tuple, c.clone());
493                                    c.send_async(buf).await.unwrap();
494                                    continue;
495                                }
496                            }
497
498                            if tcp_packet.get_flags() == tcp::TcpFlags::SYN
499                                && shared
500                                    .listening
501                                    .read()
502                                    .unwrap()
503                                    .contains(&tcp_packet.get_destination())
504                            {
505                                // SYN seen on listening socket
506                                if tcp_packet.get_sequence() == 0 {
507                                    let (sock, incoming) = Socket::new(
508                                        shared.clone(),
509                                        tun.clone(),
510                                        local_addr,
511                                        remote_addr,
512                                        Some(tcp_packet.get_sequence() + 1),
513                                        State::Idle,
514                                    );
515                                    assert!(shared
516                                        .tuples
517                                        .write()
518                                        .unwrap()
519                                        .insert(tuple, incoming)
520                                        .is_none());
521                                    tokio::spawn(sock.accept());
522                                } else {
523                                    trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
524                                    let buf = build_tcp_packet(
525                                        local_addr,
526                                        remote_addr,
527                                        0,
528                                        tcp_packet.get_sequence() + tcp_packet.payload().len() as u32 + 1, // +1 because of SYN flag set
529                                        tcp::TcpFlags::RST | tcp::TcpFlags::ACK,
530                                        None,
531                                    );
532                                    shared.tun[0].try_send(&buf).unwrap();
533                                }
534                            } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
535                                info!("Unknown TCP packet from {}, sending RST", remote_addr);
536                                let buf = build_tcp_packet(
537                                    local_addr,
538                                    remote_addr,
539                                    tcp_packet.get_acknowledgement(),
540                                    tcp_packet.get_sequence() + tcp_packet.payload().len() as u32,
541                                    tcp::TcpFlags::RST | tcp::TcpFlags::ACK,
542                                    None,
543                                );
544                                shared.tun[0].try_send(&buf).unwrap();
545                            }
546                        }
547                        None => {
548                            continue;
549                        }
550                    }
551                },
552                tuple = tuples_purge.recv() => {
553                    let tuple = tuple.unwrap();
554                    tuples.remove(&tuple);
555                    trace!("Removed cached tuple: {:?}", tuple);
556                }
557            }
558        }
559    }
560}