Skip to main content

wireguard_netstack/
netstack.rs

1//! Userspace TCP/IP network stack using smoltcp.
2//!
3//! This module provides a TCP/IP stack that runs entirely in userspace,
4//! routing packets through our WireGuard tunnel.
5
6use crate::error::{Error, Result};
7use crate::wireguard::WireGuardTunnel;
8use bytes::BytesMut;
9use parking_lot::Mutex;
10use smoltcp::iface::{Config, Interface, PollResult, SocketHandle, SocketSet};
11use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
12use smoltcp::socket::tcp::{Socket as TcpSocket, SocketBuffer, State as TcpState};
13use smoltcp::time::Instant;
14use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket};
15use std::collections::VecDeque;
16use std::net::{SocketAddr, SocketAddrV4};
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::mpsc;
20
21/// MTU for the virtual interface.
22/// Some networks drop large UDP packets, especially when WireGuard overhead is added.
23/// We use a conservative MTU that results in ~600 byte UDP packets after WireGuard
24/// encapsulation (MTU + 40 IP/TCP headers + 48 WG overhead ≈ 548 byte UDP).
25/// This works around networks that filter large UDP packets.
26pub const DEFAULT_MTU: usize = 460;
27
28/// Size of TCP socket buffers.
29const TCP_BUFFER_SIZE: usize = 65535;
30
31/// A virtual network device that sends/receives through the WireGuard tunnel.
32struct VirtualDevice {
33    /// Packets ready to be received by smoltcp (from WireGuard).
34    rx_queue: VecDeque<BytesMut>,
35    /// Packets ready to be sent (to WireGuard).
36    tx_queue: VecDeque<BytesMut>,
37    /// MTU for this device.
38    mtu: usize,
39}
40
41impl VirtualDevice {
42    fn new(mtu: usize) -> Self {
43        Self {
44            rx_queue: VecDeque::new(),
45            tx_queue: VecDeque::new(),
46            mtu,
47        }
48    }
49
50    /// Add a packet to the receive queue (from WireGuard).
51    fn push_rx(&mut self, packet: BytesMut) {
52        self.rx_queue.push_back(packet);
53    }
54
55    /// Take all packets from the transmit queue (to send via WireGuard).
56    fn drain_tx(&mut self) -> Vec<BytesMut> {
57        self.tx_queue.drain(..).collect()
58    }
59}
60
61/// RxToken for smoltcp.
62struct VirtualRxToken {
63    buffer: BytesMut,
64}
65
66impl RxToken for VirtualRxToken {
67    fn consume<R, F>(self, f: F) -> R
68    where
69        F: FnOnce(&[u8]) -> R,
70    {
71        f(&self.buffer)
72    }
73}
74
75/// TxToken for smoltcp.
76struct VirtualTxToken<'a> {
77    tx_queue: &'a mut VecDeque<BytesMut>,
78}
79
80impl<'a> TxToken for VirtualTxToken<'a> {
81    fn consume<R, F>(self, len: usize, f: F) -> R
82    where
83        F: FnOnce(&mut [u8]) -> R,
84    {
85        let mut buffer = BytesMut::zeroed(len);
86        let result = f(&mut buffer);
87        self.tx_queue.push_back(buffer);
88        result
89    }
90
91    fn set_meta(&mut self, _meta: smoltcp::phy::PacketMeta) {
92        // No metadata handling needed for virtual device
93    }
94}
95
96impl Device for VirtualDevice {
97    type RxToken<'a> = VirtualRxToken;
98    type TxToken<'a> = VirtualTxToken<'a>;
99
100    fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
101        if let Some(buffer) = self.rx_queue.pop_front() {
102            Some((
103                VirtualRxToken { buffer },
104                VirtualTxToken {
105                    tx_queue: &mut self.tx_queue,
106                },
107            ))
108        } else {
109            None
110        }
111    }
112
113    fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
114        Some(VirtualTxToken {
115            tx_queue: &mut self.tx_queue,
116        })
117    }
118
119    fn capabilities(&self) -> DeviceCapabilities {
120        let mut caps = DeviceCapabilities::default();
121        caps.medium = Medium::Ip;
122        caps.max_transmission_unit = self.mtu;
123        caps
124    }
125}
126
127/// Shared state for the network stack.
128struct NetStackInner {
129    interface: Interface,
130    device: VirtualDevice,
131    sockets: SocketSet<'static>,
132}
133
134/// A userspace TCP/IP network stack.
135pub struct NetStack {
136    inner: Mutex<NetStackInner>,
137    wg_tunnel: Arc<WireGuardTunnel>,
138    /// Sender to queue packets for transmission through WireGuard.
139    wg_tx: mpsc::Sender<BytesMut>,
140}
141
142impl NetStack {
143    /// Create a new network stack backed by a WireGuard tunnel.
144    pub fn new(wg_tunnel: Arc<WireGuardTunnel>) -> Arc<Self> {
145        let tunnel_ip = wg_tunnel.tunnel_ip();
146        let mtu = wg_tunnel.mtu() as usize;
147        let wg_tx = wg_tunnel.outgoing_sender();
148
149        // Create the virtual device with the configured MTU
150        let mut device = VirtualDevice::new(mtu);
151
152        // Create the interface configuration
153        let config = Config::new(HardwareAddress::Ip);
154
155        // Create the interface
156        let mut interface = Interface::new(config, &mut device, Instant::now());
157
158        // Configure the interface with our tunnel IP
159        interface.update_ip_addrs(|addrs| {
160            addrs
161                .push(IpCidr::new(
162                    IpAddress::v4(
163                        tunnel_ip.octets()[0],
164                        tunnel_ip.octets()[1],
165                        tunnel_ip.octets()[2],
166                        tunnel_ip.octets()[3],
167                    ),
168                    32,
169                ))
170                .unwrap();
171        });
172
173        // Set up routing - route everything through this interface
174        interface
175            .routes_mut()
176            .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 0))
177            .unwrap();
178
179        // Create socket set
180        let sockets = SocketSet::new(vec![]);
181
182        let inner = NetStackInner {
183            interface,
184            device,
185            sockets,
186        };
187
188        Arc::new(Self {
189            inner: Mutex::new(inner),
190            wg_tunnel,
191            wg_tx,
192        })
193    }
194
195    /// Create a new TCP socket and return its handle.
196    pub fn create_tcp_socket(&self) -> SocketHandle {
197        let mut inner = self.inner.lock();
198
199        let rx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
200        let tx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
201        let socket = TcpSocket::new(rx_buffer, tx_buffer);
202
203        inner.sockets.add(socket)
204    }
205
206    /// Connect a TCP socket to the given address.
207    pub fn connect(&self, handle: SocketHandle, addr: SocketAddr) -> Result<()> {
208        let mut inner = self.inner.lock();
209
210        let local_port = 49152 + (rand::random::<u16>() % 16384);
211        let local_addr = SocketAddrV4::new(self.wg_tunnel.tunnel_ip(), local_port);
212
213        let remote = match addr {
214            SocketAddr::V4(v4) => smoltcp::wire::IpEndpoint::new(
215                IpAddress::v4(
216                    v4.ip().octets()[0],
217                    v4.ip().octets()[1],
218                    v4.ip().octets()[2],
219                    v4.ip().octets()[3],
220                ),
221                v4.port(),
222            ),
223            SocketAddr::V6(_) => return Err(Error::Ipv6NotSupported),
224        };
225
226        let local = smoltcp::wire::IpEndpoint::new(
227            IpAddress::v4(
228                local_addr.ip().octets()[0],
229                local_addr.ip().octets()[1],
230                local_addr.ip().octets()[2],
231                local_addr.ip().octets()[3],
232            ),
233            local_addr.port(),
234        );
235
236        // Use destructuring to avoid split borrow issues
237        let NetStackInner {
238            ref mut interface,
239            ref mut sockets,
240            ..
241        } = *inner;
242        let cx = interface.context();
243        let socket = sockets.get_mut::<TcpSocket>(handle);
244        socket
245            .connect(cx, remote, local)
246            .map_err(|e| Error::TcpConnectGeneric(format!("TCP connect failed: {}", e)))?;
247
248        log::debug!("TCP socket connecting to {} from {}", addr, local_addr);
249
250        Ok(())
251    }
252
253    /// Check if a TCP socket is connected.
254    pub fn is_connected(&self, handle: SocketHandle) -> bool {
255        let inner = self.inner.lock();
256        let socket = inner.sockets.get::<TcpSocket>(handle);
257        socket.state() == TcpState::Established
258    }
259
260    /// Check if a TCP socket can send data.
261    pub fn can_send(&self, handle: SocketHandle) -> bool {
262        let inner = self.inner.lock();
263        let socket = inner.sockets.get::<TcpSocket>(handle);
264        socket.can_send()
265    }
266
267    /// Check if a TCP socket can receive data.
268    pub fn can_recv(&self, handle: SocketHandle) -> bool {
269        let inner = self.inner.lock();
270        let socket = inner.sockets.get::<TcpSocket>(handle);
271        let can = socket.can_recv();
272        let recv_queue = socket.recv_queue();
273        if recv_queue > 0 {
274            log::debug!(
275                "Socket can_recv={}, recv_queue={}, state={:?}",
276                can,
277                recv_queue,
278                socket.state()
279            );
280        }
281        can
282    }
283
284    /// Check if a TCP socket may send data (connection in progress or established).
285    pub fn may_send(&self, handle: SocketHandle) -> bool {
286        let inner = self.inner.lock();
287        let socket = inner.sockets.get::<TcpSocket>(handle);
288        socket.may_send()
289    }
290
291    /// Check if a TCP socket may receive data.
292    pub fn may_recv(&self, handle: SocketHandle) -> bool {
293        let inner = self.inner.lock();
294        let socket = inner.sockets.get::<TcpSocket>(handle);
295        socket.may_recv()
296    }
297
298    /// Get the TCP socket state.
299    pub fn socket_state(&self, handle: SocketHandle) -> TcpState {
300        let inner = self.inner.lock();
301        let socket = inner.sockets.get::<TcpSocket>(handle);
302        socket.state()
303    }
304
305    /// Send data on a TCP socket.
306    pub fn send(&self, handle: SocketHandle, data: &[u8]) -> Result<usize> {
307        let mut inner = self.inner.lock();
308        let socket = inner.sockets.get_mut::<TcpSocket>(handle);
309
310        socket
311            .send_slice(data)
312            .map_err(|e| Error::TcpSend(e.to_string()))
313    }
314
315    /// Receive data from a TCP socket.
316    pub fn recv(&self, handle: SocketHandle, buffer: &mut [u8]) -> Result<usize> {
317        let mut inner = self.inner.lock();
318        let socket = inner.sockets.get_mut::<TcpSocket>(handle);
319
320        socket
321            .recv_slice(buffer)
322            .map_err(|e| Error::TcpRecv(e.to_string()))
323    }
324
325    /// Close a TCP socket.
326    pub fn close(&self, handle: SocketHandle) {
327        let mut inner = self.inner.lock();
328        let socket = inner.sockets.get_mut::<TcpSocket>(handle);
329        socket.close();
330    }
331
332    /// Remove a socket from the socket set.
333    pub fn remove_socket(&self, handle: SocketHandle) {
334        let mut inner = self.inner.lock();
335        inner.sockets.remove(handle);
336    }
337
338    /// Poll the network stack, processing packets and updating socket states.
339    /// Returns true if there was any activity.
340    pub fn poll(&self) -> bool {
341        let mut inner = self.inner.lock();
342
343        let timestamp = Instant::now();
344
345        // Destructure to allow split borrows
346        let NetStackInner {
347            ref mut interface,
348            ref mut device,
349            ref mut sockets,
350        } = *inner;
351
352        // Check if there are packets waiting
353        let rx_queue_len = device.rx_queue.len();
354        if rx_queue_len > 0 {
355            log::trace!("NetStack poll: {} packets in rx_queue", rx_queue_len);
356        }
357
358        // Poll the interface
359        let poll_result = interface.poll(timestamp, device, sockets);
360        let processed = poll_result != PollResult::None;
361
362        if processed {
363            log::trace!("NetStack poll processed packets");
364        }
365
366        // Drain transmitted packets and send through WireGuard
367        let tx_packets = device.drain_tx();
368        let tx_count = tx_packets.len();
369        drop(inner); // Release lock before async operations
370
371        if tx_count > 0 {
372            log::trace!("NetStack poll sending {} packets", tx_count);
373        }
374
375        for packet in tx_packets {
376            // Log outgoing TCP packets at debug level
377            if log::log_enabled!(log::Level::Debug) {
378                if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
379                    let protocol = ip_packet.next_header();
380                    if protocol == smoltcp::wire::IpProtocol::Tcp {
381                        if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
382                            let dst_port = tcp_packet.dst_port();
383                            let payload_len = tcp_packet.payload().len();
384
385                            let mut flags = String::new();
386                            if tcp_packet.syn() {
387                                flags.push_str("SYN ");
388                            }
389                            if tcp_packet.ack() {
390                                flags.push_str("ACK ");
391                            }
392                            if tcp_packet.fin() {
393                                flags.push_str("FIN ");
394                            }
395                            if tcp_packet.rst() {
396                                flags.push_str("RST ");
397                            }
398                            if tcp_packet.psh() {
399                                flags.push_str("PSH ");
400                            }
401
402                            log::debug!(
403                                "TX: {}:{} [{}] {} bytes",
404                                ip_packet.dst_addr(),
405                                dst_port,
406                                flags.trim(),
407                                payload_len
408                            );
409                        }
410                    }
411                }
412            }
413
414            let tx = self.wg_tx.clone();
415            tokio::spawn(async move {
416                if let Err(e) = tx.send(packet).await {
417                    log::error!("Failed to queue packet for WireGuard: {}", e);
418                }
419            });
420        }
421
422        processed
423    }
424
425    /// Push a received packet (from WireGuard) into the network stack.
426    pub fn push_rx_packet(&self, packet: BytesMut) {
427        // Parse and log TCP packet details for debugging
428        if log::log_enabled!(log::Level::Debug) {
429            if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
430                let protocol = ip_packet.next_header();
431                if protocol == smoltcp::wire::IpProtocol::Tcp {
432                    if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
433                        let src_port = tcp_packet.src_port();
434                        let payload_len = tcp_packet.payload().len();
435
436                        let mut flags = String::new();
437                        if tcp_packet.syn() {
438                            flags.push_str("SYN ");
439                        }
440                        if tcp_packet.ack() {
441                            flags.push_str("ACK ");
442                        }
443                        if tcp_packet.fin() {
444                            flags.push_str("FIN ");
445                        }
446                        if tcp_packet.rst() {
447                            flags.push_str("RST ");
448                        }
449                        if tcp_packet.psh() {
450                            flags.push_str("PSH ");
451                        }
452
453                        log::debug!(
454                            "RX: {}:{} [{}] {} bytes",
455                            ip_packet.src_addr(),
456                            src_port,
457                            flags.trim(),
458                            payload_len
459                        );
460                    }
461                }
462            }
463        }
464
465        let mut inner = self.inner.lock();
466        inner.device.push_rx(packet);
467    }
468
469    /// Run the network stack polling loop.
470    pub async fn run_poll_loop(self: &Arc<Self>) -> Result<()> {
471        let mut interval = tokio::time::interval(Duration::from_millis(1));
472
473        loop {
474            interval.tick().await;
475            self.poll();
476        }
477    }
478
479    /// Run the receive loop that takes packets from WireGuard and feeds them to the stack.
480    pub async fn run_rx_loop(self: &Arc<Self>, mut rx: mpsc::Receiver<BytesMut>) -> Result<()> {
481        while let Some(packet) = rx.recv().await {
482            log::debug!("NetStack received packet ({} bytes)", packet.len());
483            self.push_rx_packet(packet);
484            self.poll();
485        }
486
487        Ok(())
488    }
489}
490
491/// A TCP connection through our network stack.
492pub struct TcpConnection {
493    /// The network stack backing this connection.
494    pub netstack: Arc<NetStack>,
495    /// The socket handle for this connection.
496    pub handle: SocketHandle,
497}
498
499impl TcpConnection {
500    /// Create a new TCP connection.
501    pub async fn connect(netstack: Arc<NetStack>, addr: SocketAddr) -> Result<Self> {
502        let handle = netstack.create_tcp_socket();
503        netstack.connect(handle, addr)?;
504
505        // Poll until connected or timeout
506        let start = std::time::Instant::now();
507        let timeout = Duration::from_secs(30);
508
509        loop {
510            netstack.poll();
511
512            let state = netstack.socket_state(handle);
513            log::trace!("TCP state: {:?}", state);
514
515            if state == TcpState::Established {
516                log::info!("TCP connection established to {}", addr);
517                return Ok(Self { netstack, handle });
518            }
519
520            if state == TcpState::Closed || state == TcpState::TimeWait {
521                netstack.remove_socket(handle);
522                return Err(Error::TcpConnect {
523                    addr,
524                    message: format!("Connection failed (state: {:?})", state),
525                });
526            }
527
528            if start.elapsed() > timeout {
529                netstack.remove_socket(handle);
530                return Err(Error::TcpTimeout);
531            }
532
533            tokio::time::sleep(Duration::from_millis(1)).await;
534        }
535    }
536
537    /// Read data from the connection.
538    pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
539        let timeout = Duration::from_secs(30);
540        let start = std::time::Instant::now();
541
542        loop {
543            self.netstack.poll();
544
545            if self.netstack.can_recv(self.handle) {
546                match self.netstack.recv(self.handle, buf) {
547                    Ok(n) if n > 0 => return Ok(n),
548                    Ok(_) => {}
549                    Err(e) => return Err(e),
550                }
551            }
552
553            if !self.netstack.may_recv(self.handle) {
554                // Connection closed
555                return Ok(0);
556            }
557
558            if start.elapsed() > timeout {
559                return Err(Error::ReadTimeout);
560            }
561
562            tokio::time::sleep(Duration::from_millis(1)).await;
563        }
564    }
565
566    /// Write data to the connection.
567    pub async fn write(&self, data: &[u8]) -> Result<usize> {
568        let timeout = Duration::from_secs(30);
569        let start = std::time::Instant::now();
570
571        let mut written = 0;
572
573        while written < data.len() {
574            self.netstack.poll();
575
576            if self.netstack.can_send(self.handle) {
577                match self.netstack.send(self.handle, &data[written..]) {
578                    Ok(n) => {
579                        written += n;
580                        log::trace!("Wrote {} bytes (total: {})", n, written);
581                    }
582                    Err(e) => return Err(e),
583                }
584            }
585
586            if !self.netstack.may_send(self.handle) {
587                // Connection closed
588                return Err(Error::ConnectionClosed);
589            }
590
591            if start.elapsed() > timeout {
592                return Err(Error::WriteTimeout);
593            }
594
595            if written < data.len() {
596                tokio::time::sleep(Duration::from_millis(1)).await;
597            }
598        }
599
600        self.netstack.poll();
601        Ok(written)
602    }
603
604    /// Write all data to the connection.
605    pub async fn write_all(&self, data: &[u8]) -> Result<()> {
606        let n = self.write(data).await?;
607        if n != data.len() {
608            return Err(Error::ShortWrite {
609                written: n,
610                expected: data.len(),
611            });
612        }
613        Ok(())
614    }
615
616    /// Shutdown the connection.
617    pub fn shutdown(&self) {
618        self.netstack.close(self.handle);
619    }
620
621    /// Get the socket handle.
622    pub fn handle(&self) -> SocketHandle {
623        self.handle
624    }
625}
626
627impl Drop for TcpConnection {
628    fn drop(&mut self) {
629        self.netstack.close(self.handle);
630        // Give time for FIN to be sent
631        self.netstack.poll();
632    }
633}