1use crate::error::{Error, Result};
7use crate::wireguard::WireGuardTunnel;
8use bytes::BytesMut;
9use parking_lot::Mutex;
10use smoltcp::iface::{Config, Interface, PollResult, SocketHandle, SocketSet};
11use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
12use smoltcp::socket::tcp::{Socket as TcpSocket, SocketBuffer, State as TcpState};
13use smoltcp::time::Instant;
14use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket};
15use std::collections::VecDeque;
16use std::net::{SocketAddr, SocketAddrV4};
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::mpsc;
20
21pub const DEFAULT_MTU: usize = 460;
27
28const TCP_BUFFER_SIZE: usize = 65535;
30
31struct VirtualDevice {
33 rx_queue: VecDeque<BytesMut>,
35 tx_queue: VecDeque<BytesMut>,
37 mtu: usize,
39}
40
41impl VirtualDevice {
42 fn new(mtu: usize) -> Self {
43 Self {
44 rx_queue: VecDeque::new(),
45 tx_queue: VecDeque::new(),
46 mtu,
47 }
48 }
49
50 fn push_rx(&mut self, packet: BytesMut) {
52 self.rx_queue.push_back(packet);
53 }
54
55 fn drain_tx(&mut self) -> Vec<BytesMut> {
57 self.tx_queue.drain(..).collect()
58 }
59}
60
61struct VirtualRxToken {
63 buffer: BytesMut,
64}
65
66impl RxToken for VirtualRxToken {
67 fn consume<R, F>(self, f: F) -> R
68 where
69 F: FnOnce(&[u8]) -> R,
70 {
71 f(&self.buffer)
72 }
73}
74
75struct VirtualTxToken<'a> {
77 tx_queue: &'a mut VecDeque<BytesMut>,
78}
79
80impl<'a> TxToken for VirtualTxToken<'a> {
81 fn consume<R, F>(self, len: usize, f: F) -> R
82 where
83 F: FnOnce(&mut [u8]) -> R,
84 {
85 let mut buffer = BytesMut::zeroed(len);
86 let result = f(&mut buffer);
87 self.tx_queue.push_back(buffer);
88 result
89 }
90
91 fn set_meta(&mut self, _meta: smoltcp::phy::PacketMeta) {
92 }
94}
95
96impl Device for VirtualDevice {
97 type RxToken<'a> = VirtualRxToken;
98 type TxToken<'a> = VirtualTxToken<'a>;
99
100 fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
101 if let Some(buffer) = self.rx_queue.pop_front() {
102 Some((
103 VirtualRxToken { buffer },
104 VirtualTxToken {
105 tx_queue: &mut self.tx_queue,
106 },
107 ))
108 } else {
109 None
110 }
111 }
112
113 fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
114 Some(VirtualTxToken {
115 tx_queue: &mut self.tx_queue,
116 })
117 }
118
119 fn capabilities(&self) -> DeviceCapabilities {
120 let mut caps = DeviceCapabilities::default();
121 caps.medium = Medium::Ip;
122 caps.max_transmission_unit = self.mtu;
123 caps
124 }
125}
126
127struct NetStackInner {
129 interface: Interface,
130 device: VirtualDevice,
131 sockets: SocketSet<'static>,
132}
133
134pub struct NetStack {
136 inner: Mutex<NetStackInner>,
137 wg_tunnel: Arc<WireGuardTunnel>,
138 wg_tx: mpsc::Sender<BytesMut>,
140}
141
142impl NetStack {
143 pub fn new(wg_tunnel: Arc<WireGuardTunnel>) -> Arc<Self> {
145 let tunnel_ip = wg_tunnel.tunnel_ip();
146 let mtu = wg_tunnel.mtu() as usize;
147 let wg_tx = wg_tunnel.outgoing_sender();
148
149 let mut device = VirtualDevice::new(mtu);
151
152 let config = Config::new(HardwareAddress::Ip);
154
155 let mut interface = Interface::new(config, &mut device, Instant::now());
157
158 interface.update_ip_addrs(|addrs| {
160 addrs
161 .push(IpCidr::new(
162 IpAddress::v4(
163 tunnel_ip.octets()[0],
164 tunnel_ip.octets()[1],
165 tunnel_ip.octets()[2],
166 tunnel_ip.octets()[3],
167 ),
168 32,
169 ))
170 .unwrap();
171 });
172
173 interface
175 .routes_mut()
176 .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 0))
177 .unwrap();
178
179 let sockets = SocketSet::new(vec![]);
181
182 let inner = NetStackInner {
183 interface,
184 device,
185 sockets,
186 };
187
188 Arc::new(Self {
189 inner: Mutex::new(inner),
190 wg_tunnel,
191 wg_tx,
192 })
193 }
194
195 pub fn create_tcp_socket(&self) -> SocketHandle {
197 let mut inner = self.inner.lock();
198
199 let rx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
200 let tx_buffer = SocketBuffer::new(vec![0u8; TCP_BUFFER_SIZE]);
201 let socket = TcpSocket::new(rx_buffer, tx_buffer);
202
203 inner.sockets.add(socket)
204 }
205
206 pub fn connect(&self, handle: SocketHandle, addr: SocketAddr) -> Result<()> {
208 let mut inner = self.inner.lock();
209
210 let local_port = 49152 + (rand::random::<u16>() % 16384);
211 let local_addr = SocketAddrV4::new(self.wg_tunnel.tunnel_ip(), local_port);
212
213 let remote = match addr {
214 SocketAddr::V4(v4) => smoltcp::wire::IpEndpoint::new(
215 IpAddress::v4(
216 v4.ip().octets()[0],
217 v4.ip().octets()[1],
218 v4.ip().octets()[2],
219 v4.ip().octets()[3],
220 ),
221 v4.port(),
222 ),
223 SocketAddr::V6(_) => return Err(Error::Ipv6NotSupported),
224 };
225
226 let local = smoltcp::wire::IpEndpoint::new(
227 IpAddress::v4(
228 local_addr.ip().octets()[0],
229 local_addr.ip().octets()[1],
230 local_addr.ip().octets()[2],
231 local_addr.ip().octets()[3],
232 ),
233 local_addr.port(),
234 );
235
236 let NetStackInner {
238 ref mut interface,
239 ref mut sockets,
240 ..
241 } = *inner;
242 let cx = interface.context();
243 let socket = sockets.get_mut::<TcpSocket>(handle);
244 socket
245 .connect(cx, remote, local)
246 .map_err(|e| Error::TcpConnectGeneric(format!("TCP connect failed: {}", e)))?;
247
248 log::debug!("TCP socket connecting to {} from {}", addr, local_addr);
249
250 Ok(())
251 }
252
253 pub fn is_connected(&self, handle: SocketHandle) -> bool {
255 let inner = self.inner.lock();
256 let socket = inner.sockets.get::<TcpSocket>(handle);
257 socket.state() == TcpState::Established
258 }
259
260 pub fn can_send(&self, handle: SocketHandle) -> bool {
262 let inner = self.inner.lock();
263 let socket = inner.sockets.get::<TcpSocket>(handle);
264 socket.can_send()
265 }
266
267 pub fn can_recv(&self, handle: SocketHandle) -> bool {
269 let inner = self.inner.lock();
270 let socket = inner.sockets.get::<TcpSocket>(handle);
271 let can = socket.can_recv();
272 let recv_queue = socket.recv_queue();
273 if recv_queue > 0 {
274 log::debug!(
275 "Socket can_recv={}, recv_queue={}, state={:?}",
276 can,
277 recv_queue,
278 socket.state()
279 );
280 }
281 can
282 }
283
284 pub fn may_send(&self, handle: SocketHandle) -> bool {
286 let inner = self.inner.lock();
287 let socket = inner.sockets.get::<TcpSocket>(handle);
288 socket.may_send()
289 }
290
291 pub fn may_recv(&self, handle: SocketHandle) -> bool {
293 let inner = self.inner.lock();
294 let socket = inner.sockets.get::<TcpSocket>(handle);
295 socket.may_recv()
296 }
297
298 pub fn socket_state(&self, handle: SocketHandle) -> TcpState {
300 let inner = self.inner.lock();
301 let socket = inner.sockets.get::<TcpSocket>(handle);
302 socket.state()
303 }
304
305 pub fn send(&self, handle: SocketHandle, data: &[u8]) -> Result<usize> {
307 let mut inner = self.inner.lock();
308 let socket = inner.sockets.get_mut::<TcpSocket>(handle);
309
310 socket
311 .send_slice(data)
312 .map_err(|e| Error::TcpSend(e.to_string()))
313 }
314
315 pub fn recv(&self, handle: SocketHandle, buffer: &mut [u8]) -> Result<usize> {
317 let mut inner = self.inner.lock();
318 let socket = inner.sockets.get_mut::<TcpSocket>(handle);
319
320 socket
321 .recv_slice(buffer)
322 .map_err(|e| Error::TcpRecv(e.to_string()))
323 }
324
325 pub fn close(&self, handle: SocketHandle) {
327 let mut inner = self.inner.lock();
328 let socket = inner.sockets.get_mut::<TcpSocket>(handle);
329 socket.close();
330 }
331
332 pub fn remove_socket(&self, handle: SocketHandle) {
334 let mut inner = self.inner.lock();
335 inner.sockets.remove(handle);
336 }
337
338 pub fn poll(&self) -> bool {
341 let mut inner = self.inner.lock();
342
343 let timestamp = Instant::now();
344
345 let NetStackInner {
347 ref mut interface,
348 ref mut device,
349 ref mut sockets,
350 } = *inner;
351
352 let rx_queue_len = device.rx_queue.len();
354 if rx_queue_len > 0 {
355 log::trace!("NetStack poll: {} packets in rx_queue", rx_queue_len);
356 }
357
358 let poll_result = interface.poll(timestamp, device, sockets);
360 let processed = poll_result != PollResult::None;
361
362 if processed {
363 log::trace!("NetStack poll processed packets");
364 }
365
366 let tx_packets = device.drain_tx();
368 let tx_count = tx_packets.len();
369 drop(inner); if tx_count > 0 {
372 log::trace!("NetStack poll sending {} packets", tx_count);
373 }
374
375 for packet in tx_packets {
376 if log::log_enabled!(log::Level::Debug) {
378 if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
379 let protocol = ip_packet.next_header();
380 if protocol == smoltcp::wire::IpProtocol::Tcp {
381 if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
382 let dst_port = tcp_packet.dst_port();
383 let payload_len = tcp_packet.payload().len();
384
385 let mut flags = String::new();
386 if tcp_packet.syn() {
387 flags.push_str("SYN ");
388 }
389 if tcp_packet.ack() {
390 flags.push_str("ACK ");
391 }
392 if tcp_packet.fin() {
393 flags.push_str("FIN ");
394 }
395 if tcp_packet.rst() {
396 flags.push_str("RST ");
397 }
398 if tcp_packet.psh() {
399 flags.push_str("PSH ");
400 }
401
402 log::debug!(
403 "TX: {}:{} [{}] {} bytes",
404 ip_packet.dst_addr(),
405 dst_port,
406 flags.trim(),
407 payload_len
408 );
409 }
410 }
411 }
412 }
413
414 let tx = self.wg_tx.clone();
415 tokio::spawn(async move {
416 if let Err(e) = tx.send(packet).await {
417 log::error!("Failed to queue packet for WireGuard: {}", e);
418 }
419 });
420 }
421
422 processed
423 }
424
425 pub fn push_rx_packet(&self, packet: BytesMut) {
427 if log::log_enabled!(log::Level::Debug) {
429 if let Ok(ip_packet) = Ipv4Packet::new_checked(&packet) {
430 let protocol = ip_packet.next_header();
431 if protocol == smoltcp::wire::IpProtocol::Tcp {
432 if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
433 let src_port = tcp_packet.src_port();
434 let payload_len = tcp_packet.payload().len();
435
436 let mut flags = String::new();
437 if tcp_packet.syn() {
438 flags.push_str("SYN ");
439 }
440 if tcp_packet.ack() {
441 flags.push_str("ACK ");
442 }
443 if tcp_packet.fin() {
444 flags.push_str("FIN ");
445 }
446 if tcp_packet.rst() {
447 flags.push_str("RST ");
448 }
449 if tcp_packet.psh() {
450 flags.push_str("PSH ");
451 }
452
453 log::debug!(
454 "RX: {}:{} [{}] {} bytes",
455 ip_packet.src_addr(),
456 src_port,
457 flags.trim(),
458 payload_len
459 );
460 }
461 }
462 }
463 }
464
465 let mut inner = self.inner.lock();
466 inner.device.push_rx(packet);
467 }
468
469 pub async fn run_poll_loop(self: &Arc<Self>) -> Result<()> {
471 let mut interval = tokio::time::interval(Duration::from_millis(1));
472
473 loop {
474 interval.tick().await;
475 self.poll();
476 }
477 }
478
479 pub async fn run_rx_loop(self: &Arc<Self>, mut rx: mpsc::Receiver<BytesMut>) -> Result<()> {
481 while let Some(packet) = rx.recv().await {
482 log::debug!("NetStack received packet ({} bytes)", packet.len());
483 self.push_rx_packet(packet);
484 self.poll();
485 }
486
487 Ok(())
488 }
489}
490
491pub struct TcpConnection {
493 pub netstack: Arc<NetStack>,
495 pub handle: SocketHandle,
497}
498
499impl TcpConnection {
500 pub async fn connect(netstack: Arc<NetStack>, addr: SocketAddr) -> Result<Self> {
502 let handle = netstack.create_tcp_socket();
503 netstack.connect(handle, addr)?;
504
505 let start = std::time::Instant::now();
507 let timeout = Duration::from_secs(30);
508
509 loop {
510 netstack.poll();
511
512 let state = netstack.socket_state(handle);
513 log::trace!("TCP state: {:?}", state);
514
515 if state == TcpState::Established {
516 log::info!("TCP connection established to {}", addr);
517 return Ok(Self { netstack, handle });
518 }
519
520 if state == TcpState::Closed || state == TcpState::TimeWait {
521 netstack.remove_socket(handle);
522 return Err(Error::TcpConnect {
523 addr,
524 message: format!("Connection failed (state: {:?})", state),
525 });
526 }
527
528 if start.elapsed() > timeout {
529 netstack.remove_socket(handle);
530 return Err(Error::TcpTimeout);
531 }
532
533 tokio::time::sleep(Duration::from_millis(1)).await;
534 }
535 }
536
537 pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
539 let timeout = Duration::from_secs(30);
540 let start = std::time::Instant::now();
541
542 loop {
543 self.netstack.poll();
544
545 if self.netstack.can_recv(self.handle) {
546 match self.netstack.recv(self.handle, buf) {
547 Ok(n) if n > 0 => return Ok(n),
548 Ok(_) => {}
549 Err(e) => return Err(e),
550 }
551 }
552
553 if !self.netstack.may_recv(self.handle) {
554 return Ok(0);
556 }
557
558 if start.elapsed() > timeout {
559 return Err(Error::ReadTimeout);
560 }
561
562 tokio::time::sleep(Duration::from_millis(1)).await;
563 }
564 }
565
566 pub async fn write(&self, data: &[u8]) -> Result<usize> {
568 let timeout = Duration::from_secs(30);
569 let start = std::time::Instant::now();
570
571 let mut written = 0;
572
573 while written < data.len() {
574 self.netstack.poll();
575
576 if self.netstack.can_send(self.handle) {
577 match self.netstack.send(self.handle, &data[written..]) {
578 Ok(n) => {
579 written += n;
580 log::trace!("Wrote {} bytes (total: {})", n, written);
581 }
582 Err(e) => return Err(e),
583 }
584 }
585
586 if !self.netstack.may_send(self.handle) {
587 return Err(Error::ConnectionClosed);
589 }
590
591 if start.elapsed() > timeout {
592 return Err(Error::WriteTimeout);
593 }
594
595 if written < data.len() {
596 tokio::time::sleep(Duration::from_millis(1)).await;
597 }
598 }
599
600 self.netstack.poll();
601 Ok(written)
602 }
603
604 pub async fn write_all(&self, data: &[u8]) -> Result<()> {
606 let n = self.write(data).await?;
607 if n != data.len() {
608 return Err(Error::ShortWrite {
609 written: n,
610 expected: data.len(),
611 });
612 }
613 Ok(())
614 }
615
616 pub fn shutdown(&self) {
618 self.netstack.close(self.handle);
619 }
620
621 pub fn handle(&self) -> SocketHandle {
623 self.handle
624 }
625}
626
627impl Drop for TcpConnection {
628 fn drop(&mut self) {
629 self.netstack.close(self.handle);
630 self.netstack.poll();
632 }
633}