1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9use ringbuf::HeapRb;
10use ringbuf::traits::{Consumer, Observer, Producer};
11use smoltcp::iface::{Interface, PollResult, SocketSet};
12use smoltcp::socket::tcp::{CongestionControl, Socket, SocketBuffer, State};
13use smoltcp::wire::{IpCidr, IpProtocol, TcpPacket};
14use tokio::sync::{Notify, broadcast, mpsc};
15use tokio::task::JoinHandle;
16
17use crate::buffer::BufferPool;
18use crate::device::NetstackDevice;
19use crate::stack::{IpPacket, NetStackConfig, Packet};
20use crate::{debug, error};
21
22pub use stream::TcpStream;
23pub(crate) use stream::{RbConsumer, RbProducer, SharedState};
24
25struct TcpConnectionWorker {
26 config: Arc<NetStackConfig>,
27 device_injector: mpsc::Sender<Packet>,
28 iface: Interface,
29 sockets: SocketSet<'static>,
30 socket_maps: HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
31 inbound: mpsc::Receiver<Packet>,
32 socket_stream_emitter: mpsc::Sender<TcpStream>,
33 notifier: Arc<Notify>,
34 shutdown_rx: broadcast::Receiver<()>,
35}
36
37pub(crate) struct SocketIOHandle {
38 recv_buffer_prod: RbProducer,
39 send_buffer_cons: RbConsumer,
40 shared_state: Arc<SharedState>,
41}
42
43pub struct TcpConnection {
44 socket_stream: mpsc::Receiver<TcpStream>,
45 shutdown_tx: broadcast::Sender<()>,
46 _handles: Vec<JoinHandle<()>>,
47}
48
49impl Drop for TcpConnection {
50 fn drop(&mut self) {
51 let _ = self.shutdown_tx.send(());
52 }
53}
54
55impl TcpConnection {
56 pub fn new(
57 config: NetStackConfig,
58 inbound: mpsc::Receiver<Packet>,
59 outbound: mpsc::Sender<Packet>,
60 buffer_pool: Arc<BufferPool>,
61 ) -> Self {
62 let num_workers = config.number_workers;
63 let config = Arc::new(config);
64
65 let (aggregated_socket_stream_emitter, aggregated_socket_stream_receiver) =
66 mpsc::channel::<TcpStream>(config.channel_size);
67
68 let (shutdown_tx, _) = broadcast::channel(1);
69
70 let mut _handles = Vec::new();
71 let mut worker_senders = Vec::with_capacity(num_workers);
72
73 for _i in 0..num_workers {
74 let (worker_inbound_sender, worker_inbound_receiver) =
75 mpsc::channel(config.channel_size);
76 worker_senders.push(worker_inbound_sender);
77
78 let mut device = NetstackDevice::new(outbound.clone(), buffer_pool.clone(), &config);
79 let device_injector = device.create_injector();
80 let iface = Self::create_interface(&config, &mut device);
81 let notifier = Arc::new(Notify::new());
82 let shutdown_rx = shutdown_tx.subscribe();
83
84 let mut worker = TcpConnectionWorker {
85 config: config.clone(),
86 device_injector,
87 iface,
88 sockets: SocketSet::new(vec![]),
89 socket_maps: HashMap::new(),
90 inbound: worker_inbound_receiver,
91 socket_stream_emitter: aggregated_socket_stream_emitter.clone(),
92 notifier: notifier.clone(),
93 shutdown_rx,
94 };
95
96 let worker_handle = tokio::spawn(async move {
97 if let Err(_e) = worker.accept_loop(device).await {
98 error!("[Worker {}] exited with error: {}", _i, _e);
99 }
100 });
101 _handles.push(worker_handle);
102 }
103
104 let dispatcher_shutdown_rx = shutdown_tx.subscribe();
105 let dispatcher_handle = tokio::spawn(Self::distribute_packets(
106 inbound,
107 worker_senders,
108 dispatcher_shutdown_rx,
109 ));
110 _handles.push(dispatcher_handle);
111
112 TcpConnection {
113 socket_stream: aggregated_socket_stream_receiver,
114 shutdown_tx,
115 _handles,
116 }
117 }
118
119 async fn distribute_packets(
120 mut inbound: mpsc::Receiver<Packet>,
121 worker_senders: Vec<mpsc::Sender<Packet>>,
122 mut shutdown_rx: broadcast::Receiver<()>,
123 ) {
124 let num_workers = worker_senders.len();
125 loop {
126 tokio::select! {
127 _ = shutdown_rx.recv() => {
128 debug!("[Dispatcher] received shutdown signal, exiting.");
129 break;
130 }
131 maybe_packet = inbound.recv() => {
132 if let Some(packet) = maybe_packet {
133 let worker_index = match IpPacket::new_checked(packet.data()) {
134 Ok(ip_packet) => {
135 if ip_packet.protocol() == IpProtocol::Tcp {
136 if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
137 let mut addr1 =
138 SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
139 let mut addr2 =
140 SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
141 if addr1 > addr2 {
142 std::mem::swap(&mut addr1, &mut addr2);
143 }
144 let mut hasher = DefaultHasher::new();
145 addr1.hash(&mut hasher);
146 addr2.hash(&mut hasher);
147 (hasher.finish() % num_workers as u64) as usize
148 } else { 0 }
149 } else { 0 }
150 }
151 Err(_) => 0,
152 };
153
154 if worker_senders[worker_index].send(packet).await.is_err() {
155 error!(
156 "[Dispatcher] Failed to send packet to worker {}, channel closed.",
157 worker_index
158 );
159 break;
160 }
161 } else {
162 debug!("[Dispatcher] Inbound channel closed, exiting.");
163 break;
164 }
165 }
166 }
167 }
168 debug!("[Dispatcher] stopped.");
169 }
170
171 fn create_interface(config: &NetStackConfig, device: &mut NetstackDevice) -> Interface {
172 let mut iface_config = smoltcp::iface::Config::new(smoltcp::wire::HardwareAddress::Ip);
173 iface_config.random_seed = rand::random();
174 let mut iface =
175 smoltcp::iface::Interface::new(iface_config, device, smoltcp::time::Instant::now());
176
177 iface.set_any_ip(true);
178 iface.update_ip_addrs(|ip_addrs| {
179 let _ = ip_addrs.push(IpCidr::new(config.ipv4_addr.into(), config.ipv4_prefix_len));
180 let _ = ip_addrs.push(IpCidr::new(config.ipv6_addr.into(), config.ipv6_prefix_len));
181 });
182
183 iface
184 .routes_mut()
185 .add_default_ipv4_route(config.ipv4_addr)
186 .expect("Failed to add default IPv4 route");
187 iface
188 .routes_mut()
189 .add_default_ipv6_route(config.ipv6_addr)
190 .expect("Failed to add default IPv6 route");
191
192 iface
193 }
194}
195
196impl TcpConnectionWorker {
197 async fn accept_loop(&mut self, mut device: NetstackDevice) -> std::io::Result<()> {
198 loop {
199 let mut progress = true;
201 while progress {
202 progress = false;
203
204 while let Ok(packet) = self.inbound.try_recv() {
206 if let Err(_e) = self.process_inbound_frame(packet).await {
207 error!("Error processing inbound frame: {}", _e);
208 }
209 progress = true;
210 }
211
212 let now = smoltcp::time::Instant::now();
214 if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
215 progress = true;
216 }
217
218 let mut total_bytes_processed = 0;
220 for (socket_handle, socket_control) in self.socket_maps.iter_mut() {
221 let socket = self.sockets.get_mut::<Socket>(*socket_handle);
222 let (read, written) = Self::handle_socket_io(socket, socket_control);
223 if read > 0 || written > 0 {
224 total_bytes_processed += read + written;
225 }
226 }
227 if total_bytes_processed > 0 {
228 progress = true;
229 }
230
231 if Self::prune_sockets(&mut self.sockets, &mut self.socket_maps) {
233 progress = true;
234 }
235
236 let now = smoltcp::time::Instant::now();
238 if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
239 progress = true;
240 }
241
242 if progress && total_bytes_processed == 0 && self.inbound.is_empty() {
243 tokio::task::yield_now().await;
244 }
245 }
246
247 let now = smoltcp::time::Instant::now();
252 let smoltcp_delay = self.iface.poll_delay(now, &self.sockets).map(|d| d.into());
253
254 tokio::select! {
255 biased;
256 _ = self.shutdown_rx.recv() => {
257 debug!("Worker received shutdown signal, exiting gracefully.");
258 return Ok(());
259 }
260
261 maybe_packet = self.inbound.recv() => {
263 match maybe_packet {
264 Some(packet) => {
265 if let Err(_e) = self.process_inbound_frame(packet).await {
266 error!("Error processing inbound frame: {}", _e);
267 }
268 },
270 None => return Ok(()), }
272 },
273
274 _ = self.notifier.notified() => {
276 },
279
280 _ = async {
282 match smoltcp_delay {
283 Some(delay) if delay > Duration::ZERO => tokio::time::sleep(delay).await,
285 _ => std::future::pending().await,
290 }
291 } => {
292 },
295 }
296 }
297 }
298
299 async fn process_inbound_frame(&mut self, frame: Packet) -> std::io::Result<()> {
300 if let Ok(ip_packet) = IpPacket::new_checked(frame.data())
301 && ip_packet.protocol() == IpProtocol::Tcp
302 && let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload())
303 && tcp_packet.syn()
304 && !tcp_packet.ack()
305 {
306 self.accept_new_connection(&ip_packet, &tcp_packet)?;
307 }
308
309 self.device_injector
310 .try_send(frame)
311 .map_err(|e| std::io::Error::other(e.to_string()))?;
312 Ok(())
313 }
314
315 fn accept_new_connection(
316 &mut self,
317 ip_packet: &IpPacket<&[u8]>,
318 tcp_packet: &TcpPacket<&[u8]>,
319 ) -> std::io::Result<()> {
320 let src_addr = SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
321 let dst_addr = SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
322
323 let mut socket = Socket::new(
324 SocketBuffer::new(vec![0u8; self.config.tcp_recv_buffer_size]),
325 SocketBuffer::new(vec![0u8; self.config.tcp_send_buffer_size]),
326 );
327
328 socket.set_keep_alive(Some(self.config.tcp_keep_alive.into()));
329 socket.set_timeout(Some(self.config.tcp_timeout.into()));
330 socket.set_nagle_enabled(false);
331 socket.set_congestion_control(CongestionControl::Cubic);
332
333 socket
334 .listen(dst_addr)
335 .map_err(|e| std::io::Error::other(e.to_string()))?;
336
337 let recv_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_recv_buffer_size));
338 let (recv_prod, recv_cons) = (
339 ringbuf::Prod::new(recv_rb.clone()),
340 ringbuf::Cons::new(recv_rb),
341 );
342
343 let send_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_send_buffer_size));
344 let (send_prod, send_cons) = (
345 ringbuf::Prod::new(send_rb.clone()),
346 ringbuf::Cons::new(send_rb),
347 );
348
349 let shared_state = Arc::new(SharedState::new());
350 let stream = TcpStream {
351 local_addr: src_addr,
352 remote_addr: dst_addr,
353 recv_buffer_cons: recv_cons,
354 send_buffer_prod: send_prod,
355 shared_state: shared_state.clone(),
356 worker_notifier: self.notifier.clone(),
357 };
358
359 let io_handle = SocketIOHandle {
360 recv_buffer_prod: recv_prod,
361 send_buffer_cons: send_cons,
362 shared_state,
363 };
364
365 if self.socket_stream_emitter.try_send(stream).is_ok() {
366 let socket_handle = self.sockets.add(socket);
367 self.socket_maps.insert(socket_handle, io_handle);
368 } else {
369 error!(
370 "[Worker] Failed to emit new TcpStream to application, channel is full or closed. Dropping new connection from {}.",
371 src_addr
372 );
373 }
374
375 Ok(())
376 }
377
378 fn handle_socket_io(
379 socket: &mut Socket,
380 socket_control: &mut SocketIOHandle,
381 ) -> (usize, usize) {
382 let mut bytes_read = 0;
383 let mut bytes_written = 0;
384 let mut notify_read = false;
385
386 if socket.can_recv() {
387 match socket.recv(|buffer| {
388 let n = socket_control.recv_buffer_prod.push_slice(buffer);
389 if n > 0 {
390 bytes_read += n;
391 }
392 (n, buffer.len())
393 }) {
394 Ok(n) => {
395 if n > 0 {
396 notify_read = true;
397 }
398 }
399 Err(_e) => {
400 error!("Socket recv error: {}. Closing read side.", _e);
401 socket_control
402 .shared_state
403 .read_closed
404 .store(true, Ordering::Release);
405 notify_read = true;
406 }
407 }
408 }
409
410 if !socket.is_open()
411 && !socket_control
412 .shared_state
413 .read_closed
414 .load(Ordering::Acquire)
415 {
416 socket_control
417 .shared_state
418 .read_closed
419 .store(true, Ordering::Release);
420 notify_read = true;
421 }
422
423 if notify_read {
424 socket_control.shared_state.recv_waker.wake();
425 }
426
427 let mut notify_write = false;
428
429 while socket.can_send() && !socket_control.send_buffer_cons.is_empty() {
430 match socket.send(|buffer| {
431 let n = socket_control.send_buffer_cons.pop_slice(buffer);
432 (n, buffer.len())
433 }) {
434 Ok(n) if n > 0 => {
435 bytes_written += n;
436 notify_write = true;
437 }
438 _ => break,
439 }
440 }
441
442 if notify_write {
443 socket_control.shared_state.send_waker.wake();
444 }
445
446 (bytes_read, bytes_written)
447 }
448
449 fn prune_sockets(
450 sockets: &mut SocketSet,
451 socket_maps: &mut HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
452 ) -> bool {
453 let initial_len = socket_maps.len();
454 socket_maps.retain(|handle, socket_control| {
455 let socket = sockets.get_mut::<Socket>(*handle);
456
457 if socket_control
458 .shared_state
459 .socket_dropped
460 .load(Ordering::Acquire)
461 {
462 socket.abort();
463 }
464
465 if !socket.is_active() && socket.state() == State::Closed {
466 sockets.remove(*handle);
467 return false;
468 }
469
470 true
471 });
472 initial_len != socket_maps.len()
473 }
474}
475
476impl futures::Stream for TcpConnection {
477 type Item = TcpStream;
478
479 fn poll_next(
480 mut self: std::pin::Pin<&mut Self>,
481 cx: &mut std::task::Context<'_>,
482 ) -> std::task::Poll<Option<Self::Item>> {
483 self.socket_stream.poll_recv(cx)
484 }
485}
486
487mod stream {
488 use std::net::SocketAddr;
489 use std::sync::Arc;
490 use std::sync::atomic::{AtomicBool, Ordering};
491 use std::task::{Context, Poll};
492
493 use futures::task::AtomicWaker;
494 use ringbuf::HeapRb;
495 use ringbuf::traits::{Consumer, Observer, Producer};
496 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
497 use tokio::sync::Notify;
498
499 pub(crate) type RbProducer = ringbuf::Prod<Arc<HeapRb<u8>>>;
500 pub(crate) type RbConsumer = ringbuf::Cons<Arc<HeapRb<u8>>>;
501
502 pub(crate) struct SharedState {
503 pub(crate) recv_waker: AtomicWaker,
504 pub(crate) send_waker: AtomicWaker,
505 pub(crate) read_closed: AtomicBool,
506 pub(crate) socket_dropped: AtomicBool,
507 }
508
509 impl SharedState {
510 pub fn new() -> Self {
511 Self {
512 recv_waker: AtomicWaker::new(),
513 send_waker: AtomicWaker::new(),
514 read_closed: AtomicBool::new(false),
515 socket_dropped: AtomicBool::new(false),
516 }
517 }
518 }
519
520 pub struct TcpStream {
521 pub(crate) local_addr: SocketAddr,
522 pub(crate) remote_addr: SocketAddr,
523 pub(crate) recv_buffer_cons: RbConsumer,
524 pub(crate) send_buffer_prod: RbProducer,
525 pub(crate) shared_state: Arc<SharedState>,
526 pub(crate) worker_notifier: Arc<Notify>,
527 }
528
529 impl TcpStream {
530 pub fn local_addr(&self) -> SocketAddr {
531 self.local_addr
532 }
533
534 pub fn remote_addr(&self) -> SocketAddr {
535 self.remote_addr
536 }
537
538 pub fn split(self) -> (ReadHalf<Self>, WriteHalf<Self>) {
539 tokio::io::split(self)
540 }
541 }
542
543 impl AsyncRead for TcpStream {
544 fn poll_read(
545 mut self: std::pin::Pin<&mut Self>,
546 cx: &mut Context<'_>,
547 buf: &mut ReadBuf<'_>,
548 ) -> Poll<std::io::Result<()>> {
549 if self.recv_buffer_cons.is_empty() {
550 if self.shared_state.read_closed.load(Ordering::Acquire) {
551 return Poll::Ready(Ok(()));
552 }
553 self.shared_state.recv_waker.register(cx.waker());
554 return Poll::Pending;
555 }
556
557 let unfilled_slice = buf.initialize_unfilled();
558 let n = self.recv_buffer_cons.pop_slice(unfilled_slice);
559 buf.advance(n);
560
561 self.worker_notifier.notify_one();
562
563 Poll::Ready(Ok(()))
564 }
565 }
566
567 impl AsyncWrite for TcpStream {
568 fn poll_write(
569 mut self: std::pin::Pin<&mut Self>,
570 cx: &mut Context<'_>,
571 buf: &[u8],
572 ) -> Poll<std::io::Result<usize>> {
573 if self.shared_state.socket_dropped.load(Ordering::Relaxed) {
574 return Poll::Ready(Err(std::io::Error::new(
575 std::io::ErrorKind::BrokenPipe,
576 "Socket is closing",
577 )));
578 }
579
580 if self.send_buffer_prod.is_full() {
581 self.shared_state.send_waker.register(cx.waker());
582 return Poll::Pending;
583 }
584
585 let n = self.send_buffer_prod.push_slice(buf);
586 if n > 0 {
587 self.worker_notifier.notify_one();
588 }
589
590 Poll::Ready(Ok(n))
591 }
592
593 fn poll_flush(
594 self: std::pin::Pin<&mut Self>,
595 cx: &mut Context<'_>,
596 ) -> Poll<std::io::Result<()>> {
597 if !self.send_buffer_prod.is_empty() {
598 self.shared_state.send_waker.register(cx.waker());
599 self.worker_notifier.notify_one();
600 return Poll::Pending;
601 }
602 Poll::Ready(Ok(()))
603 }
604
605 fn poll_shutdown(
606 mut self: std::pin::Pin<&mut Self>,
607 cx: &mut Context<'_>,
608 ) -> Poll<std::io::Result<()>> {
609 std::task::ready!(self.as_mut().poll_flush(cx))?;
610
611 self.shared_state
612 .socket_dropped
613 .store(true, Ordering::Release);
614 self.worker_notifier.notify_one();
615 Poll::Ready(Ok(()))
616 }
617 }
618
619 impl Drop for TcpStream {
620 fn drop(&mut self) {
621 self.shared_state
622 .socket_dropped
623 .store(true, Ordering::Release);
624 self.worker_notifier.notify_one();
625 }
626 }
627}