compio_net/
tcp.rs

1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
8
9use crate::{
10    OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, TcpOpts, ToSocketAddrsAsync, WriteHalf,
11};
12
13/// A TCP socket server, listening for connections.
14///
15/// You can accept a new connection by using the
16/// [`accept`](`TcpListener::accept`) method.
17///
18/// # Examples
19///
20/// ```
21/// use std::net::SocketAddr;
22///
23/// use compio_io::{AsyncReadExt, AsyncWriteExt};
24/// use compio_net::{TcpListener, TcpStream};
25/// use socket2::SockAddr;
26///
27/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
28/// let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
29///
30/// let addr = listener.local_addr().unwrap();
31///
32/// let tx_fut = TcpStream::connect(&addr);
33///
34/// let rx_fut = listener.accept();
35///
36/// let (mut tx, (mut rx, _)) = futures_util::try_join!(tx_fut, rx_fut).unwrap();
37///
38/// tx.write_all("test").await.0.unwrap();
39///
40/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
41///
42/// assert_eq!(buf, b"test");
43/// # });
44/// ```
45#[derive(Debug, Clone)]
46pub struct TcpListener {
47    inner: Socket,
48}
49
50impl TcpListener {
51    /// Creates a new `TcpListener`, which will be bound to the specified
52    /// address.
53    ///
54    /// The returned listener is ready for accepting connections.
55    ///
56    /// Binding with a port number of 0 will request that the OS assigns a port
57    /// to this listener.
58    pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
59        Self::bind_with_options(addr, TcpOpts::default().reuse_address(true)).await
60    }
61
62    /// Creates a new `TcpListener`, which will be bound to the specified
63    /// address using `TcpOpts`.
64    ///
65    /// The returned listener is ready for accepting connections.
66    ///
67    /// Binding with a port number of 0 will request that the OS assigns a port
68    /// to this listener.
69    pub async fn bind_with_options(
70        addr: impl ToSocketAddrsAsync,
71        options: TcpOpts,
72    ) -> io::Result<Self> {
73        super::each_addr(addr, |addr| async move {
74            let sa = SockAddr::from(addr);
75            let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
76            options.setup_socket(&socket)?;
77            socket.socket.bind(&sa)?;
78            socket.listen(128)?;
79            Ok(Self { inner: socket })
80        })
81        .await
82    }
83
84    /// Creates new TcpListener from a [`std::net::TcpListener`].
85    pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
86        Ok(Self {
87            inner: Socket::from_socket2(Socket2::from(stream))?,
88        })
89    }
90
91    /// Close the socket. If the returned future is dropped before polling, the
92    /// socket won't be closed.
93    pub fn close(self) -> impl Future<Output = io::Result<()>> {
94        self.inner.close()
95    }
96
97    /// Accepts a new incoming connection from this listener.
98    ///
99    /// This function will yield once a new TCP connection is established. When
100    /// established, the corresponding [`TcpStream`] and the remote peer's
101    /// address will be returned.
102    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
103        let (socket, addr) = self.inner.accept().await?;
104        let stream = TcpStream { inner: socket };
105        Ok((stream, addr.as_socket().expect("should be SocketAddr")))
106    }
107
108    /// Returns the local address that this listener is bound to.
109    ///
110    /// This can be useful, for example, when binding to port 0 to
111    /// figure out which port was actually bound.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
117    ///
118    /// use compio_net::TcpListener;
119    /// use socket2::SockAddr;
120    ///
121    /// # compio_runtime::Runtime::new().unwrap().block_on(async {
122    /// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
123    ///
124    /// let addr = listener.local_addr().expect("Couldn't get local address");
125    /// assert_eq!(
126    ///     addr,
127    ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080))
128    /// );
129    /// # });
130    /// ```
131    pub fn local_addr(&self) -> io::Result<SocketAddr> {
132        self.inner
133            .local_addr()
134            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
135    }
136}
137
138impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
139
140/// A TCP stream between a local and a remote socket.
141///
142/// A TCP stream can either be created by connecting to an endpoint, via the
143/// `connect` method, or by accepting a connection from a listener.
144///
145/// # Examples
146///
147/// ```no_run
148/// use std::net::SocketAddr;
149///
150/// use compio_io::AsyncWrite;
151/// use compio_net::TcpStream;
152///
153/// # compio_runtime::Runtime::new().unwrap().block_on(async {
154/// // Connect to a peer
155/// let mut stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
156///
157/// // Write some data.
158/// stream.write("hello world!").await.unwrap();
159/// # })
160/// ```
161#[derive(Debug, Clone)]
162pub struct TcpStream {
163    inner: Socket,
164}
165
166impl TcpStream {
167    /// Opens a TCP connection to a remote host.
168    pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
169        Self::connect_base(addr, None).await
170    }
171
172    /// Opens a TCP connection to a remote host using `TcpOpts`.
173    pub async fn connect_with_options(
174        addr: impl ToSocketAddrsAsync,
175        options: TcpOpts,
176    ) -> io::Result<Self> {
177        Self::connect_base(addr, Some(options)).await
178    }
179
180    async fn connect_base(
181        addr: impl ToSocketAddrsAsync,
182        options: Option<TcpOpts>,
183    ) -> io::Result<Self> {
184        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
185
186        super::each_addr(addr, |addr| async move {
187            let addr2 = SockAddr::from(addr);
188            let socket = if cfg!(windows) {
189                let bind_addr = if addr.is_ipv4() {
190                    SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
191                } else if addr.is_ipv6() {
192                    SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
193                } else {
194                    return Err(io::Error::new(
195                        io::ErrorKind::AddrNotAvailable,
196                        "Unsupported address domain.",
197                    ));
198                };
199                Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
200            } else {
201                Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
202            };
203            if let Some(options) = &options {
204                options.setup_socket(&socket)?;
205            }
206            socket.connect_async(&addr2).await?;
207            Ok(Self { inner: socket })
208        })
209        .await
210    }
211
212    /// Bind to `bind_addr` then opens a TCP connection to a remote host.
213    pub async fn bind_and_connect(
214        bind_addr: SocketAddr,
215        addr: impl ToSocketAddrsAsync,
216    ) -> io::Result<Self> {
217        Self::bind_and_connect_base(bind_addr, addr, None).await
218    }
219
220    /// Bind to `bind_addr` then opens a TCP connection to a remote host using
221    /// `TcpOpts`.
222    pub async fn bind_and_connect_with_options(
223        bind_addr: SocketAddr,
224        addr: impl ToSocketAddrsAsync,
225        options: TcpOpts,
226    ) -> io::Result<Self> {
227        Self::bind_and_connect_base(bind_addr, addr, Some(options)).await
228    }
229
230    async fn bind_and_connect_base(
231        bind_addr: SocketAddr,
232        addr: impl ToSocketAddrsAsync,
233        options: Option<TcpOpts>,
234    ) -> io::Result<Self> {
235        let options = options.unwrap_or_default();
236        super::each_addr(addr, |addr| async move {
237            let addr = SockAddr::from(addr);
238            let bind_addr = SockAddr::from(bind_addr);
239
240            let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
241            options.setup_socket(&socket)?;
242            socket.connect_async(&addr).await?;
243            Ok(Self { inner: socket })
244        })
245        .await
246    }
247
248    /// Creates new TcpStream from a [`std::net::TcpStream`].
249    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
250        Ok(Self {
251            inner: Socket::from_socket2(Socket2::from(stream))?,
252        })
253    }
254
255    /// Close the socket. If the returned future is dropped before polling, the
256    /// socket won't be closed.
257    pub fn close(self) -> impl Future<Output = io::Result<()>> {
258        self.inner.close()
259    }
260
261    /// Returns the socket address of the remote peer of this TCP connection.
262    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
263        self.inner
264            .peer_addr()
265            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
266    }
267
268    /// Returns the socket address of the local half of this TCP connection.
269    pub fn local_addr(&self) -> io::Result<SocketAddr> {
270        self.inner
271            .local_addr()
272            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
273    }
274
275    /// Splits a [`TcpStream`] into a read half and a write half, which can be
276    /// used to read and write the stream concurrently.
277    ///
278    /// This method is more efficient than
279    /// [`into_split`](TcpStream::into_split), but the halves cannot
280    /// be moved into independently spawned tasks.
281    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
282        crate::split(self)
283    }
284
285    /// Splits a [`TcpStream`] into a read half and a write half, which can be
286    /// used to read and write the stream concurrently.
287    ///
288    /// Unlike [`split`](TcpStream::split), the owned halves can be moved to
289    /// separate tasks, however this comes at the cost of a heap allocation.
290    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
291        crate::into_split(self)
292    }
293
294    /// Create [`PollFd`] from inner socket.
295    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
296        self.inner.to_poll_fd()
297    }
298
299    /// Create [`PollFd`] from inner socket.
300    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
301        self.inner.into_poll_fd()
302    }
303
304    /// Gets the value of the `TCP_NODELAY` option on this socket.
305    ///
306    /// For more information about this option, see
307    /// [`TcpStream::set_nodelay`].
308    pub fn nodelay(&self) -> io::Result<bool> {
309        self.inner.socket.tcp_nodelay()
310    }
311
312    /// Sets the value of the TCP_NODELAY option on this socket.
313    ///
314    /// If set, this option disables the Nagle algorithm. This means
315    /// that segments are always sent as soon as possible, even if
316    /// there is only a small amount of data. When not set, data is
317    /// buffered until there is a sufficient amount to send out,
318    /// thereby avoiding the frequent sending of small packets.
319    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
320        self.inner.socket.set_tcp_nodelay(nodelay)
321    }
322}
323
324impl AsyncRead for TcpStream {
325    #[inline]
326    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
327        (&*self).read(buf).await
328    }
329
330    #[inline]
331    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
332        (&*self).read_vectored(buf).await
333    }
334}
335
336impl AsyncRead for &TcpStream {
337    #[inline]
338    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
339        self.inner.recv(buf).await
340    }
341
342    #[inline]
343    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
344        self.inner.recv_vectored(buf).await
345    }
346}
347
348impl AsyncReadManaged for TcpStream {
349    type Buffer<'a> = BorrowedBuffer<'a>;
350    type BufferPool = BufferPool;
351
352    async fn read_managed<'a>(
353        &mut self,
354        buffer_pool: &'a Self::BufferPool,
355        len: usize,
356    ) -> io::Result<Self::Buffer<'a>> {
357        (&*self).read_managed(buffer_pool, len).await
358    }
359}
360
361impl AsyncReadManaged for &TcpStream {
362    type Buffer<'a> = BorrowedBuffer<'a>;
363    type BufferPool = BufferPool;
364
365    async fn read_managed<'a>(
366        &mut self,
367        buffer_pool: &'a Self::BufferPool,
368        len: usize,
369    ) -> io::Result<Self::Buffer<'a>> {
370        self.inner.recv_managed(buffer_pool, len as _).await
371    }
372}
373
374impl AsyncWrite for TcpStream {
375    #[inline]
376    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
377        (&*self).write(buf).await
378    }
379
380    #[inline]
381    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
382        (&*self).write_vectored(buf).await
383    }
384
385    #[inline]
386    async fn flush(&mut self) -> io::Result<()> {
387        (&*self).flush().await
388    }
389
390    #[inline]
391    async fn shutdown(&mut self) -> io::Result<()> {
392        (&*self).shutdown().await
393    }
394}
395
396impl AsyncWrite for &TcpStream {
397    #[inline]
398    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
399        self.inner.send(buf).await
400    }
401
402    #[inline]
403    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
404        self.inner.send_vectored(buf).await
405    }
406
407    #[inline]
408    async fn flush(&mut self) -> io::Result<()> {
409        Ok(())
410    }
411
412    #[inline]
413    async fn shutdown(&mut self) -> io::Result<()> {
414        self.inner.shutdown().await
415    }
416}
417
418impl Splittable for TcpStream {
419    type ReadHalf = OwnedReadHalf<Self>;
420    type WriteHalf = OwnedWriteHalf<Self>;
421
422    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
423        crate::into_split(self)
424    }
425}
426
427impl<'a> Splittable for &'a TcpStream {
428    type ReadHalf = ReadHalf<'a, TcpStream>;
429    type WriteHalf = WriteHalf<'a, TcpStream>;
430
431    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
432        crate::split(self)
433    }
434}
435
436impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);