1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9use ringbuf::HeapRb;
10use ringbuf::traits::{Consumer, Observer, Producer};
11use smoltcp::iface::{Interface, PollResult, SocketSet};
12use smoltcp::socket::tcp::{CongestionControl, Socket, SocketBuffer};
13use smoltcp::wire::{IpCidr, IpProtocol, TcpPacket};
14use tokio::sync::{Notify, broadcast, mpsc};
15use tokio::task::JoinHandle;
16
17use crate::buffer::BufferPool;
18use crate::device::NetstackDevice;
19use crate::stack::{IpPacket, NetStackConfig, Packet};
20use crate::{debug, error};
21
22pub use stream::TcpStream;
23pub(crate) use stream::{RbConsumer, RbProducer, SharedState};
24
25struct TcpConnectionWorker {
26 config: Arc<NetStackConfig>,
27 device_injector: mpsc::Sender<Packet>,
28 iface: Interface,
29 sockets: SocketSet<'static>,
30 socket_maps: HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
31 inbound: mpsc::Receiver<Packet>,
32 socket_stream_emitter: mpsc::Sender<TcpStream>,
33 notifier: Arc<Notify>,
34 shutdown_rx: broadcast::Receiver<()>,
35}
36
37pub(crate) struct SocketIOHandle {
38 recv_buffer_prod: RbProducer,
39 send_buffer_cons: RbConsumer,
40 shared_state: Arc<SharedState>,
41}
42
43pub struct TcpConnection {
44 socket_stream: mpsc::Receiver<TcpStream>,
45 shutdown_tx: broadcast::Sender<()>,
46 _handles: Vec<JoinHandle<()>>,
47}
48
49impl Drop for TcpConnection {
50 fn drop(&mut self) {
51 let _ = self.shutdown_tx.send(());
52 }
53}
54
55impl TcpConnection {
56 pub fn new(
57 config: NetStackConfig,
58 inbound: mpsc::Receiver<Packet>,
59 outbound: mpsc::Sender<Packet>,
60 buffer_pool: Arc<BufferPool>,
61 ) -> Self {
62 let num_workers = config.number_workers;
63 let config = Arc::new(config);
64
65 let (aggregated_socket_stream_emitter, aggregated_socket_stream_receiver) =
66 mpsc::channel::<TcpStream>(config.channel_size);
67
68 let (shutdown_tx, _) = broadcast::channel(1);
69
70 let mut _handles = Vec::new();
71 let mut worker_senders = Vec::with_capacity(num_workers);
72
73 for _i in 0..num_workers {
74 let (worker_inbound_sender, worker_inbound_receiver) =
75 mpsc::channel(config.channel_size);
76 worker_senders.push(worker_inbound_sender);
77
78 let mut device = NetstackDevice::new(outbound.clone(), buffer_pool.clone(), &config);
79 let device_injector = device.create_injector();
80 let iface = Self::create_interface(&config, &mut device);
81 let notifier = Arc::new(Notify::new());
82 let shutdown_rx = shutdown_tx.subscribe();
83
84 let mut worker = TcpConnectionWorker {
85 config: config.clone(),
86 device_injector,
87 iface,
88 sockets: SocketSet::new(vec![]),
89 socket_maps: HashMap::new(),
90 inbound: worker_inbound_receiver,
91 socket_stream_emitter: aggregated_socket_stream_emitter.clone(),
92 notifier: notifier.clone(),
93 shutdown_rx,
94 };
95
96 let worker_handle = tokio::spawn(async move {
97 if let Err(_e) = worker.accept_loop(device).await {
98 error!("[Worker {}] exited with error: {}", _i, _e);
99 }
100 });
101 _handles.push(worker_handle);
102 }
103
104 let dispatcher_shutdown_rx = shutdown_tx.subscribe();
105 let dispatcher_handle = tokio::spawn(Self::distribute_packets(
106 inbound,
107 worker_senders,
108 dispatcher_shutdown_rx,
109 ));
110 _handles.push(dispatcher_handle);
111
112 TcpConnection {
113 socket_stream: aggregated_socket_stream_receiver,
114 shutdown_tx,
115 _handles,
116 }
117 }
118
119 async fn distribute_packets(
120 mut inbound: mpsc::Receiver<Packet>,
121 worker_senders: Vec<mpsc::Sender<Packet>>,
122 mut shutdown_rx: broadcast::Receiver<()>,
123 ) {
124 let num_workers = worker_senders.len();
125 loop {
126 tokio::select! {
127 _ = shutdown_rx.recv() => {
128 debug!("[Dispatcher] received shutdown signal, exiting.");
129 break;
130 }
131 maybe_packet = inbound.recv() => {
132 if let Some(packet) = maybe_packet {
133 let worker_index = match IpPacket::new_checked(packet.data()) {
134 Ok(ip_packet) => {
135 if ip_packet.protocol() == IpProtocol::Tcp {
136 if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
137 let mut addr1 =
138 SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
139 let mut addr2 =
140 SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
141 if addr1 > addr2 {
142 std::mem::swap(&mut addr1, &mut addr2);
143 }
144 let mut hasher = DefaultHasher::new();
145 addr1.hash(&mut hasher);
146 addr2.hash(&mut hasher);
147 (hasher.finish() % num_workers as u64) as usize
148 } else { 0 }
149 } else { 0 }
150 }
151 Err(_) => 0,
152 };
153
154 if worker_senders[worker_index].send(packet).await.is_err() {
155 error!(
156 "[Dispatcher] Failed to send packet to worker {}, channel closed.",
157 worker_index
158 );
159 break;
160 }
161 } else {
162 debug!("[Dispatcher] Inbound channel closed, exiting.");
163 break;
164 }
165 }
166 }
167 }
168 debug!("[Dispatcher] stopped.");
169 }
170
171 fn create_interface(config: &NetStackConfig, device: &mut NetstackDevice) -> Interface {
172 let mut iface_config = smoltcp::iface::Config::new(smoltcp::wire::HardwareAddress::Ip);
173 iface_config.random_seed = rand::random();
174 let mut iface =
175 smoltcp::iface::Interface::new(iface_config, device, smoltcp::time::Instant::now());
176
177 iface.set_any_ip(true);
178 iface.update_ip_addrs(|ip_addrs| {
179 let _ = ip_addrs.push(IpCidr::new(config.ipv4_addr.into(), config.ipv4_prefix_len));
180 let _ = ip_addrs.push(IpCidr::new(config.ipv6_addr.into(), config.ipv6_prefix_len));
181 });
182
183 iface
184 .routes_mut()
185 .add_default_ipv4_route(config.ipv4_addr)
186 .expect("Failed to add default IPv4 route");
187 iface
188 .routes_mut()
189 .add_default_ipv6_route(config.ipv6_addr)
190 .expect("Failed to add default IPv6 route");
191
192 iface
193 }
194}
195
196impl TcpConnectionWorker {
197 async fn accept_loop(&mut self, mut device: NetstackDevice) -> std::io::Result<()> {
198 loop {
199 let mut progress = true;
201 while progress {
202 progress = false;
203
204 while let Ok(packet) = self.inbound.try_recv() {
206 if let Err(_e) = self.process_inbound_frame(packet).await {
207 error!("Error processing inbound frame: {}", _e);
208 }
209 progress = true;
210 }
211
212 let now = smoltcp::time::Instant::now();
214 if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
215 progress = true;
216 }
217
218 let mut total_bytes_processed = 0;
220 for (socket_handle, socket_control) in self.socket_maps.iter_mut() {
221 if socket_control
223 .shared_state
224 .socket_dropped
225 .load(Ordering::Acquire)
226 {
227 continue;
228 }
229 let socket = self.sockets.get_mut::<Socket>(*socket_handle);
230 let (read, written) = Self::handle_socket_io(socket, socket_control);
231 if read > 0 || written > 0 {
232 total_bytes_processed += read + written;
233 }
234 }
235 if total_bytes_processed > 0 {
236 progress = true;
237 }
238
239 if Self::prune_sockets(&mut self.sockets, &mut self.socket_maps) {
241 progress = true;
242 }
243
244 let now = smoltcp::time::Instant::now();
246 if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
247 progress = true;
248 }
249
250 if progress && total_bytes_processed == 0 && self.inbound.is_empty() {
251 tokio::task::yield_now().await;
252 }
253 }
254
255 let now = smoltcp::time::Instant::now();
260 let smoltcp_delay = self.iface.poll_delay(now, &self.sockets).map(|d| d.into());
261
262 tokio::select! {
263 biased;
264 _ = self.shutdown_rx.recv() => {
265 debug!("Worker received shutdown signal, exiting gracefully.");
266 return Ok(());
267 }
268
269 maybe_packet = self.inbound.recv() => {
271 match maybe_packet {
272 Some(packet) => {
273 if let Err(_e) = self.process_inbound_frame(packet).await {
274 error!("Error processing inbound frame: {}", _e);
275 }
276 },
278 None => return Ok(()), }
280 },
281
282 _ = self.notifier.notified() => {
284 },
287
288 _ = async {
290 match smoltcp_delay {
291 Some(delay) if delay > Duration::ZERO => tokio::time::sleep(delay).await,
293 _ => std::future::pending().await,
298 }
299 } => {
300 },
303 }
304 }
305 }
306
307 async fn process_inbound_frame(&mut self, frame: Packet) -> std::io::Result<()> {
308 if let Ok(ip_packet) = IpPacket::new_checked(frame.data())
309 && ip_packet.protocol() == IpProtocol::Tcp
310 && let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload())
311 && tcp_packet.syn()
312 && !tcp_packet.ack()
313 {
314 self.accept_new_connection(&ip_packet, &tcp_packet)?;
315 }
316
317 self.device_injector
318 .try_send(frame)
319 .map_err(|e| std::io::Error::other(e.to_string()))?;
320 Ok(())
321 }
322
323 fn accept_new_connection(
324 &mut self,
325 ip_packet: &IpPacket<&[u8]>,
326 tcp_packet: &TcpPacket<&[u8]>,
327 ) -> std::io::Result<()> {
328 let src_addr = SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
329 let dst_addr = SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
330
331 let mut socket = Socket::new(
332 SocketBuffer::new(vec![0u8; self.config.tcp_recv_buffer_size]),
333 SocketBuffer::new(vec![0u8; self.config.tcp_send_buffer_size]),
334 );
335
336 socket.set_keep_alive(Some(self.config.tcp_keep_alive.into()));
337 socket.set_timeout(Some(self.config.tcp_timeout.into()));
338 socket.set_nagle_enabled(false);
339 socket.set_congestion_control(CongestionControl::Cubic);
340
341 socket
342 .listen(dst_addr)
343 .map_err(|e| std::io::Error::other(e.to_string()))?;
344
345 let recv_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_recv_buffer_size));
346 let (recv_prod, recv_cons) = (
347 ringbuf::Prod::new(recv_rb.clone()),
348 ringbuf::Cons::new(recv_rb),
349 );
350
351 let send_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_send_buffer_size));
352 let (send_prod, send_cons) = (
353 ringbuf::Prod::new(send_rb.clone()),
354 ringbuf::Cons::new(send_rb),
355 );
356
357 let shared_state = Arc::new(SharedState::new());
358 let stream = TcpStream {
359 local_addr: src_addr,
360 remote_addr: dst_addr,
361 recv_buffer_cons: recv_cons,
362 send_buffer_prod: send_prod,
363 shared_state: shared_state.clone(),
364 worker_notifier: self.notifier.clone(),
365 };
366
367 let io_handle = SocketIOHandle {
368 recv_buffer_prod: recv_prod,
369 send_buffer_cons: send_cons,
370 shared_state,
371 };
372
373 if self.socket_stream_emitter.try_send(stream).is_ok() {
374 let socket_handle = self.sockets.add(socket);
375 self.socket_maps.insert(socket_handle, io_handle);
376 } else {
377 error!(
378 "[Worker] Failed to emit new TcpStream to application, channel is full or closed. Dropping new connection from {}.",
379 src_addr
380 );
381 }
382
383 Ok(())
384 }
385
386 fn handle_socket_io(
387 socket: &mut Socket,
388 socket_control: &mut SocketIOHandle,
389 ) -> (usize, usize) {
390 let mut bytes_read = 0;
391 let mut bytes_written = 0;
392 let mut notify_read = false;
393
394 if socket.can_recv() {
395 match socket.recv(|buffer| {
396 let n = socket_control.recv_buffer_prod.push_slice(buffer);
397 if n > 0 {
398 bytes_read += n;
399 }
400 (n, buffer.len())
401 }) {
402 Ok(n) => {
403 if n > 0 {
404 notify_read = true;
405 }
406 }
407 Err(_e) => {
408 error!("Socket recv error: {}. Closing read side.", _e);
409 socket_control
410 .shared_state
411 .read_closed
412 .store(true, Ordering::Release);
413 notify_read = true;
414 }
415 }
416 }
417
418 if !socket.is_open()
419 && !socket_control
420 .shared_state
421 .read_closed
422 .load(Ordering::Acquire)
423 {
424 socket_control
425 .shared_state
426 .read_closed
427 .store(true, Ordering::Release);
428 notify_read = true;
429 }
430
431 if notify_read {
432 socket_control.shared_state.recv_waker.wake();
433 }
434
435 let mut notify_write = false;
436
437 while socket.can_send() && !socket_control.send_buffer_cons.is_empty() {
438 match socket.send(|buffer| {
439 let n = socket_control.send_buffer_cons.pop_slice(buffer);
440 (n, buffer.len())
441 }) {
442 Ok(n) if n > 0 => {
443 bytes_written += n;
444 notify_write = true;
445 }
446 _ => break,
447 }
448 }
449
450 if notify_write {
451 socket_control.shared_state.send_waker.wake();
452 }
453
454 (bytes_read, bytes_written)
455 }
456
457 fn prune_sockets(
458 sockets: &mut SocketSet,
459 socket_maps: &mut HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
460 ) -> bool {
461 let initial_len = socket_maps.len();
462 socket_maps.retain(|handle, socket_control| {
463 let socket = sockets.get_mut::<Socket>(*handle);
464
465 if socket_control
466 .shared_state
467 .socket_dropped
468 .load(Ordering::Acquire)
469 {
470 socket.abort();
471 sockets.remove(*handle);
472 return false;
473 }
474
475 if !socket.is_active() {
476 sockets.remove(*handle);
477 return false;
478 }
479
480 true
481 });
482 initial_len != socket_maps.len()
483 }
484}
485
486impl futures::Stream for TcpConnection {
487 type Item = TcpStream;
488
489 fn poll_next(
490 mut self: std::pin::Pin<&mut Self>,
491 cx: &mut std::task::Context<'_>,
492 ) -> std::task::Poll<Option<Self::Item>> {
493 self.socket_stream.poll_recv(cx)
494 }
495}
496
497mod stream {
498 use std::net::SocketAddr;
499 use std::sync::Arc;
500 use std::sync::atomic::{AtomicBool, Ordering};
501 use std::task::{Context, Poll};
502
503 use futures::task::AtomicWaker;
504 use ringbuf::HeapRb;
505 use ringbuf::traits::{Consumer, Observer, Producer};
506 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
507 use tokio::sync::Notify;
508
509 pub(crate) type RbProducer = ringbuf::Prod<Arc<HeapRb<u8>>>;
510 pub(crate) type RbConsumer = ringbuf::Cons<Arc<HeapRb<u8>>>;
511
512 pub(crate) struct SharedState {
513 pub(crate) recv_waker: AtomicWaker,
514 pub(crate) send_waker: AtomicWaker,
515 pub(crate) read_closed: AtomicBool,
516 pub(crate) socket_dropped: AtomicBool,
517 }
518
519 impl SharedState {
520 pub fn new() -> Self {
521 Self {
522 recv_waker: AtomicWaker::new(),
523 send_waker: AtomicWaker::new(),
524 read_closed: AtomicBool::new(false),
525 socket_dropped: AtomicBool::new(false),
526 }
527 }
528 }
529
530 pub struct TcpStream {
531 pub(crate) local_addr: SocketAddr,
532 pub(crate) remote_addr: SocketAddr,
533 pub(crate) recv_buffer_cons: RbConsumer,
534 pub(crate) send_buffer_prod: RbProducer,
535 pub(crate) shared_state: Arc<SharedState>,
536 pub(crate) worker_notifier: Arc<Notify>,
537 }
538
539 impl TcpStream {
540 pub fn local_addr(&self) -> SocketAddr {
541 self.local_addr
542 }
543
544 pub fn remote_addr(&self) -> SocketAddr {
545 self.remote_addr
546 }
547
548 pub fn split(self) -> (ReadHalf<Self>, WriteHalf<Self>) {
549 tokio::io::split(self)
550 }
551 }
552
553 impl AsyncRead for TcpStream {
554 fn poll_read(
555 mut self: std::pin::Pin<&mut Self>,
556 cx: &mut Context<'_>,
557 buf: &mut ReadBuf<'_>,
558 ) -> Poll<std::io::Result<()>> {
559 let len_before = self.recv_buffer_cons.occupied_len();
560
561 if len_before == 0 {
562 if self.shared_state.read_closed.load(Ordering::Acquire) {
563 return Poll::Ready(Ok(()));
564 }
565 self.shared_state.recv_waker.register(cx.waker());
566
567 if self.recv_buffer_cons.is_empty() {
568 return Poll::Pending;
569 }
570 }
571
572 let unfilled_slice = buf.initialize_unfilled();
573 let n = self.recv_buffer_cons.pop_slice(unfilled_slice);
574 buf.advance(n);
575
576 if n > 0 {
577 self.worker_notifier.notify_one();
578 }
579
580 Poll::Ready(Ok(()))
581 }
582 }
583
584 impl AsyncWrite for TcpStream {
585 fn poll_write(
586 mut self: std::pin::Pin<&mut Self>,
587 cx: &mut Context<'_>,
588 buf: &[u8],
589 ) -> Poll<std::io::Result<usize>> {
590 if self.shared_state.socket_dropped.load(Ordering::Acquire) {
591 return Poll::Ready(Err(std::io::Error::new(
592 std::io::ErrorKind::BrokenPipe,
593 "Socket is closing",
594 )));
595 }
596
597 if self.send_buffer_prod.is_full() {
598 self.shared_state.send_waker.register(cx.waker());
599 if self.send_buffer_prod.is_full() {
600 return Poll::Pending;
601 }
602 }
603
604 let n = self.send_buffer_prod.push_slice(buf);
605 if n > 0 {
606 self.worker_notifier.notify_one();
607 }
608
609 Poll::Ready(Ok(n))
610 }
611
612 fn poll_flush(
613 self: std::pin::Pin<&mut Self>,
614 cx: &mut Context<'_>,
615 ) -> Poll<std::io::Result<()>> {
616 if !self.send_buffer_prod.is_empty() {
617 self.shared_state.send_waker.register(cx.waker());
618 if !self.send_buffer_prod.is_empty() {
619 return Poll::Pending;
620 }
621 }
622 Poll::Ready(Ok(()))
623 }
624
625 fn poll_shutdown(
626 mut self: std::pin::Pin<&mut Self>,
627 cx: &mut Context<'_>,
628 ) -> Poll<std::io::Result<()>> {
629 std::task::ready!(self.as_mut().poll_flush(cx))?;
630
631 self.shared_state
632 .socket_dropped
633 .store(true, Ordering::Release);
634 self.worker_notifier.notify_one();
635 Poll::Ready(Ok(()))
636 }
637 }
638
639 impl Drop for TcpStream {
640 fn drop(&mut self) {
641 self.shared_state
642 .socket_dropped
643 .store(true, Ordering::Release);
644 self.worker_notifier.notify_one();
645 }
646 }
647}