madsim/sim/net/tcp/
stream.rs1use crate::net::{IpProtocol::Tcp, *};
2use bytes::{Buf, BufMut, BytesMut};
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5use std::{
6 fmt,
7 io::Result,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12use tracing::*;
13
14pub struct TcpStream {
16 pub(super) guard: Option<Arc<BindGuard>>,
17 pub(super) addr: SocketAddr,
18 pub(super) peer: SocketAddr,
19 pub(super) write_buf: BytesMut,
21 pub(super) read_buf: Bytes,
22 pub(super) tx: PayloadSender,
23 pub(super) rx: PayloadReceiver,
24}
25
26impl fmt::Debug for TcpStream {
27 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
28 fmt.debug_struct("TcpStream")
29 .field("addr", &self.addr)
30 .field("peer", &self.peer)
31 .finish()
32 }
33}
34
35impl TcpStream {
36 #[instrument]
47 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
48 let mut last_err = None;
49
50 for addr in lookup_host(addr).await? {
51 match Self::connect_one(addr).await {
52 Ok(stream) => return Ok(stream),
53 Err(e) => last_err = Some(e),
54 }
55 }
56 Err(last_err.unwrap_or_else(|| {
57 io::Error::new(
58 io::ErrorKind::InvalidInput,
59 "could not resolve to any addresses",
60 )
61 }))
62 }
63
64 #[instrument]
66 async fn connect_one(addr: SocketAddr) -> Result<TcpStream> {
67 let net = plugin::simulator::<NetSim>();
68 net.rand_delay().await?;
69
70 let guard = BindGuard::bind("0.0.0.0:0", Tcp, Arc::new(TcpStreamSocket)).await?;
73 let (tx, rx, local_addr) = net
74 .connect1(plugin::node(), guard.addr.port(), addr, Tcp)
75 .await?;
76 let stream = TcpStream {
77 guard: Some(Arc::new(guard)),
78 addr: local_addr,
79 peer: addr,
80 write_buf: Default::default(),
81 read_buf: Default::default(),
82 tx,
83 rx,
84 };
85 Ok(stream)
86 }
87
88 pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
90 Ok(())
92 }
93
94 pub fn local_addr(&self) -> Result<SocketAddr> {
96 Ok(self.addr)
97 }
98
99 pub fn peer_addr(&self) -> Result<SocketAddr> {
101 Ok(self.peer)
102 }
103
104 pub fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<usize> {
112 if !self.read_buf.is_empty() {
114 let len = self.read_buf.len().min(buf.remaining_mut());
115 buf.put_slice(&self.read_buf[..len]);
116 self.read_buf.advance(len);
117 return Ok(len);
118 }
119 Err(io::Error::new(
120 io::ErrorKind::WouldBlock,
121 "read buffer is empty",
122 ))
123 }
124}
125
126#[cfg(unix)]
127impl AsRawFd for TcpStream {
128 fn as_raw_fd(&self) -> RawFd {
129 todo!("TcpStream::as_raw_fd");
130 }
131}
132
133impl AsyncRead for TcpStream {
134 fn poll_read(
135 mut self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 buf: &mut ReadBuf<'_>,
138 ) -> Poll<Result<()>> {
139 if !self.read_buf.is_empty() {
141 let len = self.read_buf.len().min(buf.remaining());
142 buf.put_slice(&self.read_buf[..len]);
143 self.read_buf.advance(len);
144 return Poll::Ready(Ok(()));
145 }
146 let poll_res = { self.rx.poll_next_unpin(cx) };
148 match poll_res {
149 Poll::Pending => Poll::Pending,
150 Poll::Ready(Some(data)) => {
151 self.read_buf = *data.downcast::<Bytes>().unwrap();
152 self.poll_read(cx, buf)
153 }
154 Poll::Ready(None) => Poll::Ready(Ok(())),
158 }
159 }
160}
161
162impl AsyncWrite for TcpStream {
163 fn poll_write(
164 mut self: Pin<&mut Self>,
165 _cx: &mut Context<'_>,
166 buf: &[u8],
167 ) -> Poll<Result<usize>> {
168 self.write_buf.extend_from_slice(buf);
169 Poll::Ready(Ok(buf.len()))
171 }
172
173 fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
174 let data = self.write_buf.split().freeze();
176 self.tx
177 .send(Box::new(data))
178 .ok_or_else(|| io::Error::new(io::ErrorKind::ConnectionReset, "connection reset"))?;
179 Poll::Ready(Ok(()))
180 }
181
182 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
183 Poll::Ready(Ok(()))
185 }
186}
187
188struct TcpStreamSocket;
190
191impl Socket for TcpStreamSocket {}