use std::{future::Future, io, net::SocketAddr};
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_driver::impl_raw_fd;
use compio_io::{AsyncRead, AsyncReadManaged, AsyncWrite, util::Splittable};
use compio_runtime::{BorrowedBuffer, BufferPool};
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
use crate::{
OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, SocketOpts, ToSocketAddrsAsync,
WriteHalf,
};
#[derive(Debug, Clone)]
pub struct TcpListener {
inner: Socket,
}
impl TcpListener {
pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
Self::bind_with_options(addr, &SocketOpts::default().reuse_address(true)).await
}
pub async fn bind_with_options(
addr: impl ToSocketAddrsAsync,
options: &SocketOpts,
) -> io::Result<Self> {
super::each_addr(addr, |addr| async move {
let sa = SockAddr::from(addr);
let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
options.setup_socket(&socket)?;
socket.socket.bind(&sa)?;
socket.listen(128)?;
Ok(Self { inner: socket })
})
.await
}
pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
self.accept_with_options(&SocketOpts::default()).await
}
pub async fn accept_with_options(
&self,
options: &SocketOpts,
) -> io::Result<(TcpStream, SocketAddr)> {
let (socket, addr) = self.inner.accept().await?;
options.setup_socket(&socket)?;
let stream = TcpStream { inner: socket };
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
}
impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
#[derive(Debug, Clone)]
pub struct TcpStream {
inner: Socket,
}
impl TcpStream {
pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
Self::connect_with_options(addr, &SocketOpts::default()).await
}
pub async fn connect_with_options(
addr: impl ToSocketAddrsAsync,
options: &SocketOpts,
) -> io::Result<Self> {
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
super::each_addr(addr, |addr| async move {
let addr2 = SockAddr::from(addr);
let socket = if cfg!(windows) {
let bind_addr = if addr.is_ipv4() {
SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
} else if addr.is_ipv6() {
SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
} else {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Unsupported address domain.",
));
};
Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
} else {
Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
};
options.setup_socket(&socket)?;
socket.connect_async(&addr2).await?;
Ok(Self { inner: socket })
})
.await
}
pub async fn bind_and_connect(
bind_addr: SocketAddr,
addr: impl ToSocketAddrsAsync,
) -> io::Result<Self> {
Self::bind_and_connect_with_options(bind_addr, addr, &SocketOpts::default()).await
}
pub async fn bind_and_connect_with_options(
bind_addr: SocketAddr,
addr: impl ToSocketAddrsAsync,
options: &SocketOpts,
) -> io::Result<Self> {
super::each_addr(addr, |addr| async move {
let addr = SockAddr::from(addr);
let bind_addr = SockAddr::from(bind_addr);
let socket = Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?;
options.setup_socket(&socket)?;
socket.connect_async(&addr).await?;
Ok(Self { inner: socket })
})
.await
}
pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner
.peer_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
crate::split(self)
}
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
self.inner.to_poll_fd()
}
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
self.inner.into_poll_fd()
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.socket.tcp_nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.socket.set_tcp_nodelay(nodelay)
}
pub async fn send_out_of_band<T: IoBuf>(&self, buf: T) -> BufResult<usize, T> {
#[cfg(unix)]
use libc::MSG_OOB;
#[cfg(windows)]
use windows_sys::Win32::Networking::WinSock::MSG_OOB;
self.inner.send(buf, MSG_OOB).await
}
}
impl AsyncRead for TcpStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
(&*self).read(buf).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
(&*self).read_vectored(buf).await
}
}
impl AsyncRead for &TcpStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.inner.recv(buf, 0).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.inner.recv_vectored(buf, 0).await
}
}
impl AsyncReadManaged for TcpStream {
type Buffer<'a> = BorrowedBuffer<'a>;
type BufferPool = BufferPool;
async fn read_managed<'a>(
&mut self,
buffer_pool: &'a Self::BufferPool,
len: usize,
) -> io::Result<Self::Buffer<'a>> {
(&*self).read_managed(buffer_pool, len).await
}
}
impl AsyncReadManaged for &TcpStream {
type Buffer<'a> = BorrowedBuffer<'a>;
type BufferPool = BufferPool;
async fn read_managed<'a>(
&mut self,
buffer_pool: &'a Self::BufferPool,
len: usize,
) -> io::Result<Self::Buffer<'a>> {
self.inner.recv_managed(buffer_pool, len as _, 0).await
}
}
impl AsyncWrite for TcpStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write(buf).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write_vectored(buf).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
(&*self).flush().await
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
(&*self).shutdown().await
}
}
impl AsyncWrite for &TcpStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send(buf, 0).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send_vectored(buf, 0).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
self.inner.shutdown().await
}
}
impl Splittable for TcpStream {
type ReadHalf = OwnedReadHalf<Self>;
type WriteHalf = OwnedWriteHalf<Self>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::into_split(self)
}
}
impl<'a> Splittable for &'a TcpStream {
type ReadHalf = ReadHalf<'a, TcpStream>;
type WriteHalf = WriteHalf<'a, TcpStream>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::split(self)
}
}
impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);