ombrac_netstack/
tcp.rs

1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9use ringbuf::HeapRb;
10use ringbuf::traits::{Consumer, Observer, Producer};
11use smoltcp::iface::{Interface, PollResult, SocketSet};
12use smoltcp::socket::tcp::{CongestionControl, Socket, SocketBuffer, State};
13use smoltcp::wire::{IpCidr, IpProtocol, TcpPacket};
14use tokio::sync::{Notify, broadcast, mpsc};
15use tokio::task::JoinHandle;
16
17use crate::buffer::BufferPool;
18use crate::device::NetstackDevice;
19use crate::stack::{IpPacket, NetStackConfig, Packet};
20use crate::{debug, error};
21
22pub use stream::TcpStream;
23pub(crate) use stream::{RbConsumer, RbProducer, SharedState};
24
25struct TcpConnectionWorker {
26    config: Arc<NetStackConfig>,
27    device_injector: mpsc::Sender<Packet>,
28    iface: Interface,
29    sockets: SocketSet<'static>,
30    socket_maps: HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
31    inbound: mpsc::Receiver<Packet>,
32    socket_stream_emitter: mpsc::Sender<TcpStream>,
33    notifier: Arc<Notify>,
34    shutdown_rx: broadcast::Receiver<()>,
35}
36
37pub(crate) struct SocketIOHandle {
38    recv_buffer_prod: RbProducer,
39    send_buffer_cons: RbConsumer,
40    shared_state: Arc<SharedState>,
41}
42
43pub struct TcpConnection {
44    socket_stream: mpsc::Receiver<TcpStream>,
45    shutdown_tx: broadcast::Sender<()>,
46    _handles: Vec<JoinHandle<()>>,
47}
48
49impl Drop for TcpConnection {
50    fn drop(&mut self) {
51        let _ = self.shutdown_tx.send(());
52    }
53}
54
55impl TcpConnection {
56    pub fn new(
57        config: NetStackConfig,
58        inbound: mpsc::Receiver<Packet>,
59        outbound: mpsc::Sender<Packet>,
60        buffer_pool: Arc<BufferPool>,
61    ) -> Self {
62        let num_workers = config.number_workers;
63        let config = Arc::new(config);
64
65        let (aggregated_socket_stream_emitter, aggregated_socket_stream_receiver) =
66            mpsc::channel::<TcpStream>(config.channel_size);
67
68        let (shutdown_tx, _) = broadcast::channel(1);
69
70        let mut _handles = Vec::new();
71        let mut worker_senders = Vec::with_capacity(num_workers);
72
73        for _i in 0..num_workers {
74            let (worker_inbound_sender, worker_inbound_receiver) =
75                mpsc::channel(config.channel_size);
76            worker_senders.push(worker_inbound_sender);
77
78            let mut device = NetstackDevice::new(outbound.clone(), buffer_pool.clone(), &config);
79            let device_injector = device.create_injector();
80            let iface = Self::create_interface(&config, &mut device);
81            let notifier = Arc::new(Notify::new());
82            let shutdown_rx = shutdown_tx.subscribe();
83
84            let mut worker = TcpConnectionWorker {
85                config: config.clone(),
86                device_injector,
87                iface,
88                sockets: SocketSet::new(vec![]),
89                socket_maps: HashMap::new(),
90                inbound: worker_inbound_receiver,
91                socket_stream_emitter: aggregated_socket_stream_emitter.clone(),
92                notifier: notifier.clone(),
93                shutdown_rx,
94            };
95
96            let worker_handle = tokio::spawn(async move {
97                if let Err(_e) = worker.accept_loop(device).await {
98                    error!("[Worker {}] exited with error: {}", _i, _e);
99                }
100            });
101            _handles.push(worker_handle);
102        }
103
104        let dispatcher_shutdown_rx = shutdown_tx.subscribe();
105        let dispatcher_handle = tokio::spawn(Self::distribute_packets(
106            inbound,
107            worker_senders,
108            dispatcher_shutdown_rx,
109        ));
110        _handles.push(dispatcher_handle);
111
112        TcpConnection {
113            socket_stream: aggregated_socket_stream_receiver,
114            shutdown_tx,
115            _handles,
116        }
117    }
118
119    async fn distribute_packets(
120        mut inbound: mpsc::Receiver<Packet>,
121        worker_senders: Vec<mpsc::Sender<Packet>>,
122        mut shutdown_rx: broadcast::Receiver<()>,
123    ) {
124        let num_workers = worker_senders.len();
125        loop {
126            tokio::select! {
127                _ = shutdown_rx.recv() => {
128                    debug!("[Dispatcher] received shutdown signal, exiting.");
129                    break;
130                }
131                maybe_packet = inbound.recv() => {
132                    if let Some(packet) = maybe_packet {
133                        let worker_index = match IpPacket::new_checked(packet.data()) {
134                            Ok(ip_packet) => {
135                                if ip_packet.protocol() == IpProtocol::Tcp {
136                                    if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
137                                        let mut addr1 =
138                                            SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
139                                        let mut addr2 =
140                                            SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
141                                        if addr1 > addr2 {
142                                            std::mem::swap(&mut addr1, &mut addr2);
143                                        }
144                                        let mut hasher = DefaultHasher::new();
145                                        addr1.hash(&mut hasher);
146                                        addr2.hash(&mut hasher);
147                                        (hasher.finish() % num_workers as u64) as usize
148                                    } else { 0 }
149                                } else { 0 }
150                            }
151                            Err(_) => 0,
152                        };
153
154                        if worker_senders[worker_index].send(packet).await.is_err() {
155                            error!(
156                                "[Dispatcher] Failed to send packet to worker {}, channel closed.",
157                                worker_index
158                            );
159                            break;
160                        }
161                    } else {
162                        debug!("[Dispatcher] Inbound channel closed, exiting.");
163                        break;
164                    }
165                }
166            }
167        }
168        debug!("[Dispatcher] stopped.");
169    }
170
171    fn create_interface(config: &NetStackConfig, device: &mut NetstackDevice) -> Interface {
172        let mut iface_config = smoltcp::iface::Config::new(smoltcp::wire::HardwareAddress::Ip);
173        iface_config.random_seed = rand::random();
174        let mut iface =
175            smoltcp::iface::Interface::new(iface_config, device, smoltcp::time::Instant::now());
176
177        iface.set_any_ip(true);
178        iface.update_ip_addrs(|ip_addrs| {
179            let _ = ip_addrs.push(IpCidr::new(config.ipv4_addr.into(), config.ipv4_prefix_len));
180            let _ = ip_addrs.push(IpCidr::new(config.ipv6_addr.into(), config.ipv6_prefix_len));
181        });
182
183        iface
184            .routes_mut()
185            .add_default_ipv4_route(config.ipv4_addr)
186            .expect("Failed to add default IPv4 route");
187        iface
188            .routes_mut()
189            .add_default_ipv6_route(config.ipv6_addr)
190            .expect("Failed to add default IPv6 route");
191
192        iface
193    }
194}
195
196impl TcpConnectionWorker {
197    async fn accept_loop(&mut self, mut device: NetstackDevice) -> std::io::Result<()> {
198        loop {
199            // --- STAGE 1: WORK-DRAINING ---
200            let mut progress = true;
201            while progress {
202                progress = false;
203
204                // Drain inbound packets
205                while let Ok(packet) = self.inbound.try_recv() {
206                    if let Err(_e) = self.process_inbound_frame(packet).await {
207                        error!("Error processing inbound frame: {}", _e);
208                    }
209                    progress = true;
210                }
211
212                // Poll smoltcp for network events
213                let now = smoltcp::time::Instant::now();
214                if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
215                    progress = true;
216                }
217
218                // Handle IO for all active sockets
219                let mut total_bytes_processed = 0;
220                for (socket_handle, socket_control) in self.socket_maps.iter_mut() {
221                    let socket = self.sockets.get_mut::<Socket>(*socket_handle);
222                    let (read, written) = Self::handle_socket_io(socket, socket_control);
223                    if read > 0 || written > 0 {
224                        total_bytes_processed += read + written;
225                    }
226                }
227                if total_bytes_processed > 0 {
228                    progress = true;
229                }
230
231                // Prune any closed/aborted sockets
232                if Self::prune_sockets(&mut self.sockets, &mut self.socket_maps) {
233                    progress = true;
234                }
235
236                // Poll smoltcp again after IO
237                let now = smoltcp::time::Instant::now();
238                if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
239                    progress = true;
240                }
241
242                if progress && total_bytes_processed == 0 && self.inbound.is_empty() {
243                    tokio::task::yield_now().await;
244                }
245            }
246
247            // --- STAGE 2: IDLE / WAITING ---
248            // If we are here, it means the work-draining loop completed a full iteration
249            // without any work being done. We are now idle and can safely wait for the next event.
250
251            let now = smoltcp::time::Instant::now();
252            let smoltcp_delay = self.iface.poll_delay(now, &self.sockets).map(|d| d.into());
253
254            tokio::select! {
255                biased;
256                _ = self.shutdown_rx.recv() => {
257                    debug!("Worker received shutdown signal, exiting gracefully.");
258                    return Ok(());
259                }
260
261                // Wait for a new inbound packet to arrive.
262                maybe_packet = self.inbound.recv() => {
263                    match maybe_packet {
264                        Some(packet) => {
265                            if let Err(_e) = self.process_inbound_frame(packet).await {
266                                error!("Error processing inbound frame: {}", _e);
267                            }
268                            // After processing, continue to the top of 'main_loop to enter the work-draining phase.
269                        },
270                        None => return Ok(()), // Channel closed, exit worker.
271                    }
272                },
273
274                // Wait for a notification from a TcpStream (e.g., app wrote data).
275                _ = self.notifier.notified() => {
276                    // Notification received. No action needed here.
277                    // We'll simply loop back to the top and enter the work-draining phase.
278                },
279
280                // Wait for smoltcp's timer to expire (e.g., for retransmissions).
281                _ = async {
282                    match smoltcp_delay {
283                        // If there's a specific delay, we sleep for that duration.
284                        Some(delay) if delay > Duration::ZERO => tokio::time::sleep(delay).await,
285                        // In all other cases (no timer, or timer is for "now"),
286                        // we wait on a future that never completes. This effectively
287                        // disables this select arm and forces the select to wait for
288                        // one of the other arms (shutdown, inbound, notifier).
289                        _ => std::future::pending().await,
290                    }
291                } => {
292                    // Timer expired. No action needed here.
293                    // We'll simply loop back to the top and the work-draining phase will poll smoltcp.
294                },
295            }
296        }
297    }
298
299    async fn process_inbound_frame(&mut self, frame: Packet) -> std::io::Result<()> {
300        if let Ok(ip_packet) = IpPacket::new_checked(frame.data())
301            && ip_packet.protocol() == IpProtocol::Tcp
302            && let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload())
303            && tcp_packet.syn()
304            && !tcp_packet.ack()
305        {
306            self.accept_new_connection(&ip_packet, &tcp_packet)?;
307        }
308
309        self.device_injector
310            .try_send(frame)
311            .map_err(|e| std::io::Error::other(e.to_string()))?;
312        Ok(())
313    }
314
315    fn accept_new_connection(
316        &mut self,
317        ip_packet: &IpPacket<&[u8]>,
318        tcp_packet: &TcpPacket<&[u8]>,
319    ) -> std::io::Result<()> {
320        let src_addr = SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
321        let dst_addr = SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
322
323        let mut socket = Socket::new(
324            SocketBuffer::new(vec![0u8; self.config.tcp_recv_buffer_size]),
325            SocketBuffer::new(vec![0u8; self.config.tcp_send_buffer_size]),
326        );
327
328        socket.set_keep_alive(Some(self.config.tcp_keep_alive.into()));
329        socket.set_timeout(Some(self.config.tcp_timeout.into()));
330        socket.set_nagle_enabled(false);
331        socket.set_congestion_control(CongestionControl::Cubic);
332
333        socket
334            .listen(dst_addr)
335            .map_err(|e| std::io::Error::other(e.to_string()))?;
336
337        let recv_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_recv_buffer_size));
338        let (recv_prod, recv_cons) = (
339            ringbuf::Prod::new(recv_rb.clone()),
340            ringbuf::Cons::new(recv_rb),
341        );
342
343        let send_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_send_buffer_size));
344        let (send_prod, send_cons) = (
345            ringbuf::Prod::new(send_rb.clone()),
346            ringbuf::Cons::new(send_rb),
347        );
348
349        let shared_state = Arc::new(SharedState::new());
350        let stream = TcpStream {
351            local_addr: src_addr,
352            remote_addr: dst_addr,
353            recv_buffer_cons: recv_cons,
354            send_buffer_prod: send_prod,
355            shared_state: shared_state.clone(),
356            worker_notifier: self.notifier.clone(),
357        };
358
359        let io_handle = SocketIOHandle {
360            recv_buffer_prod: recv_prod,
361            send_buffer_cons: send_cons,
362            shared_state,
363        };
364
365        if self.socket_stream_emitter.try_send(stream).is_ok() {
366            let socket_handle = self.sockets.add(socket);
367            self.socket_maps.insert(socket_handle, io_handle);
368        } else {
369            error!(
370                "[Worker] Failed to emit new TcpStream to application, channel is full or closed. Dropping new connection from {}.",
371                src_addr
372            );
373        }
374
375        Ok(())
376    }
377
378    fn handle_socket_io(
379        socket: &mut Socket,
380        socket_control: &mut SocketIOHandle,
381    ) -> (usize, usize) {
382        let mut bytes_read = 0;
383        let mut bytes_written = 0;
384        let mut notify_read = false;
385
386        if socket.can_recv() {
387            match socket.recv(|buffer| {
388                let n = socket_control.recv_buffer_prod.push_slice(buffer);
389                if n > 0 {
390                    bytes_read += n;
391                }
392                (n, buffer.len())
393            }) {
394                Ok(n) => {
395                    if n > 0 {
396                        notify_read = true;
397                    }
398                }
399                Err(_e) => {
400                    error!("Socket recv error: {}. Closing read side.", _e);
401                    socket_control
402                        .shared_state
403                        .read_closed
404                        .store(true, Ordering::Release);
405                    notify_read = true;
406                }
407            }
408        }
409
410        if !socket.is_open()
411            && !socket_control
412                .shared_state
413                .read_closed
414                .load(Ordering::Acquire)
415        {
416            socket_control
417                .shared_state
418                .read_closed
419                .store(true, Ordering::Release);
420            notify_read = true;
421        }
422
423        if notify_read {
424            socket_control.shared_state.recv_waker.wake();
425        }
426
427        let mut notify_write = false;
428
429        while socket.can_send() && !socket_control.send_buffer_cons.is_empty() {
430            match socket.send(|buffer| {
431                let n = socket_control.send_buffer_cons.pop_slice(buffer);
432                (n, buffer.len())
433            }) {
434                Ok(n) if n > 0 => {
435                    bytes_written += n;
436                    notify_write = true;
437                }
438                _ => break,
439            }
440        }
441
442        if notify_write {
443            socket_control.shared_state.send_waker.wake();
444        }
445
446        (bytes_read, bytes_written)
447    }
448
449    fn prune_sockets(
450        sockets: &mut SocketSet,
451        socket_maps: &mut HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
452    ) -> bool {
453        let initial_len = socket_maps.len();
454        socket_maps.retain(|handle, socket_control| {
455            let socket = sockets.get_mut::<Socket>(*handle);
456
457            if socket_control
458                .shared_state
459                .socket_dropped
460                .load(Ordering::Acquire)
461            {
462                socket.abort();
463            }
464
465            if !socket.is_active() && socket.state() == State::Closed {
466                sockets.remove(*handle);
467                return false;
468            }
469
470            true
471        });
472        initial_len != socket_maps.len()
473    }
474}
475
476impl futures::Stream for TcpConnection {
477    type Item = TcpStream;
478
479    fn poll_next(
480        mut self: std::pin::Pin<&mut Self>,
481        cx: &mut std::task::Context<'_>,
482    ) -> std::task::Poll<Option<Self::Item>> {
483        self.socket_stream.poll_recv(cx)
484    }
485}
486
487mod stream {
488    use std::net::SocketAddr;
489    use std::sync::Arc;
490    use std::sync::atomic::{AtomicBool, Ordering};
491    use std::task::{Context, Poll};
492
493    use futures::task::AtomicWaker;
494    use ringbuf::HeapRb;
495    use ringbuf::traits::{Consumer, Observer, Producer};
496    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
497    use tokio::sync::Notify;
498
499    pub(crate) type RbProducer = ringbuf::Prod<Arc<HeapRb<u8>>>;
500    pub(crate) type RbConsumer = ringbuf::Cons<Arc<HeapRb<u8>>>;
501
502    pub(crate) struct SharedState {
503        pub(crate) recv_waker: AtomicWaker,
504        pub(crate) send_waker: AtomicWaker,
505        pub(crate) read_closed: AtomicBool,
506        pub(crate) socket_dropped: AtomicBool,
507    }
508
509    impl SharedState {
510        pub fn new() -> Self {
511            Self {
512                recv_waker: AtomicWaker::new(),
513                send_waker: AtomicWaker::new(),
514                read_closed: AtomicBool::new(false),
515                socket_dropped: AtomicBool::new(false),
516            }
517        }
518    }
519
520    pub struct TcpStream {
521        pub(crate) local_addr: SocketAddr,
522        pub(crate) remote_addr: SocketAddr,
523        pub(crate) recv_buffer_cons: RbConsumer,
524        pub(crate) send_buffer_prod: RbProducer,
525        pub(crate) shared_state: Arc<SharedState>,
526        pub(crate) worker_notifier: Arc<Notify>,
527    }
528
529    impl TcpStream {
530        pub fn local_addr(&self) -> SocketAddr {
531            self.local_addr
532        }
533
534        pub fn remote_addr(&self) -> SocketAddr {
535            self.remote_addr
536        }
537
538        pub fn split(self) -> (ReadHalf<Self>, WriteHalf<Self>) {
539            tokio::io::split(self)
540        }
541    }
542
543    impl AsyncRead for TcpStream {
544        fn poll_read(
545            mut self: std::pin::Pin<&mut Self>,
546            cx: &mut Context<'_>,
547            buf: &mut ReadBuf<'_>,
548        ) -> Poll<std::io::Result<()>> {
549            if self.recv_buffer_cons.is_empty() {
550                if self.shared_state.read_closed.load(Ordering::Acquire) {
551                    return Poll::Ready(Ok(()));
552                }
553                self.shared_state.recv_waker.register(cx.waker());
554                return Poll::Pending;
555            }
556
557            let unfilled_slice = buf.initialize_unfilled();
558            let n = self.recv_buffer_cons.pop_slice(unfilled_slice);
559            buf.advance(n);
560
561            self.worker_notifier.notify_one();
562
563            Poll::Ready(Ok(()))
564        }
565    }
566
567    impl AsyncWrite for TcpStream {
568        fn poll_write(
569            mut self: std::pin::Pin<&mut Self>,
570            cx: &mut Context<'_>,
571            buf: &[u8],
572        ) -> Poll<std::io::Result<usize>> {
573            if self.shared_state.socket_dropped.load(Ordering::Relaxed) {
574                return Poll::Ready(Err(std::io::Error::new(
575                    std::io::ErrorKind::BrokenPipe,
576                    "Socket is closing",
577                )));
578            }
579
580            if self.send_buffer_prod.is_full() {
581                self.shared_state.send_waker.register(cx.waker());
582                return Poll::Pending;
583            }
584
585            let n = self.send_buffer_prod.push_slice(buf);
586            if n > 0 {
587                self.worker_notifier.notify_one();
588            }
589
590            Poll::Ready(Ok(n))
591        }
592
593        fn poll_flush(
594            self: std::pin::Pin<&mut Self>,
595            cx: &mut Context<'_>,
596        ) -> Poll<std::io::Result<()>> {
597            if !self.send_buffer_prod.is_empty() {
598                self.shared_state.send_waker.register(cx.waker());
599                self.worker_notifier.notify_one();
600                return Poll::Pending;
601            }
602            Poll::Ready(Ok(()))
603        }
604
605        fn poll_shutdown(
606            mut self: std::pin::Pin<&mut Self>,
607            cx: &mut Context<'_>,
608        ) -> Poll<std::io::Result<()>> {
609            std::task::ready!(self.as_mut().poll_flush(cx))?;
610
611            self.shared_state
612                .socket_dropped
613                .store(true, Ordering::Release);
614            self.worker_notifier.notify_one();
615            Poll::Ready(Ok(()))
616        }
617    }
618
619    impl Drop for TcpStream {
620        fn drop(&mut self) {
621            self.shared_state
622                .socket_dropped
623                .store(true, Ordering::Release);
624            self.worker_notifier.notify_one();
625        }
626    }
627}