1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
8
9use crate::{
10    OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf,
11};
12
13#[derive(Debug, Clone)]
46pub struct TcpListener {
47    inner: Socket,
48}
49
50impl TcpListener {
51    pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
59        super::each_addr(addr, |addr| async move {
60            let sa = SockAddr::from(addr);
61            let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
62            socket.socket.set_reuse_address(true)?;
63            socket.socket.bind(&sa)?;
64            socket.listen(128)?;
65            Ok(Self { inner: socket })
66        })
67        .await
68    }
69
70    pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
72        Ok(Self {
73            inner: Socket::from_socket2(Socket2::from(stream))?,
74        })
75    }
76
77    pub fn close(self) -> impl Future<Output = io::Result<()>> {
80        self.inner.close()
81    }
82
83    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
89        let (socket, addr) = self.inner.accept().await?;
90        let stream = TcpStream { inner: socket };
91        Ok((stream, addr.as_socket().expect("should be SocketAddr")))
92    }
93
94    pub fn local_addr(&self) -> io::Result<SocketAddr> {
118        self.inner
119            .local_addr()
120            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
121    }
122}
123
124impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
125
126#[derive(Debug, Clone)]
148pub struct TcpStream {
149    inner: Socket,
150}
151
152impl TcpStream {
153    pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
155        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
156
157        super::each_addr(addr, |addr| async move {
158            let addr2 = SockAddr::from(addr);
159            let socket = if cfg!(windows) {
160                let bind_addr = if addr.is_ipv4() {
161                    SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
162                } else if addr.is_ipv6() {
163                    SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
164                } else {
165                    return Err(io::Error::new(
166                        io::ErrorKind::AddrNotAvailable,
167                        "Unsupported address domain.",
168                    ));
169                };
170                Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
171            } else {
172                Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
173            };
174            socket.connect_async(&addr2).await?;
175            Ok(Self { inner: socket })
176        })
177        .await
178    }
179
180    pub async fn bind_and_connect(
182        bind_addr: SocketAddr,
183        addr: impl ToSocketAddrsAsync,
184    ) -> io::Result<Self> {
185        super::each_addr(addr, |addr| async move {
186            let addr = SockAddr::from(addr);
187            let bind_addr = SockAddr::from(bind_addr);
188
189            let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
190            socket.connect_async(&addr).await?;
191            Ok(Self { inner: socket })
192        })
193        .await
194    }
195
196    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
198        Ok(Self {
199            inner: Socket::from_socket2(Socket2::from(stream))?,
200        })
201    }
202
203    pub fn close(self) -> impl Future<Output = io::Result<()>> {
206        self.inner.close()
207    }
208
209    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
211        self.inner
212            .peer_addr()
213            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
214    }
215
216    pub fn local_addr(&self) -> io::Result<SocketAddr> {
218        self.inner
219            .local_addr()
220            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
221    }
222
223    pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
230        crate::split(self)
231    }
232
233    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
239        crate::into_split(self)
240    }
241
242    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
244        self.inner.to_poll_fd()
245    }
246
247    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
249        self.inner.into_poll_fd()
250    }
251
252    pub fn nodelay(&self) -> io::Result<bool> {
257        self.inner.socket.nodelay()
258    }
259
260    pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
268        self.inner.socket.set_nodelay(nodelay)
269    }
270}
271
272impl AsyncRead for TcpStream {
273    #[inline]
274    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
275        (&*self).read(buf).await
276    }
277
278    #[inline]
279    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
280        (&*self).read_vectored(buf).await
281    }
282}
283
284impl AsyncRead for &TcpStream {
285    #[inline]
286    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
287        self.inner.recv(buf).await
288    }
289
290    #[inline]
291    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
292        self.inner.recv_vectored(buf).await
293    }
294}
295
296impl AsyncReadManaged for TcpStream {
297    type Buffer<'a> = BorrowedBuffer<'a>;
298    type BufferPool = BufferPool;
299
300    async fn read_managed<'a>(
301        &mut self,
302        buffer_pool: &'a Self::BufferPool,
303        len: usize,
304    ) -> io::Result<Self::Buffer<'a>> {
305        (&*self).read_managed(buffer_pool, len).await
306    }
307}
308
309impl AsyncReadManaged for &TcpStream {
310    type Buffer<'a> = BorrowedBuffer<'a>;
311    type BufferPool = BufferPool;
312
313    async fn read_managed<'a>(
314        &mut self,
315        buffer_pool: &'a Self::BufferPool,
316        len: usize,
317    ) -> io::Result<Self::Buffer<'a>> {
318        self.inner.recv_managed(buffer_pool, len as _).await
319    }
320}
321
322impl AsyncWrite for TcpStream {
323    #[inline]
324    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
325        (&*self).write(buf).await
326    }
327
328    #[inline]
329    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
330        (&*self).write_vectored(buf).await
331    }
332
333    #[inline]
334    async fn flush(&mut self) -> io::Result<()> {
335        (&*self).flush().await
336    }
337
338    #[inline]
339    async fn shutdown(&mut self) -> io::Result<()> {
340        (&*self).shutdown().await
341    }
342}
343
344impl AsyncWrite for &TcpStream {
345    #[inline]
346    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
347        self.inner.send(buf).await
348    }
349
350    #[inline]
351    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
352        self.inner.send_vectored(buf).await
353    }
354
355    #[inline]
356    async fn flush(&mut self) -> io::Result<()> {
357        Ok(())
358    }
359
360    #[inline]
361    async fn shutdown(&mut self) -> io::Result<()> {
362        self.inner.shutdown().await
363    }
364}
365
366impl Splittable for TcpStream {
367    type ReadHalf = OwnedReadHalf<Self>;
368    type WriteHalf = OwnedWriteHalf<Self>;
369
370    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
371        crate::into_split(self)
372    }
373}
374
375impl<'a> Splittable for &'a TcpStream {
376    type ReadHalf = ReadHalf<'a, TcpStream>;
377    type WriteHalf = WriteHalf<'a, TcpStream>;
378
379    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
380        crate::split(self)
381    }
382}
383
384impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);