netstack_smoltcp/
tcp.rs

1use std::{
2    collections::HashMap,
3    net::SocketAddr,
4    pin::Pin,
5    sync::{
6        atomic::{AtomicBool, Ordering},
7        Arc,
8    },
9    task::{Context, Poll, Waker},
10};
11
12use futures::Stream;
13use smoltcp::{
14    iface::{Config as InterfaceConfig, Interface, SocketHandle, SocketSet},
15    phy::Device,
16    socket::tcp::{Socket as TcpSocket, SocketBuffer as TcpSocketBuffer, State as TcpState},
17    storage::RingBuffer,
18    time::{Duration, Instant},
19    wire::{HardwareAddress, IpAddress, IpCidr, IpProtocol, Ipv4Address, Ipv6Address, TcpPacket},
20};
21use spin::Mutex as SpinMutex;
22use tokio::{
23    io::{AsyncRead, AsyncWrite, ReadBuf},
24    sync::{
25        mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender},
26        Notify,
27    },
28};
29use tracing::{error, trace};
30
31use crate::{
32    device::VirtualDevice,
33    packet::{AnyIpPktFrame, IpPacket},
34    Runner,
35};
36
37// NOTE: Default buffer could contain 20 AEAD packets
38const DEFAULT_TCP_SEND_BUFFER_SIZE: u32 = 0x3FFF * 20;
39const DEFAULT_TCP_RECV_BUFFER_SIZE: u32 = 0x3FFF * 20;
40
41#[derive(Debug, Clone, Copy, Eq, PartialEq)]
42enum TcpSocketState {
43    Normal,
44    Close,
45    Closing,
46    Closed,
47}
48
49struct TcpSocketControl {
50    send_buffer: RingBuffer<'static, u8>,
51    send_waker: Option<Waker>,
52    recv_buffer: RingBuffer<'static, u8>,
53    recv_waker: Option<Waker>,
54    recv_state: TcpSocketState,
55    send_state: TcpSocketState,
56}
57
58struct TcpSocketCreation {
59    control: SharedControl,
60    socket: TcpSocket<'static>,
61}
62
63type SharedNotify = Arc<Notify>;
64type SharedControl = Arc<SpinMutex<TcpSocketControl>>;
65
66struct TcpListenerRunner;
67
68impl TcpListenerRunner {
69    fn create(
70        device: VirtualDevice,
71        iface: Interface,
72        iface_ingress_tx: UnboundedSender<Vec<u8>>,
73        iface_ingress_tx_avail: Arc<AtomicBool>,
74        tcp_rx: Receiver<AnyIpPktFrame>,
75        stream_tx: UnboundedSender<TcpStream>,
76        sockets: HashMap<SocketHandle, SharedControl>,
77    ) -> Runner {
78        Runner::new(async move {
79            let notify = Arc::new(Notify::new());
80            let (socket_tx, socket_rx) = unbounded_channel::<TcpSocketCreation>();
81            let res = tokio::select! {
82                v = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => v,
83                v = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => v,
84            };
85            res?;
86            trace!("VirtDevice::poll thread exited");
87            Ok(())
88        })
89    }
90
91    async fn handle_packet(
92        notify: SharedNotify,
93        iface_ingress_tx: UnboundedSender<Vec<u8>>,
94        iface_ingress_tx_avail: Arc<AtomicBool>,
95        mut tcp_rx: Receiver<AnyIpPktFrame>,
96        stream_tx: UnboundedSender<TcpStream>,
97        socket_tx: UnboundedSender<TcpSocketCreation>,
98    ) -> std::io::Result<()> {
99        while let Some(frame) = tcp_rx.recv().await {
100            let packet = match IpPacket::new_checked(frame.as_slice()) {
101                Ok(p) => p,
102                Err(err) => {
103                    error!("invalid TCP IP packet: {:?}", err,);
104                    continue;
105                }
106            };
107
108            // Specially handle icmp packet by TCP interface.
109            if matches!(packet.protocol(), IpProtocol::Icmp | IpProtocol::Icmpv6) {
110                iface_ingress_tx
111                    .send(frame)
112                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
113                iface_ingress_tx_avail.store(true, Ordering::Release);
114                notify.notify_one();
115                continue;
116            }
117
118            let src_ip = packet.src_addr();
119            let dst_ip = packet.dst_addr();
120            let payload = packet.payload();
121
122            let packet = match TcpPacket::new_checked(payload) {
123                Ok(p) => p,
124                Err(err) => {
125                    error!("invalid TCP err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
126                    continue;
127                }
128            };
129            let src_port = packet.src_port();
130            let dst_port = packet.dst_port();
131
132            let src_addr = SocketAddr::new(src_ip, src_port);
133            let dst_addr = SocketAddr::new(dst_ip, dst_port);
134
135            // TCP first handshake packet, create a new Connection
136            if packet.syn() && !packet.ack() {
137                let mut socket = TcpSocket::new(
138                    TcpSocketBuffer::new(vec![0u8; DEFAULT_TCP_RECV_BUFFER_SIZE as usize]),
139                    TcpSocketBuffer::new(vec![0u8; DEFAULT_TCP_SEND_BUFFER_SIZE as usize]),
140                );
141                socket.set_keep_alive(Some(Duration::from_secs(28)));
142                // FIXME: It should follow system's setting. 7200 is Linux's default.
143                socket.set_timeout(Some(Duration::from_secs(7200)));
144                // NO ACK delay
145                // socket.set_ack_delay(None);
146
147                if let Err(err) = socket.listen(dst_addr) {
148                    error!("listen error: {:?}", err);
149                    continue;
150                }
151
152                trace!("created TCP connection for {} <-> {}", src_addr, dst_addr);
153
154                let control = Arc::new(SpinMutex::new(TcpSocketControl {
155                    send_buffer: RingBuffer::new(vec![0u8; DEFAULT_TCP_SEND_BUFFER_SIZE as usize]),
156                    send_waker: None,
157                    recv_buffer: RingBuffer::new(vec![0u8; DEFAULT_TCP_RECV_BUFFER_SIZE as usize]),
158                    recv_waker: None,
159                    recv_state: TcpSocketState::Normal,
160                    send_state: TcpSocketState::Normal,
161                }));
162
163                stream_tx
164                    .send(TcpStream {
165                        src_addr,
166                        dst_addr,
167                        notify: notify.clone(),
168                        control: control.clone(),
169                    })
170                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
171                socket_tx
172                    .send(TcpSocketCreation { control, socket })
173                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
174            }
175
176            // Pipeline tcp stream packet
177            iface_ingress_tx
178                .send(frame)
179                .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
180            iface_ingress_tx_avail.store(true, Ordering::Release);
181            notify.notify_one();
182        }
183        Ok(())
184    }
185
186    async fn handle_socket(
187        notify: SharedNotify,
188        mut device: VirtualDevice,
189        mut iface: Interface,
190        iface_ingress_tx_avail: Arc<AtomicBool>,
191        mut sockets: HashMap<SocketHandle, SharedControl>,
192        mut socket_rx: UnboundedReceiver<TcpSocketCreation>,
193    ) -> std::io::Result<()> {
194        let mut socket_set = SocketSet::new(vec![]);
195        loop {
196            while let Ok(TcpSocketCreation { control, socket }) = socket_rx.try_recv() {
197                let handle = socket_set.add(socket);
198                sockets.insert(handle, control);
199            }
200
201            let before_poll = Instant::now();
202            let updated_sockets = iface.poll(before_poll, &mut device, &mut socket_set);
203            if matches!(
204                updated_sockets,
205                smoltcp::iface::PollResult::SocketStateChanged
206            ) {
207                trace!("VirtDevice::poll costed {}", Instant::now() - before_poll);
208            }
209
210            // Check all the sockets' status
211            let mut sockets_to_remove = Vec::new();
212
213            for (socket_handle, control) in sockets.iter() {
214                let socket_handle = *socket_handle;
215                let socket = socket_set.get_mut::<TcpSocket>(socket_handle);
216                let mut control = control.lock();
217
218                // Remove the socket only when it is in the closed state.
219                if socket.state() == TcpState::Closed {
220                    sockets_to_remove.push(socket_handle);
221
222                    control.send_state = TcpSocketState::Closed;
223                    control.recv_state = TcpSocketState::Closed;
224
225                    if let Some(waker) = control.send_waker.take() {
226                        waker.wake();
227                    }
228                    if let Some(waker) = control.recv_waker.take() {
229                        waker.wake();
230                    }
231
232                    trace!("closed TCP connection");
233                    continue;
234                }
235
236                // SHUT_WR
237                if matches!(control.send_state, TcpSocketState::Close) {
238                    trace!("closing TCP Write Half, {:?}", socket.state());
239
240                    // Close the socket. Set to FIN state
241                    socket.close();
242                    control.send_state = TcpSocketState::Closing;
243
244                    // We can still process the pending buffer.
245                }
246
247                // Check if readable
248                let mut wake_receiver = false;
249                while socket.can_recv() && !control.recv_buffer.is_full() {
250                    let result = socket.recv(|buffer| {
251                        let n = control.recv_buffer.enqueue_slice(buffer);
252                        (n, ())
253                    });
254
255                    match result {
256                        Ok(..) => wake_receiver = true,
257                        Err(err) => {
258                            error!("socket recv error: {:?}, {:?}", err, socket.state());
259
260                            // Don't know why. Abort the connection.
261                            socket.abort();
262
263                            if matches!(control.recv_state, TcpSocketState::Normal) {
264                                control.recv_state = TcpSocketState::Closed;
265                            }
266                            wake_receiver = true;
267
268                            // The socket will be recycled in the next poll.
269                            break;
270                        }
271                    }
272                }
273
274                // If socket is not in ESTABLISH, FIN-WAIT-1, FIN-WAIT-2,
275                // the local client have closed our receiver.
276                let states = [
277                    TcpState::Listen,
278                    TcpState::SynReceived,
279                    TcpState::Established,
280                    TcpState::FinWait1,
281                    TcpState::FinWait2,
282                ];
283                if matches!(control.recv_state, TcpSocketState::Normal)
284                    && !socket.may_recv()
285                    && !states.contains(&socket.state())
286                {
287                    trace!("closed TCP Read Half, {:?}", socket.state());
288
289                    // Let TcpStream::poll_read returns EOF.
290                    control.recv_state = TcpSocketState::Closed;
291                    wake_receiver = true;
292                }
293
294                if wake_receiver && control.recv_waker.is_some() {
295                    if let Some(waker) = control.recv_waker.take() {
296                        waker.wake();
297                    }
298                }
299
300                // Check if writable
301                let mut wake_sender = false;
302                while socket.can_send() && !control.send_buffer.is_empty() {
303                    let result = socket.send(|buffer| {
304                        let n = control.send_buffer.dequeue_slice(buffer);
305                        (n, ())
306                    });
307
308                    match result {
309                        Ok(..) => wake_sender = true,
310                        Err(err) => {
311                            error!("socket send error: {:?}, {:?}", err, socket.state());
312
313                            // Don't know why. Abort the connection.
314                            socket.abort();
315
316                            if matches!(control.send_state, TcpSocketState::Normal) {
317                                control.send_state = TcpSocketState::Closed;
318                            }
319                            wake_sender = true;
320
321                            // The socket will be recycled in the next poll.
322                            break;
323                        }
324                    }
325                }
326
327                if wake_sender && control.send_waker.is_some() {
328                    if let Some(waker) = control.send_waker.take() {
329                        waker.wake();
330                    }
331                }
332            }
333
334            for socket_handle in sockets_to_remove {
335                sockets.remove(&socket_handle);
336                socket_set.remove(socket_handle);
337            }
338
339            if !iface_ingress_tx_avail.load(Ordering::Acquire) {
340                let next_duration = iface
341                    .poll_delay(before_poll, &socket_set)
342                    .unwrap_or(Duration::from_millis(5));
343                if next_duration != Duration::ZERO {
344                    let _ = tokio::time::timeout(
345                        tokio::time::Duration::from(next_duration),
346                        notify.notified(),
347                    )
348                    .await;
349                }
350            }
351        }
352    }
353}
354
355pub struct TcpListener {
356    stream_rx: UnboundedReceiver<TcpStream>,
357}
358
359impl TcpListener {
360    pub(super) fn new(
361        tcp_rx: Receiver<AnyIpPktFrame>,
362        stack_tx: Sender<AnyIpPktFrame>,
363    ) -> std::io::Result<(Runner, Self)> {
364        let (mut device, iface_ingress_tx, iface_ingress_tx_avail) = VirtualDevice::new(stack_tx);
365        let iface = Self::create_interface(&mut device)?;
366
367        let (stream_tx, stream_rx) = unbounded_channel();
368
369        let runner = TcpListenerRunner::create(
370            device,
371            iface,
372            iface_ingress_tx,
373            iface_ingress_tx_avail,
374            tcp_rx,
375            stream_tx,
376            HashMap::new(),
377        );
378
379        Ok((runner, Self { stream_rx }))
380    }
381
382    fn create_interface<D>(device: &mut D) -> std::io::Result<Interface>
383    where
384        D: Device + ?Sized,
385    {
386        let mut iface_config = InterfaceConfig::new(HardwareAddress::Ip);
387        iface_config.random_seed = rand::random();
388        let mut iface = Interface::new(iface_config, device, Instant::now());
389        iface.update_ip_addrs(|ip_addrs| {
390            ip_addrs
391                .push(IpCidr::new(IpAddress::v4(0, 0, 0, 1), 0))
392                .expect("iface IPv4");
393            ip_addrs
394                .push(IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 1), 0))
395                .expect("iface IPv6");
396        });
397        iface
398            .routes_mut()
399            .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1))
400            .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
401        iface
402            .routes_mut()
403            .add_default_ipv6_route(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1))
404            .map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
405        iface.set_any_ip(true);
406        Ok(iface)
407    }
408}
409
410impl Stream for TcpListener {
411    type Item = (TcpStream, SocketAddr, SocketAddr);
412
413    fn poll_next(
414        mut self: std::pin::Pin<&mut Self>,
415        cx: &mut std::task::Context<'_>,
416    ) -> std::task::Poll<Option<Self::Item>> {
417        self.stream_rx.poll_recv(cx).map(|stream| {
418            stream.map(|stream| {
419                let local_addr = *stream.local_addr();
420                let remote_addr: SocketAddr = *stream.remote_addr();
421                (stream, local_addr, remote_addr)
422            })
423        })
424    }
425}
426
427pub struct TcpStream {
428    src_addr: SocketAddr,
429    dst_addr: SocketAddr,
430    notify: SharedNotify,
431    control: SharedControl,
432}
433
434impl Drop for TcpStream {
435    fn drop(&mut self) {
436        let mut control = self.control.lock();
437
438        if matches!(control.recv_state, TcpSocketState::Normal) {
439            control.recv_state = TcpSocketState::Close;
440        }
441
442        if matches!(control.send_state, TcpSocketState::Normal) {
443            control.send_state = TcpSocketState::Close;
444        }
445
446        self.notify.notify_one();
447    }
448}
449
450impl TcpStream {
451    pub fn local_addr(&self) -> &SocketAddr {
452        &self.src_addr
453    }
454
455    pub fn remote_addr(&self) -> &SocketAddr {
456        &self.dst_addr
457    }
458}
459
460impl AsyncRead for TcpStream {
461    fn poll_read(
462        self: Pin<&mut Self>,
463        cx: &mut Context<'_>,
464        buf: &mut ReadBuf<'_>,
465    ) -> Poll<std::io::Result<()>> {
466        let mut control = self.control.lock();
467
468        // Read from buffer
469        if control.recv_buffer.is_empty() {
470            // If socket is already closed / half closed, just return EOF directly.
471            if matches!(control.recv_state, TcpSocketState::Closed) {
472                return Ok(()).into();
473            }
474
475            // Nothing could be read. Wait for notify.
476            if let Some(old_waker) = control.recv_waker.replace(cx.waker().clone()) {
477                if !old_waker.will_wake(cx.waker()) {
478                    old_waker.wake();
479                }
480            }
481
482            return Poll::Pending;
483        }
484
485        let recv_buf = unsafe {
486            std::mem::transmute::<&mut [std::mem::MaybeUninit<u8>], &mut [u8]>(buf.unfilled_mut())
487        };
488        let n = control.recv_buffer.dequeue_slice(recv_buf);
489        buf.advance(n);
490
491        if n > 0 {
492            self.notify.notify_one();
493        }
494
495        Ok(()).into()
496    }
497}
498
499impl AsyncWrite for TcpStream {
500    fn poll_write(
501        self: Pin<&mut Self>,
502        cx: &mut Context<'_>,
503        buf: &[u8],
504    ) -> Poll<std::io::Result<usize>> {
505        let mut control = self.control.lock();
506
507        // If state == Close | Closing | Closed, the TCP stream WR half is closed.
508        if !matches!(control.send_state, TcpSocketState::Normal) {
509            return Err(std::io::ErrorKind::BrokenPipe.into()).into();
510        }
511
512        // Write to buffer
513
514        if control.send_buffer.is_full() {
515            if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
516                if !old_waker.will_wake(cx.waker()) {
517                    old_waker.wake();
518                }
519            }
520
521            return Poll::Pending;
522        }
523
524        let n = control.send_buffer.enqueue_slice(buf);
525
526        if n > 0 {
527            self.notify.notify_one();
528        }
529
530        Ok(n).into()
531    }
532
533    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
534        Ok(()).into()
535    }
536
537    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
538        let mut control = self.control.lock();
539
540        if matches!(control.send_state, TcpSocketState::Closed) {
541            return Ok(()).into();
542        }
543
544        // SHUT_WR
545        if matches!(control.send_state, TcpSocketState::Normal) {
546            control.send_state = TcpSocketState::Close;
547        }
548
549        if let Some(old_waker) = control.send_waker.replace(cx.waker().clone()) {
550            if !old_waker.will_wake(cx.waker()) {
551                old_waker.wake();
552            }
553        }
554
555        self.notify.notify_one();
556
557        Poll::Pending
558    }
559}