1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
6use compio_runtime::{BorrowedBuffer, BufferPool};
7use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
8
9use crate::{
10 OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf,
11};
12
13#[derive(Debug, Clone)]
46pub struct TcpListener {
47 inner: Socket,
48}
49
50impl TcpListener {
51 pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
59 super::each_addr(addr, |addr| async move {
60 let sa = SockAddr::from(addr);
61 let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
62 socket.socket.set_reuse_address(true)?;
63 socket.socket.bind(&sa)?;
64 socket.listen(128)?;
65 Ok(Self { inner: socket })
66 })
67 .await
68 }
69
70 pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
72 Ok(Self {
73 inner: Socket::from_socket2(Socket2::from(stream))?,
74 })
75 }
76
77 pub fn close(self) -> impl Future<Output = io::Result<()>> {
80 self.inner.close()
81 }
82
83 pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
89 let (socket, addr) = self.inner.accept().await?;
90 let stream = TcpStream { inner: socket };
91 Ok((stream, addr.as_socket().expect("should be SocketAddr")))
92 }
93
94 pub fn local_addr(&self) -> io::Result<SocketAddr> {
118 self.inner
119 .local_addr()
120 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
121 }
122}
123
124impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
125
126#[derive(Debug, Clone)]
148pub struct TcpStream {
149 inner: Socket,
150}
151
152impl TcpStream {
153 pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
155 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
156
157 super::each_addr(addr, |addr| async move {
158 let addr2 = SockAddr::from(addr);
159 let socket = if cfg!(windows) {
160 let bind_addr = if addr.is_ipv4() {
161 SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
162 } else if addr.is_ipv6() {
163 SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
164 } else {
165 return Err(io::Error::new(
166 io::ErrorKind::AddrNotAvailable,
167 "Unsupported address domain.",
168 ));
169 };
170 Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
171 } else {
172 Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
173 };
174 socket.connect_async(&addr2).await?;
175 Ok(Self { inner: socket })
176 })
177 .await
178 }
179
180 pub async fn bind_and_connect(
182 bind_addr: SocketAddr,
183 addr: impl ToSocketAddrsAsync,
184 ) -> io::Result<Self> {
185 super::each_addr(addr, |addr| async move {
186 let addr = SockAddr::from(addr);
187 let bind_addr = SockAddr::from(bind_addr);
188
189 let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
190 socket.connect_async(&addr).await?;
191 Ok(Self { inner: socket })
192 })
193 .await
194 }
195
196 pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
198 Ok(Self {
199 inner: Socket::from_socket2(Socket2::from(stream))?,
200 })
201 }
202
203 pub fn close(self) -> impl Future<Output = io::Result<()>> {
206 self.inner.close()
207 }
208
209 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
211 self.inner
212 .peer_addr()
213 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
214 }
215
216 pub fn local_addr(&self) -> io::Result<SocketAddr> {
218 self.inner
219 .local_addr()
220 .map(|addr| addr.as_socket().expect("should be SocketAddr"))
221 }
222
223 pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
230 crate::split(self)
231 }
232
233 pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
239 crate::into_split(self)
240 }
241
242 pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
244 self.inner.to_poll_fd()
245 }
246
247 pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
249 self.inner.into_poll_fd()
250 }
251
252 pub fn nodelay(&self) -> io::Result<bool> {
257 self.inner.socket.nodelay()
258 }
259
260 pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
268 self.inner.socket.set_nodelay(nodelay)
269 }
270}
271
272impl AsyncRead for TcpStream {
273 #[inline]
274 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
275 (&*self).read(buf).await
276 }
277
278 #[inline]
279 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
280 (&*self).read_vectored(buf).await
281 }
282}
283
284impl AsyncRead for &TcpStream {
285 #[inline]
286 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
287 self.inner.recv(buf).await
288 }
289
290 #[inline]
291 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
292 self.inner.recv_vectored(buf).await
293 }
294}
295
296impl AsyncReadManaged for TcpStream {
297 type Buffer<'a> = BorrowedBuffer<'a>;
298 type BufferPool = BufferPool;
299
300 async fn read_managed<'a>(
301 &mut self,
302 buffer_pool: &'a Self::BufferPool,
303 len: usize,
304 ) -> io::Result<Self::Buffer<'a>> {
305 (&*self).read_managed(buffer_pool, len).await
306 }
307}
308
309impl AsyncReadManaged for &TcpStream {
310 type Buffer<'a> = BorrowedBuffer<'a>;
311 type BufferPool = BufferPool;
312
313 async fn read_managed<'a>(
314 &mut self,
315 buffer_pool: &'a Self::BufferPool,
316 len: usize,
317 ) -> io::Result<Self::Buffer<'a>> {
318 self.inner.recv_managed(buffer_pool, len as _).await
319 }
320}
321
322impl AsyncWrite for TcpStream {
323 #[inline]
324 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
325 (&*self).write(buf).await
326 }
327
328 #[inline]
329 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
330 (&*self).write_vectored(buf).await
331 }
332
333 #[inline]
334 async fn flush(&mut self) -> io::Result<()> {
335 (&*self).flush().await
336 }
337
338 #[inline]
339 async fn shutdown(&mut self) -> io::Result<()> {
340 (&*self).shutdown().await
341 }
342}
343
344impl AsyncWrite for &TcpStream {
345 #[inline]
346 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
347 self.inner.send(buf).await
348 }
349
350 #[inline]
351 async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
352 self.inner.send_vectored(buf).await
353 }
354
355 #[inline]
356 async fn flush(&mut self) -> io::Result<()> {
357 Ok(())
358 }
359
360 #[inline]
361 async fn shutdown(&mut self) -> io::Result<()> {
362 self.inner.shutdown().await
363 }
364}
365
366impl Splittable for TcpStream {
367 type ReadHalf = OwnedReadHalf<Self>;
368 type WriteHalf = OwnedWriteHalf<Self>;
369
370 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
371 crate::into_split(self)
372 }
373}
374
375impl<'a> Splittable for &'a TcpStream {
376 type ReadHalf = ReadHalf<'a, TcpStream>;
377 type WriteHalf = WriteHalf<'a, TcpStream>;
378
379 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
380 crate::split(self)
381 }
382}
383
384impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);