Skip to main content

ombrac_netstack/
tcp.rs

1use std::collections::HashMap;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9use ringbuf::HeapRb;
10use ringbuf::traits::{Consumer, Observer, Producer};
11use smoltcp::iface::{Interface, PollResult, SocketSet};
12use smoltcp::socket::tcp::{CongestionControl, Socket, SocketBuffer};
13use smoltcp::wire::{IpCidr, IpProtocol, TcpPacket};
14use tokio::sync::{Notify, broadcast, mpsc};
15use tokio::task::JoinHandle;
16
17use crate::buffer::BufferPool;
18use crate::device::NetstackDevice;
19use crate::stack::{IpPacket, NetStackConfig, Packet};
20use crate::{debug, error};
21
22pub use stream::TcpStream;
23pub(crate) use stream::{RbConsumer, RbProducer, SharedState};
24
25struct TcpConnectionWorker {
26    config: Arc<NetStackConfig>,
27    device_injector: mpsc::Sender<Packet>,
28    iface: Interface,
29    sockets: SocketSet<'static>,
30    socket_maps: HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
31    inbound: mpsc::Receiver<Packet>,
32    socket_stream_emitter: mpsc::Sender<TcpStream>,
33    notifier: Arc<Notify>,
34    shutdown_rx: broadcast::Receiver<()>,
35}
36
37pub(crate) struct SocketIOHandle {
38    recv_buffer_prod: RbProducer,
39    send_buffer_cons: RbConsumer,
40    shared_state: Arc<SharedState>,
41}
42
43pub struct TcpConnection {
44    socket_stream: mpsc::Receiver<TcpStream>,
45    shutdown_tx: broadcast::Sender<()>,
46    _handles: Vec<JoinHandle<()>>,
47}
48
49impl Drop for TcpConnection {
50    fn drop(&mut self) {
51        let _ = self.shutdown_tx.send(());
52    }
53}
54
55impl TcpConnection {
56    pub fn new(
57        config: NetStackConfig,
58        inbound: mpsc::Receiver<Packet>,
59        outbound: mpsc::Sender<Packet>,
60        buffer_pool: Arc<BufferPool>,
61    ) -> Self {
62        let num_workers = config.number_workers;
63        let config = Arc::new(config);
64
65        let (aggregated_socket_stream_emitter, aggregated_socket_stream_receiver) =
66            mpsc::channel::<TcpStream>(config.channel_size);
67
68        let (shutdown_tx, _) = broadcast::channel(1);
69
70        let mut _handles = Vec::new();
71        let mut worker_senders = Vec::with_capacity(num_workers);
72
73        for _i in 0..num_workers {
74            let (worker_inbound_sender, worker_inbound_receiver) =
75                mpsc::channel(config.channel_size);
76            worker_senders.push(worker_inbound_sender);
77
78            let mut device = NetstackDevice::new(outbound.clone(), buffer_pool.clone(), &config);
79            let device_injector = device.create_injector();
80            let iface = Self::create_interface(&config, &mut device);
81            let notifier = Arc::new(Notify::new());
82            let shutdown_rx = shutdown_tx.subscribe();
83
84            let mut worker = TcpConnectionWorker {
85                config: config.clone(),
86                device_injector,
87                iface,
88                sockets: SocketSet::new(vec![]),
89                socket_maps: HashMap::new(),
90                inbound: worker_inbound_receiver,
91                socket_stream_emitter: aggregated_socket_stream_emitter.clone(),
92                notifier: notifier.clone(),
93                shutdown_rx,
94            };
95
96            let worker_handle = tokio::spawn(async move {
97                if let Err(_e) = worker.accept_loop(device).await {
98                    error!("[Worker {}] exited with error: {}", _i, _e);
99                }
100            });
101            _handles.push(worker_handle);
102        }
103
104        let dispatcher_shutdown_rx = shutdown_tx.subscribe();
105        let dispatcher_handle = tokio::spawn(Self::distribute_packets(
106            inbound,
107            worker_senders,
108            dispatcher_shutdown_rx,
109        ));
110        _handles.push(dispatcher_handle);
111
112        TcpConnection {
113            socket_stream: aggregated_socket_stream_receiver,
114            shutdown_tx,
115            _handles,
116        }
117    }
118
119    async fn distribute_packets(
120        mut inbound: mpsc::Receiver<Packet>,
121        worker_senders: Vec<mpsc::Sender<Packet>>,
122        mut shutdown_rx: broadcast::Receiver<()>,
123    ) {
124        let num_workers = worker_senders.len();
125        loop {
126            tokio::select! {
127                _ = shutdown_rx.recv() => {
128                    debug!("[Dispatcher] received shutdown signal, exiting.");
129                    break;
130                }
131                maybe_packet = inbound.recv() => {
132                    if let Some(packet) = maybe_packet {
133                        let worker_index = match IpPacket::new_checked(packet.data()) {
134                            Ok(ip_packet) => {
135                                if ip_packet.protocol() == IpProtocol::Tcp {
136                                    if let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload()) {
137                                        let mut addr1 =
138                                            SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
139                                        let mut addr2 =
140                                            SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
141                                        if addr1 > addr2 {
142                                            std::mem::swap(&mut addr1, &mut addr2);
143                                        }
144                                        let mut hasher = DefaultHasher::new();
145                                        addr1.hash(&mut hasher);
146                                        addr2.hash(&mut hasher);
147                                        (hasher.finish() % num_workers as u64) as usize
148                                    } else { 0 }
149                                } else { 0 }
150                            }
151                            Err(_) => 0,
152                        };
153
154                        if worker_senders[worker_index].send(packet).await.is_err() {
155                            error!(
156                                "[Dispatcher] Failed to send packet to worker {}, channel closed.",
157                                worker_index
158                            );
159                            break;
160                        }
161                    } else {
162                        debug!("[Dispatcher] Inbound channel closed, exiting.");
163                        break;
164                    }
165                }
166            }
167        }
168        debug!("[Dispatcher] stopped.");
169    }
170
171    fn create_interface(config: &NetStackConfig, device: &mut NetstackDevice) -> Interface {
172        let mut iface_config = smoltcp::iface::Config::new(smoltcp::wire::HardwareAddress::Ip);
173        iface_config.random_seed = rand::random();
174        let mut iface =
175            smoltcp::iface::Interface::new(iface_config, device, smoltcp::time::Instant::now());
176
177        iface.set_any_ip(true);
178        iface.update_ip_addrs(|ip_addrs| {
179            let _ = ip_addrs.push(IpCidr::new(config.ipv4_addr.into(), config.ipv4_prefix_len));
180            let _ = ip_addrs.push(IpCidr::new(config.ipv6_addr.into(), config.ipv6_prefix_len));
181        });
182
183        iface
184            .routes_mut()
185            .add_default_ipv4_route(config.ipv4_addr)
186            .expect("Failed to add default IPv4 route");
187        iface
188            .routes_mut()
189            .add_default_ipv6_route(config.ipv6_addr)
190            .expect("Failed to add default IPv6 route");
191
192        iface
193    }
194}
195
196impl TcpConnectionWorker {
197    async fn accept_loop(&mut self, mut device: NetstackDevice) -> std::io::Result<()> {
198        loop {
199            // --- STAGE 1: WORK-DRAINING ---
200            let mut progress = true;
201            while progress {
202                progress = false;
203
204                // Drain inbound packets
205                while let Ok(packet) = self.inbound.try_recv() {
206                    if let Err(_e) = self.process_inbound_frame(packet).await {
207                        error!("Error processing inbound frame: {}", _e);
208                    }
209                    progress = true;
210                }
211
212                // Poll smoltcp for network events
213                let now = smoltcp::time::Instant::now();
214                if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
215                    progress = true;
216                }
217
218                // Handle IO for all active sockets
219                let mut total_bytes_processed = 0;
220                for (socket_handle, socket_control) in self.socket_maps.iter_mut() {
221                    // Skip sockets that have been marked as dropped (they will be cleaned up in next prune)
222                    if socket_control
223                        .shared_state
224                        .socket_dropped
225                        .load(Ordering::Acquire)
226                    {
227                        continue;
228                    }
229                    let socket = self.sockets.get_mut::<Socket>(*socket_handle);
230                    let (read, written) = Self::handle_socket_io(socket, socket_control);
231                    if read > 0 || written > 0 {
232                        total_bytes_processed += read + written;
233                    }
234                }
235                if total_bytes_processed > 0 {
236                    progress = true;
237                }
238
239                // Prune again after IO to clean up any sockets that became inactive during IO
240                if Self::prune_sockets(&mut self.sockets, &mut self.socket_maps) {
241                    progress = true;
242                }
243
244                // Poll smoltcp again after IO
245                let now = smoltcp::time::Instant::now();
246                if self.iface.poll(now, &mut device, &mut self.sockets) != PollResult::None {
247                    progress = true;
248                }
249
250                if progress && total_bytes_processed == 0 && self.inbound.is_empty() {
251                    tokio::task::yield_now().await;
252                }
253            }
254
255            // --- STAGE 2: IDLE / WAITING ---
256            // If we are here, it means the work-draining loop completed a full iteration
257            // without any work being done. We are now idle and can safely wait for the next event.
258
259            let now = smoltcp::time::Instant::now();
260            let smoltcp_delay = self.iface.poll_delay(now, &self.sockets).map(|d| d.into());
261
262            tokio::select! {
263                biased;
264                _ = self.shutdown_rx.recv() => {
265                    debug!("Worker received shutdown signal, exiting gracefully.");
266                    return Ok(());
267                }
268
269                // Wait for a new inbound packet to arrive.
270                maybe_packet = self.inbound.recv() => {
271                    match maybe_packet {
272                        Some(packet) => {
273                            if let Err(_e) = self.process_inbound_frame(packet).await {
274                                error!("Error processing inbound frame: {}", _e);
275                            }
276                            // After processing, continue to the top of 'main_loop to enter the work-draining phase.
277                        },
278                        None => return Ok(()), // Channel closed, exit worker.
279                    }
280                },
281
282                // Wait for a notification from a TcpStream (e.g., app wrote data).
283                _ = self.notifier.notified() => {
284                    // Notification received. No action needed here.
285                    // We'll simply loop back to the top and enter the work-draining phase.
286                },
287
288                // Wait for smoltcp's timer to expire (e.g., for retransmissions).
289                _ = async {
290                    match smoltcp_delay {
291                        // If there's a specific delay, we sleep for that duration.
292                        Some(delay) if delay > Duration::ZERO => tokio::time::sleep(delay).await,
293                        // In all other cases (no timer, or timer is for "now"),
294                        // we wait on a future that never completes. This effectively
295                        // disables this select arm and forces the select to wait for
296                        // one of the other arms (shutdown, inbound, notifier).
297                        _ => std::future::pending().await,
298                    }
299                } => {
300                    // Timer expired. No action needed here.
301                    // We'll simply loop back to the top and the work-draining phase will poll smoltcp.
302                },
303            }
304        }
305    }
306
307    async fn process_inbound_frame(&mut self, frame: Packet) -> std::io::Result<()> {
308        if let Ok(ip_packet) = IpPacket::new_checked(frame.data())
309            && ip_packet.protocol() == IpProtocol::Tcp
310            && let Ok(tcp_packet) = TcpPacket::new_checked(ip_packet.payload())
311            && tcp_packet.syn()
312            && !tcp_packet.ack()
313        {
314            self.accept_new_connection(&ip_packet, &tcp_packet)?;
315        }
316
317        self.device_injector
318            .try_send(frame)
319            .map_err(|e| std::io::Error::other(e.to_string()))?;
320        Ok(())
321    }
322
323    fn accept_new_connection(
324        &mut self,
325        ip_packet: &IpPacket<&[u8]>,
326        tcp_packet: &TcpPacket<&[u8]>,
327    ) -> std::io::Result<()> {
328        let src_addr = SocketAddr::new(ip_packet.src_addr(), tcp_packet.src_port());
329        let dst_addr = SocketAddr::new(ip_packet.dst_addr(), tcp_packet.dst_port());
330
331        let mut socket = Socket::new(
332            SocketBuffer::new(vec![0u8; self.config.tcp_recv_buffer_size]),
333            SocketBuffer::new(vec![0u8; self.config.tcp_send_buffer_size]),
334        );
335
336        socket.set_keep_alive(Some(self.config.tcp_keep_alive.into()));
337        socket.set_timeout(Some(self.config.tcp_timeout.into()));
338        socket.set_nagle_enabled(false);
339        socket.set_congestion_control(CongestionControl::Cubic);
340
341        socket
342            .listen(dst_addr)
343            .map_err(|e| std::io::Error::other(e.to_string()))?;
344
345        let recv_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_recv_buffer_size));
346        let (recv_prod, recv_cons) = (
347            ringbuf::Prod::new(recv_rb.clone()),
348            ringbuf::Cons::new(recv_rb),
349        );
350
351        let send_rb = Arc::new(HeapRb::<u8>::new(self.config.tcp_send_buffer_size));
352        let (send_prod, send_cons) = (
353            ringbuf::Prod::new(send_rb.clone()),
354            ringbuf::Cons::new(send_rb),
355        );
356
357        let shared_state = Arc::new(SharedState::new());
358        let stream = TcpStream {
359            local_addr: src_addr,
360            remote_addr: dst_addr,
361            recv_buffer_cons: recv_cons,
362            send_buffer_prod: send_prod,
363            shared_state: shared_state.clone(),
364            worker_notifier: self.notifier.clone(),
365        };
366
367        let io_handle = SocketIOHandle {
368            recv_buffer_prod: recv_prod,
369            send_buffer_cons: send_cons,
370            shared_state,
371        };
372
373        if self.socket_stream_emitter.try_send(stream).is_ok() {
374            let socket_handle = self.sockets.add(socket);
375            self.socket_maps.insert(socket_handle, io_handle);
376        } else {
377            error!(
378                "[Worker] Failed to emit new TcpStream to application, channel is full or closed. Dropping new connection from {}.",
379                src_addr
380            );
381        }
382
383        Ok(())
384    }
385
386    fn handle_socket_io(
387        socket: &mut Socket,
388        socket_control: &mut SocketIOHandle,
389    ) -> (usize, usize) {
390        let mut bytes_read = 0;
391        let mut bytes_written = 0;
392        let mut notify_read = false;
393
394        if socket.can_recv() {
395            match socket.recv(|buffer| {
396                let n = socket_control.recv_buffer_prod.push_slice(buffer);
397                if n > 0 {
398                    bytes_read += n;
399                }
400                (n, buffer.len())
401            }) {
402                Ok(n) => {
403                    if n > 0 {
404                        notify_read = true;
405                    }
406                }
407                Err(_e) => {
408                    error!("Socket recv error: {}. Closing read side.", _e);
409                    socket_control
410                        .shared_state
411                        .read_closed
412                        .store(true, Ordering::Release);
413                    notify_read = true;
414                }
415            }
416        }
417
418        if !socket.is_open()
419            && !socket_control
420                .shared_state
421                .read_closed
422                .load(Ordering::Acquire)
423        {
424            socket_control
425                .shared_state
426                .read_closed
427                .store(true, Ordering::Release);
428            notify_read = true;
429        }
430
431        if notify_read {
432            socket_control.shared_state.recv_waker.wake();
433        }
434
435        let mut notify_write = false;
436
437        while socket.can_send() && !socket_control.send_buffer_cons.is_empty() {
438            match socket.send(|buffer| {
439                let n = socket_control.send_buffer_cons.pop_slice(buffer);
440                (n, buffer.len())
441            }) {
442                Ok(n) if n > 0 => {
443                    bytes_written += n;
444                    notify_write = true;
445                }
446                _ => break,
447            }
448        }
449
450        if notify_write {
451            socket_control.shared_state.send_waker.wake();
452        }
453
454        (bytes_read, bytes_written)
455    }
456
457    fn prune_sockets(
458        sockets: &mut SocketSet,
459        socket_maps: &mut HashMap<smoltcp::iface::SocketHandle, SocketIOHandle>,
460    ) -> bool {
461        let initial_len = socket_maps.len();
462        socket_maps.retain(|handle, socket_control| {
463            let socket = sockets.get_mut::<Socket>(*handle);
464
465            if socket_control
466                .shared_state
467                .socket_dropped
468                .load(Ordering::Acquire)
469            {
470                socket.abort();
471                sockets.remove(*handle);
472                return false;
473            }
474
475            if !socket.is_active() {
476                sockets.remove(*handle);
477                return false;
478            }
479
480            true
481        });
482        initial_len != socket_maps.len()
483    }
484}
485
486impl futures::Stream for TcpConnection {
487    type Item = TcpStream;
488
489    fn poll_next(
490        mut self: std::pin::Pin<&mut Self>,
491        cx: &mut std::task::Context<'_>,
492    ) -> std::task::Poll<Option<Self::Item>> {
493        self.socket_stream.poll_recv(cx)
494    }
495}
496
497mod stream {
498    use std::net::SocketAddr;
499    use std::sync::Arc;
500    use std::sync::atomic::{AtomicBool, Ordering};
501    use std::task::{Context, Poll};
502
503    use futures::task::AtomicWaker;
504    use ringbuf::HeapRb;
505    use ringbuf::traits::{Consumer, Observer, Producer};
506    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
507    use tokio::sync::Notify;
508
509    pub(crate) type RbProducer = ringbuf::Prod<Arc<HeapRb<u8>>>;
510    pub(crate) type RbConsumer = ringbuf::Cons<Arc<HeapRb<u8>>>;
511
512    pub(crate) struct SharedState {
513        pub(crate) recv_waker: AtomicWaker,
514        pub(crate) send_waker: AtomicWaker,
515        pub(crate) read_closed: AtomicBool,
516        pub(crate) socket_dropped: AtomicBool,
517    }
518
519    impl SharedState {
520        pub fn new() -> Self {
521            Self {
522                recv_waker: AtomicWaker::new(),
523                send_waker: AtomicWaker::new(),
524                read_closed: AtomicBool::new(false),
525                socket_dropped: AtomicBool::new(false),
526            }
527        }
528    }
529
530    pub struct TcpStream {
531        pub(crate) local_addr: SocketAddr,
532        pub(crate) remote_addr: SocketAddr,
533        pub(crate) recv_buffer_cons: RbConsumer,
534        pub(crate) send_buffer_prod: RbProducer,
535        pub(crate) shared_state: Arc<SharedState>,
536        pub(crate) worker_notifier: Arc<Notify>,
537    }
538
539    impl TcpStream {
540        pub fn local_addr(&self) -> SocketAddr {
541            self.local_addr
542        }
543
544        pub fn remote_addr(&self) -> SocketAddr {
545            self.remote_addr
546        }
547
548        pub fn split(self) -> (ReadHalf<Self>, WriteHalf<Self>) {
549            tokio::io::split(self)
550        }
551    }
552
553    impl AsyncRead for TcpStream {
554        fn poll_read(
555            mut self: std::pin::Pin<&mut Self>,
556            cx: &mut Context<'_>,
557            buf: &mut ReadBuf<'_>,
558        ) -> Poll<std::io::Result<()>> {
559            let len_before = self.recv_buffer_cons.occupied_len();
560
561            if len_before == 0 {
562                if self.shared_state.read_closed.load(Ordering::Acquire) {
563                    return Poll::Ready(Ok(()));
564                }
565                self.shared_state.recv_waker.register(cx.waker());
566
567                if self.recv_buffer_cons.is_empty() {
568                    return Poll::Pending;
569                }
570            }
571
572            let unfilled_slice = buf.initialize_unfilled();
573            let n = self.recv_buffer_cons.pop_slice(unfilled_slice);
574            buf.advance(n);
575
576            if n > 0 {
577                self.worker_notifier.notify_one();
578            }
579
580            Poll::Ready(Ok(()))
581        }
582    }
583
584    impl AsyncWrite for TcpStream {
585        fn poll_write(
586            mut self: std::pin::Pin<&mut Self>,
587            cx: &mut Context<'_>,
588            buf: &[u8],
589        ) -> Poll<std::io::Result<usize>> {
590            if self.shared_state.socket_dropped.load(Ordering::Acquire) {
591                return Poll::Ready(Err(std::io::Error::new(
592                    std::io::ErrorKind::BrokenPipe,
593                    "Socket is closing",
594                )));
595            }
596
597            if self.send_buffer_prod.is_full() {
598                self.shared_state.send_waker.register(cx.waker());
599                if self.send_buffer_prod.is_full() {
600                    return Poll::Pending;
601                }
602            }
603
604            let n = self.send_buffer_prod.push_slice(buf);
605            if n > 0 {
606                self.worker_notifier.notify_one();
607            }
608
609            Poll::Ready(Ok(n))
610        }
611
612        fn poll_flush(
613            self: std::pin::Pin<&mut Self>,
614            cx: &mut Context<'_>,
615        ) -> Poll<std::io::Result<()>> {
616            if !self.send_buffer_prod.is_empty() {
617                self.shared_state.send_waker.register(cx.waker());
618                if !self.send_buffer_prod.is_empty() {
619                    return Poll::Pending;
620                }
621            }
622            Poll::Ready(Ok(()))
623        }
624
625        fn poll_shutdown(
626            mut self: std::pin::Pin<&mut Self>,
627            cx: &mut Context<'_>,
628        ) -> Poll<std::io::Result<()>> {
629            std::task::ready!(self.as_mut().poll_flush(cx))?;
630
631            self.shared_state
632                .socket_dropped
633                .store(true, Ordering::Release);
634            self.worker_notifier.notify_one();
635            Poll::Ready(Ok(()))
636        }
637    }
638
639    impl Drop for TcpStream {
640        fn drop(&mut self) {
641            self.shared_state
642                .socket_dropped
643                .store(true, Ordering::Release);
644            self.worker_notifier.notify_one();
645        }
646    }
647}