compio_net/
tcp.rs

1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
8
9use crate::{
10    OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf,
11};
12
13/// A TCP socket server, listening for connections.
14///
15/// You can accept a new connection by using the
16/// [`accept`](`TcpListener::accept`) method.
17///
18/// # Examples
19///
20/// ```
21/// use std::net::SocketAddr;
22///
23/// use compio_io::{AsyncReadExt, AsyncWriteExt};
24/// use compio_net::{TcpListener, TcpStream};
25/// use socket2::SockAddr;
26///
27/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
28/// let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
29///
30/// let addr = listener.local_addr().unwrap();
31///
32/// let tx_fut = TcpStream::connect(&addr);
33///
34/// let rx_fut = listener.accept();
35///
36/// let (mut tx, (mut rx, _)) = futures_util::try_join!(tx_fut, rx_fut).unwrap();
37///
38/// tx.write_all("test").await.0.unwrap();
39///
40/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
41///
42/// assert_eq!(buf, b"test");
43/// # });
44/// ```
45#[derive(Debug, Clone)]
46pub struct TcpListener {
47    inner: Socket,
48}
49
50impl TcpListener {
51    /// Creates a new `TcpListener`, which will be bound to the specified
52    /// address.
53    ///
54    /// The returned listener is ready for accepting connections.
55    ///
56    /// Binding with a port number of 0 will request that the OS assigns a port
57    /// to this listener.
58    pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
59        super::each_addr(addr, |addr| async move {
60            let sa = SockAddr::from(addr);
61            let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
62            socket.socket.set_reuse_address(true)?;
63            socket.socket.bind(&sa)?;
64            socket.listen(128)?;
65            Ok(Self { inner: socket })
66        })
67        .await
68    }
69
70    /// Creates new TcpListener from a [`std::net::TcpListener`].
71    pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
72        Ok(Self {
73            inner: Socket::from_socket2(Socket2::from(stream))?,
74        })
75    }
76
77    /// Close the socket. If the returned future is dropped before polling, the
78    /// socket won't be closed.
79    pub fn close(self) -> impl Future<Output = io::Result<()>> {
80        self.inner.close()
81    }
82
83    /// Accepts a new incoming connection from this listener.
84    ///
85    /// This function will yield once a new TCP connection is established. When
86    /// established, the corresponding [`TcpStream`] and the remote peer's
87    /// address will be returned.
88    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
89        let (socket, addr) = self.inner.accept().await?;
90        let stream = TcpStream { inner: socket };
91        Ok((stream, addr.as_socket().expect("should be SocketAddr")))
92    }
93
94    /// Returns the local address that this listener is bound to.
95    ///
96    /// This can be useful, for example, when binding to port 0 to
97    /// figure out which port was actually bound.
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
103    ///
104    /// use compio_net::TcpListener;
105    /// use socket2::SockAddr;
106    ///
107    /// # compio_runtime::Runtime::new().unwrap().block_on(async {
108    /// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
109    ///
110    /// let addr = listener.local_addr().expect("Couldn't get local address");
111    /// assert_eq!(
112    ///     addr,
113    ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080))
114    /// );
115    /// # });
116    /// ```
117    pub fn local_addr(&self) -> io::Result<SocketAddr> {
118        self.inner
119            .local_addr()
120            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
121    }
122}
123
124impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
125
126/// A TCP stream between a local and a remote socket.
127///
128/// A TCP stream can either be created by connecting to an endpoint, via the
129/// `connect` method, or by accepting a connection from a listener.
130///
131/// # Examples
132///
133/// ```no_run
134/// use std::net::SocketAddr;
135///
136/// use compio_io::AsyncWrite;
137/// use compio_net::TcpStream;
138///
139/// # compio_runtime::Runtime::new().unwrap().block_on(async {
140/// // Connect to a peer
141/// let mut stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
142///
143/// // Write some data.
144/// stream.write("hello world!").await.unwrap();
145/// # })
146/// ```
147#[derive(Debug, Clone)]
148pub struct TcpStream {
149    inner: Socket,
150}
151
152impl TcpStream {
153    /// Opens a TCP connection to a remote host.
154    pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
155        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
156
157        super::each_addr(addr, |addr| async move {
158            let addr2 = SockAddr::from(addr);
159            let socket = if cfg!(windows) {
160                let bind_addr = if addr.is_ipv4() {
161                    SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
162                } else if addr.is_ipv6() {
163                    SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
164                } else {
165                    return Err(io::Error::new(
166                        io::ErrorKind::AddrNotAvailable,
167                        "Unsupported address domain.",
168                    ));
169                };
170                Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
171            } else {
172                Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
173            };
174            socket.connect_async(&addr2).await?;
175            Ok(Self { inner: socket })
176        })
177        .await
178    }
179
180    /// Bind to `bind_addr` then opens a TCP connection to a remote host.
181    pub async fn bind_and_connect(
182        bind_addr: SocketAddr,
183        addr: impl ToSocketAddrsAsync,
184    ) -> io::Result<Self> {
185        super::each_addr(addr, |addr| async move {
186            let addr = SockAddr::from(addr);
187            let bind_addr = SockAddr::from(bind_addr);
188
189            let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
190            socket.connect_async(&addr).await?;
191            Ok(Self { inner: socket })
192        })
193        .await
194    }
195
196    /// Creates new TcpStream from a [`std::net::TcpStream`].
197    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
198        Ok(Self {
199            inner: Socket::from_socket2(Socket2::from(stream))?,
200        })
201    }
202
203    /// Close the socket. If the returned future is dropped before polling, the
204    /// socket won't be closed.
205    pub fn close(self) -> impl Future<Output = io::Result<()>> {
206        self.inner.close()
207    }
208
209    /// Returns the socket address of the remote peer of this TCP connection.
210    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
211        self.inner
212            .peer_addr()
213            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
214    }
215
216    /// Returns the socket address of the local half of this TCP connection.
217    pub fn local_addr(&self) -> io::Result<SocketAddr> {
218        self.inner
219            .local_addr()
220            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
221    }
222
223    /// Splits a [`TcpStream`] into a read half and a write half, which can be
224    /// used to read and write the stream concurrently.
225    ///
226    /// This method is more efficient than
227    /// [`into_split`](TcpStream::into_split), but the halves cannot
228    /// be moved into independently spawned tasks.
229    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
230        crate::split(self)
231    }
232
233    /// Splits a [`TcpStream`] into a read half and a write half, which can be
234    /// used to read and write the stream concurrently.
235    ///
236    /// Unlike [`split`](TcpStream::split), the owned halves can be moved to
237    /// separate tasks, however this comes at the cost of a heap allocation.
238    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
239        crate::into_split(self)
240    }
241
242    /// Create [`PollFd`] from inner socket.
243    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
244        self.inner.to_poll_fd()
245    }
246
247    /// Create [`PollFd`] from inner socket.
248    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
249        self.inner.into_poll_fd()
250    }
251
252    /// Gets the value of the `TCP_NODELAY` option on this socket.
253    ///
254    /// For more information about this option, see
255    /// [`TcpStream::set_nodelay`].
256    pub fn nodelay(&self) -> io::Result<bool> {
257        self.inner.socket.nodelay()
258    }
259
260    /// Sets the value of the TCP_NODELAY option on this socket.
261    ///
262    /// If set, this option disables the Nagle algorithm. This means
263    /// that segments are always sent as soon as possible, even if
264    /// there is only a small amount of data. When not set, data is
265    /// buffered until there is a sufficient amount to send out,
266    /// thereby avoiding the frequent sending of small packets.
267    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
268        self.inner.socket.set_nodelay(nodelay)
269    }
270}
271
272impl AsyncRead for TcpStream {
273    #[inline]
274    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
275        (&*self).read(buf).await
276    }
277
278    #[inline]
279    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
280        (&*self).read_vectored(buf).await
281    }
282}
283
284impl AsyncRead for &TcpStream {
285    #[inline]
286    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
287        self.inner.recv(buf).await
288    }
289
290    #[inline]
291    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
292        self.inner.recv_vectored(buf).await
293    }
294}
295
296impl AsyncReadManaged for TcpStream {
297    type Buffer<'a> = BorrowedBuffer<'a>;
298    type BufferPool = BufferPool;
299
300    async fn read_managed<'a>(
301        &mut self,
302        buffer_pool: &'a Self::BufferPool,
303        len: usize,
304    ) -> io::Result<Self::Buffer<'a>> {
305        (&*self).read_managed(buffer_pool, len).await
306    }
307}
308
309impl AsyncReadManaged for &TcpStream {
310    type Buffer<'a> = BorrowedBuffer<'a>;
311    type BufferPool = BufferPool;
312
313    async fn read_managed<'a>(
314        &mut self,
315        buffer_pool: &'a Self::BufferPool,
316        len: usize,
317    ) -> io::Result<Self::Buffer<'a>> {
318        self.inner.recv_managed(buffer_pool, len as _).await
319    }
320}
321
322impl AsyncWrite for TcpStream {
323    #[inline]
324    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
325        (&*self).write(buf).await
326    }
327
328    #[inline]
329    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
330        (&*self).write_vectored(buf).await
331    }
332
333    #[inline]
334    async fn flush(&mut self) -> io::Result<()> {
335        (&*self).flush().await
336    }
337
338    #[inline]
339    async fn shutdown(&mut self) -> io::Result<()> {
340        (&*self).shutdown().await
341    }
342}
343
344impl AsyncWrite for &TcpStream {
345    #[inline]
346    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
347        self.inner.send(buf).await
348    }
349
350    #[inline]
351    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
352        self.inner.send_vectored(buf).await
353    }
354
355    #[inline]
356    async fn flush(&mut self) -> io::Result<()> {
357        Ok(())
358    }
359
360    #[inline]
361    async fn shutdown(&mut self) -> io::Result<()> {
362        self.inner.shutdown().await
363    }
364}
365
366impl Splittable for TcpStream {
367    type ReadHalf = OwnedReadHalf<Self>;
368    type WriteHalf = OwnedWriteHalf<Self>;
369
370    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
371        crate::into_split(self)
372    }
373}
374
375impl<'a> Splittable for &'a TcpStream {
376    type ReadHalf = ReadHalf<'a, TcpStream>;
377    type WriteHalf = WriteHalf<'a, TcpStream>;
378
379    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
380        crate::split(self)
381    }
382}
383
384impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);