madsim/sim/net/tcp/
stream.rs

1use crate::net::{IpProtocol::Tcp, *};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5use std::{
6    fmt,
7    io::Result,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12use tracing::*;
13
14/// A TCP stream between a local and a remote socket.
15pub struct TcpStream {
16    pub(super) guard: Option<Arc<BindGuard>>,
17    pub(super) addr: SocketAddr,
18    pub(super) peer: SocketAddr,
19    /// Buffer write data to be flushed.
20    pub(super) write_buf: BytesMut,
21    pub(super) read_buf: Bytes,
22    pub(super) tx: PayloadSender,
23    pub(super) rx: PayloadReceiver,
24}
25
26impl fmt::Debug for TcpStream {
27    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
28        fmt.debug_struct("TcpStream")
29            .field("addr", &self.addr)
30            .field("peer", &self.peer)
31            .finish()
32    }
33}
34
35impl TcpStream {
36    /// Opens a simulated TCP connection to a remote host.
37    ///
38    /// `addr` is an address of the remote host. Anything which implements the
39    /// [`ToSocketAddrs`] trait can be supplied as the address.  If `addr`
40    /// yields multiple addresses, connect will be attempted with each of the
41    /// addresses until a connection is successful. If none of the addresses
42    /// result in a successful connection, the error returned from the last
43    /// connection attempt (the last address) is returned.
44    ///
45    /// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs
46    #[instrument]
47    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
48        let mut last_err = None;
49
50        for addr in lookup_host(addr).await? {
51            match Self::connect_one(addr).await {
52                Ok(stream) => return Ok(stream),
53                Err(e) => last_err = Some(e),
54            }
55        }
56        Err(last_err.unwrap_or_else(|| {
57            io::Error::new(
58                io::ErrorKind::InvalidInput,
59                "could not resolve to any addresses",
60            )
61        }))
62    }
63
64    /// Connects to one address.
65    #[instrument]
66    async fn connect_one(addr: SocketAddr) -> Result<TcpStream> {
67        let net = plugin::simulator::<NetSim>();
68        net.rand_delay().await?;
69
70        // send a request to listener and wait for TcpStream
71        // FIXME: the port it uses should not be exclusive
72        let guard = BindGuard::bind("0.0.0.0:0", Tcp, Arc::new(TcpStreamSocket)).await?;
73        let (tx, rx, local_addr) = net
74            .connect1(plugin::node(), guard.addr.port(), addr, Tcp)
75            .await?;
76        let stream = TcpStream {
77            guard: Some(Arc::new(guard)),
78            addr: local_addr,
79            peer: addr,
80            write_buf: Default::default(),
81            read_buf: Default::default(),
82            tx,
83            rx,
84        };
85        Ok(stream)
86    }
87
88    /// Sets the value of the `TCP_NODELAY` option on this socket.
89    pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
90        // TODO: simulate TCP_NODELAY
91        Ok(())
92    }
93
94    /// Returns the socket address of the local half of this TCP connection.
95    pub fn local_addr(&self) -> Result<SocketAddr> {
96        Ok(self.addr)
97    }
98
99    /// Returns the socket address of the remote peer of this TCP connection.
100    pub fn peer_addr(&self) -> Result<SocketAddr> {
101        Ok(self.peer)
102    }
103
104    /// Tries to read data from the stream into the provided buffer, advancing
105    /// the buffer's internal cursor, returning how many bytes were read.
106    ///
107    /// Receives any pending data from the socket but does not wait for new data
108    /// to arrive. On success, returns the number of bytes read. Because
109    /// `try_read_buf()` is non-blocking, the buffer does not have to be stored
110    /// by the async task and can exist entirely on the stack.
111    pub fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<usize> {
112        // read the buffer if not empty
113        if !self.read_buf.is_empty() {
114            let len = self.read_buf.len().min(buf.remaining_mut());
115            buf.put_slice(&self.read_buf[..len]);
116            self.read_buf.advance(len);
117            return Ok(len);
118        }
119        Err(io::Error::new(
120            io::ErrorKind::WouldBlock,
121            "read buffer is empty",
122        ))
123    }
124}
125
126#[cfg(unix)]
127impl AsRawFd for TcpStream {
128    fn as_raw_fd(&self) -> RawFd {
129        todo!("TcpStream::as_raw_fd");
130    }
131}
132
133impl AsyncRead for TcpStream {
134    fn poll_read(
135        mut self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137        buf: &mut ReadBuf<'_>,
138    ) -> Poll<Result<()>> {
139        // read the buffer if not empty
140        if !self.read_buf.is_empty() {
141            let len = self.read_buf.len().min(buf.remaining());
142            buf.put_slice(&self.read_buf[..len]);
143            self.read_buf.advance(len);
144            return Poll::Ready(Ok(()));
145        }
146        // otherwise wait on channel
147        let poll_res = { self.rx.poll_next_unpin(cx) };
148        match poll_res {
149            Poll::Pending => Poll::Pending,
150            Poll::Ready(Some(data)) => {
151                self.read_buf = *data.downcast::<Bytes>().unwrap();
152                self.poll_read(cx, buf)
153            }
154            // ref: https://man7.org/linux/man-pages/man2/recv.2.html
155            // > When a stream socket peer has performed an orderly shutdown, the
156            // > return value will be 0 (the traditional "end-of-file" return).
157            Poll::Ready(None) => Poll::Ready(Ok(())),
158        }
159    }
160}
161
162impl AsyncWrite for TcpStream {
163    fn poll_write(
164        mut self: Pin<&mut Self>,
165        _cx: &mut Context<'_>,
166        buf: &[u8],
167    ) -> Poll<Result<usize>> {
168        self.write_buf.extend_from_slice(buf);
169        // TODO: simulate buffer full, partial write
170        Poll::Ready(Ok(buf.len()))
171    }
172
173    fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
174        // send data
175        let data = self.write_buf.split().freeze();
176        self.tx
177            .send(Box::new(data))
178            .ok_or_else(|| io::Error::new(io::ErrorKind::ConnectionReset, "connection reset"))?;
179        Poll::Ready(Ok(()))
180    }
181
182    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
183        // TODO: simulate shutdown
184        Poll::Ready(Ok(()))
185    }
186}
187
188/// Socket registered in the [`Network`].
189struct TcpStreamSocket;
190
191impl Socket for TcpStreamSocket {}