1#![cfg_attr(feature = "benchmark", feature(test))]
42
43pub mod packet;
44
45use bytes::{Bytes, BytesMut};
46use log::{error, info, trace, warn};
47use packet::*;
48use pnet::packet::{tcp, Packet};
49use rand::prelude::*;
50use std::collections::{HashMap, HashSet};
51use std::fmt;
52use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
53use std::sync::{
54 atomic::{AtomicU32, Ordering},
55 Arc, RwLock,
56};
57use tokio::sync::broadcast;
58use tokio::sync::mpsc;
59use tokio::time;
60use tokio_tun::Tun;
61
62const TIMEOUT: time::Duration = time::Duration::from_secs(1);
63const RETRIES: usize = 6;
64const MPMC_BUFFER_LEN: usize = 512;
65const MPSC_BUFFER_LEN: usize = 128;
66const MAX_UNACKED_LEN: u32 = 128 * 1024 * 1024; #[derive(Hash, Eq, PartialEq, Clone, Debug)]
69struct AddrTuple {
70 local_addr: SocketAddr,
71 remote_addr: SocketAddr,
72}
73
74impl AddrTuple {
75 fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> AddrTuple {
76 AddrTuple {
77 local_addr,
78 remote_addr,
79 }
80 }
81}
82
83struct Shared {
84 tuples: RwLock<HashMap<AddrTuple, flume::Sender<Bytes>>>,
85 listening: RwLock<HashSet<u16>>,
86 tun: Vec<Arc<Tun>>,
87 ready: mpsc::Sender<Socket>,
88 tuples_purge: broadcast::Sender<AddrTuple>,
89}
90
91pub struct Stack {
92 shared: Arc<Shared>,
93 local_ip: Ipv4Addr,
94 local_ip6: Option<Ipv6Addr>,
95 ready: mpsc::Receiver<Socket>,
96}
97
98pub enum State {
99 Idle,
100 SynSent,
101 SynReceived,
102 Established,
103}
104
105pub struct Socket {
106 shared: Arc<Shared>,
107 tun: Arc<Tun>,
108 incoming: flume::Receiver<Bytes>,
109 local_addr: SocketAddr,
110 remote_addr: SocketAddr,
111 seq: AtomicU32,
112 ack: AtomicU32,
113 last_ack: AtomicU32,
114 state: State,
115}
116
117impl Socket {
125 fn new(
126 shared: Arc<Shared>,
127 tun: Arc<Tun>,
128 local_addr: SocketAddr,
129 remote_addr: SocketAddr,
130 ack: Option<u32>,
131 state: State,
132 ) -> (Socket, flume::Sender<Bytes>) {
133 let (incoming_tx, incoming_rx) = flume::bounded(MPMC_BUFFER_LEN);
134
135 (
136 Socket {
137 shared,
138 tun,
139 incoming: incoming_rx,
140 local_addr,
141 remote_addr,
142 seq: AtomicU32::new(0),
143 ack: AtomicU32::new(ack.unwrap_or(0)),
144 last_ack: AtomicU32::new(ack.unwrap_or(0)),
145 state,
146 },
147 incoming_tx,
148 )
149 }
150
151 fn build_tcp_packet(&self, flags: u8, payload: Option<&[u8]>) -> Bytes {
152 let ack = self.ack.load(Ordering::Relaxed);
153 self.last_ack.store(ack, Ordering::Relaxed);
154
155 build_tcp_packet(
156 self.local_addr,
157 self.remote_addr,
158 self.seq.load(Ordering::Relaxed),
159 ack,
160 flags,
161 payload,
162 )
163 }
164
165 pub async fn send(&self, payload: &[u8]) -> Option<()> {
173 match self.state {
174 State::Established => {
175 let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, Some(payload));
176 self.seq.fetch_add(payload.len() as u32, Ordering::Relaxed);
177 self.tun.send(&buf).await.ok().and(Some(()))
178 }
179 _ => unreachable!(),
180 }
181 }
182
183 pub async fn recv(&self, buf: &mut [u8]) -> Option<usize> {
191 match self.state {
192 State::Established => {
193 self.incoming.recv_async().await.ok().and_then(|raw_buf| {
194 let (_v4_packet, tcp_packet) = parse_ip_packet(&raw_buf).unwrap();
195
196 if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
197 info!("Connection {} reset by peer", self);
198 return None;
199 }
200
201 let payload = tcp_packet.payload();
202
203 let new_ack = tcp_packet.get_sequence().wrapping_add(payload.len() as u32);
204 let last_ask = self.last_ack.load(Ordering::Relaxed);
205 self.ack.store(new_ack, Ordering::Relaxed);
206
207 if new_ack.overflowing_sub(last_ask).0 > MAX_UNACKED_LEN {
208 let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
209 if let Err(e) = self.tun.try_send(&buf) {
210 info!("Connection {} unable to send idling ACK back: {}", self, e)
213 }
214 }
215
216 buf[..payload.len()].copy_from_slice(payload);
217
218 Some(payload.len())
219 })
220 }
221 _ => unreachable!(),
222 }
223 }
224
225 async fn accept(mut self) {
226 for _ in 0..RETRIES {
227 match self.state {
228 State::Idle => {
229 let buf = self.build_tcp_packet(tcp::TcpFlags::SYN | tcp::TcpFlags::ACK, None);
230 self.tun.send(&buf).await.unwrap();
232 self.state = State::SynReceived;
233 info!("Sent SYN + ACK to client");
234 }
235 State::SynReceived => {
236 let res = time::timeout(TIMEOUT, self.incoming.recv_async()).await;
237 if let Ok(buf) = res {
238 let buf = buf.unwrap();
239 let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap();
240
241 if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
242 return;
243 }
244
245 if tcp_packet.get_flags() == tcp::TcpFlags::ACK
246 && tcp_packet.get_acknowledgement()
247 == self.seq.load(Ordering::Relaxed) + 1
248 {
249 self.seq.fetch_add(1, Ordering::Relaxed);
251 self.state = State::Established;
252
253 info!("Connection from {:?} established", self.remote_addr);
254 let ready = self.shared.ready.clone();
255 if let Err(e) = ready.send(self).await {
256 error!("Unable to send accepted socket to ready queue: {}", e);
257 }
258 return;
259 }
260 } else {
261 info!("Waiting for client ACK timed out");
262 self.state = State::Idle;
263 }
264 }
265 _ => unreachable!(),
266 }
267 }
268 }
269
270 async fn connect(&mut self) -> Option<()> {
271 for _ in 0..RETRIES {
272 match self.state {
273 State::Idle => {
274 let buf = self.build_tcp_packet(tcp::TcpFlags::SYN, None);
275 self.tun.send(&buf).await.unwrap();
276 self.state = State::SynSent;
277 info!("Sent SYN to server");
278 }
279 State::SynSent => {
280 match time::timeout(TIMEOUT, self.incoming.recv_async()).await {
281 Ok(buf) => {
282 let buf = buf.unwrap();
283 let (_v4_packet, tcp_packet) = parse_ip_packet(&buf).unwrap();
284
285 if (tcp_packet.get_flags() & tcp::TcpFlags::RST) != 0 {
286 return None;
287 }
288
289 if tcp_packet.get_flags() == tcp::TcpFlags::SYN | tcp::TcpFlags::ACK
290 && tcp_packet.get_acknowledgement()
291 == self.seq.load(Ordering::Relaxed) + 1
292 {
293 self.seq.fetch_add(1, Ordering::Relaxed);
295 self.ack
296 .store(tcp_packet.get_sequence() + 1, Ordering::Relaxed);
297
298 let buf = self.build_tcp_packet(tcp::TcpFlags::ACK, None);
300 self.tun.send(&buf).await.unwrap();
301
302 self.state = State::Established;
303
304 info!("Connection to {:?} established", self.remote_addr);
305 return Some(());
306 }
307 }
308 Err(_) => {
309 info!("Waiting for SYN + ACK timed out");
310 self.state = State::Idle;
311 }
312 }
313 }
314 _ => unreachable!(),
315 }
316 }
317
318 None
319 }
320}
321
322impl Drop for Socket {
323 fn drop(&mut self) {
325 let tuple = AddrTuple::new(self.local_addr, self.remote_addr);
326 assert!(self.shared.tuples.write().unwrap().remove(&tuple).is_some());
328 self.shared.tuples_purge.send(tuple).unwrap();
330
331 let buf = build_tcp_packet(
332 self.local_addr,
333 self.remote_addr,
334 self.seq.load(Ordering::Relaxed),
335 0,
336 tcp::TcpFlags::RST,
337 None,
338 );
339 if let Err(e) = self.tun.try_send(&buf) {
340 warn!("Unable to send RST to remote end: {}", e);
341 }
342
343 info!("Fake TCP connection to {} closed", self);
344 }
345}
346
347impl fmt::Display for Socket {
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350 write!(
351 f,
352 "(Fake TCP connection from {} to {})",
353 self.local_addr, self.remote_addr
354 )
355 }
356}
357
358impl Stack {
360 pub fn new(tun: Vec<Tun>, local_ip: Ipv4Addr, local_ip6: Option<Ipv6Addr>) -> Stack {
365 let tun: Vec<Arc<Tun>> = tun.into_iter().map(Arc::new).collect();
366 let (ready_tx, ready_rx) = mpsc::channel(MPSC_BUFFER_LEN);
367 let (tuples_purge_tx, _tuples_purge_rx) = broadcast::channel(16);
368 let shared = Arc::new(Shared {
369 tuples: RwLock::new(HashMap::new()),
370 tun: tun.clone(),
371 listening: RwLock::new(HashSet::new()),
372 ready: ready_tx,
373 tuples_purge: tuples_purge_tx.clone(),
374 });
375
376 for t in tun {
377 tokio::spawn(Stack::reader_task(
378 t,
379 shared.clone(),
380 tuples_purge_tx.subscribe(),
381 ));
382 }
383
384 Stack {
385 shared,
386 local_ip,
387 local_ip6,
388 ready: ready_rx,
389 }
390 }
391
392 pub fn listen(&mut self, port: u16) {
394 assert!(self.shared.listening.write().unwrap().insert(port));
395 }
396
397 pub async fn accept(&mut self) -> Socket {
399 self.ready.recv().await.unwrap()
400 }
401
402 pub async fn connect(&mut self, addr: SocketAddr) -> Option<Socket> {
405 let mut rng = SmallRng::from_os_rng();
406 for local_port in rng.random_range(32768..=60999)..=60999 {
407 let local_addr = SocketAddr::new(
408 if addr.is_ipv4() {
409 IpAddr::V4(self.local_ip)
410 } else {
411 IpAddr::V6(self.local_ip6.expect("IPv6 local address undefined"))
412 },
413 local_port,
414 );
415 let tuple = AddrTuple::new(local_addr, addr);
416 let mut sock;
417
418 {
419 let mut tuples = self.shared.tuples.write().unwrap();
420 if tuples.contains_key(&tuple) {
421 trace!(
422 "Fake TCP connection to {}, local port number {} already in use, trying another one",
423 addr, local_port
424 );
425 continue;
426 }
427
428 let incoming;
429 (sock, incoming) = Socket::new(
430 self.shared.clone(),
431 self.shared.tun.choose(&mut rng).unwrap().clone(),
432 local_addr,
433 addr,
434 None,
435 State::Idle,
436 );
437
438 assert!(tuples.insert(tuple, incoming).is_none());
439 }
440
441 return sock.connect().await.map(|_| sock);
442 }
443
444 error!(
445 "Fake TCP connection to {} failed, emphemeral port number exhausted",
446 addr
447 );
448 None
449 }
450
451 async fn reader_task(
452 tun: Arc<Tun>,
453 shared: Arc<Shared>,
454 mut tuples_purge: broadcast::Receiver<AddrTuple>,
455 ) {
456 let mut tuples: HashMap<AddrTuple, flume::Sender<Bytes>> = HashMap::new();
457
458 loop {
459 let mut buf = BytesMut::zeroed(MAX_PACKET_LEN);
460
461 tokio::select! {
462 size = tun.recv(&mut buf) => {
463 let size = size.unwrap();
464 buf.truncate(size);
465 let buf = buf.freeze();
466
467 match parse_ip_packet(&buf) {
468 Some((ip_packet, tcp_packet)) => {
469 let local_addr =
470 SocketAddr::new(ip_packet.get_destination(), tcp_packet.get_destination());
471 let remote_addr = SocketAddr::new(ip_packet.get_source(), tcp_packet.get_source());
472
473 let tuple = AddrTuple::new(local_addr, remote_addr);
474 if let Some(c) = tuples.get(&tuple) {
475 if c.send_async(buf).await.is_err() {
476 trace!("Cache hit, but receiver already closed, dropping packet");
477 }
478
479 continue;
480
481 } else {
484 trace!("Cache miss, checking the shared tuples table for connection");
485 let sender = {
486 let tuples = shared.tuples.read().unwrap();
487 tuples.get(&tuple).cloned()
488 };
489
490 if let Some(c) = sender {
491 trace!("Storing connection information into local tuples");
492 tuples.insert(tuple, c.clone());
493 c.send_async(buf).await.unwrap();
494 continue;
495 }
496 }
497
498 if tcp_packet.get_flags() == tcp::TcpFlags::SYN
499 && shared
500 .listening
501 .read()
502 .unwrap()
503 .contains(&tcp_packet.get_destination())
504 {
505 if tcp_packet.get_sequence() == 0 {
507 let (sock, incoming) = Socket::new(
508 shared.clone(),
509 tun.clone(),
510 local_addr,
511 remote_addr,
512 Some(tcp_packet.get_sequence() + 1),
513 State::Idle,
514 );
515 assert!(shared
516 .tuples
517 .write()
518 .unwrap()
519 .insert(tuple, incoming)
520 .is_none());
521 tokio::spawn(sock.accept());
522 } else {
523 trace!("Bad TCP SYN packet from {}, sending RST", remote_addr);
524 let buf = build_tcp_packet(
525 local_addr,
526 remote_addr,
527 0,
528 tcp_packet.get_sequence() + tcp_packet.payload().len() as u32 + 1, tcp::TcpFlags::RST | tcp::TcpFlags::ACK,
530 None,
531 );
532 shared.tun[0].try_send(&buf).unwrap();
533 }
534 } else if (tcp_packet.get_flags() & tcp::TcpFlags::RST) == 0 {
535 info!("Unknown TCP packet from {}, sending RST", remote_addr);
536 let buf = build_tcp_packet(
537 local_addr,
538 remote_addr,
539 tcp_packet.get_acknowledgement(),
540 tcp_packet.get_sequence() + tcp_packet.payload().len() as u32,
541 tcp::TcpFlags::RST | tcp::TcpFlags::ACK,
542 None,
543 );
544 shared.tun[0].try_send(&buf).unwrap();
545 }
546 }
547 None => {
548 continue;
549 }
550 }
551 },
552 tuple = tuples_purge.recv() => {
553 let tuple = tuple.unwrap();
554 tuples.remove(&tuple);
555 trace!("Removed cached tuple: {:?}", tuple);
556 }
557 }
558 }
559 }
560}