use std::{
future::Future,
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_driver::{
BufferRef, impl_raw_fd,
op::{RecvFlags, RecvMsgMultiResult, SendFlags},
};
use compio_io::{
AsyncRead, AsyncReadManaged, AsyncReadMulti, AsyncWrite,
ancillary::{
AsyncReadAncillary, AsyncReadAncillaryManaged, AsyncReadAncillaryMulti, AsyncWriteAncillary,
},
util::Splittable,
};
use compio_runtime::fd::PollFd;
use futures_util::{Stream, StreamExt, stream::FusedStream};
use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
use crate::{
Incoming, MSG_NOSIGNAL, OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync,
WriteHalf,
};
#[derive(Debug, Clone)]
pub struct TcpListener {
inner: Socket,
}
impl TcpListener {
pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
super::each_addr(addr, |addr| async move {
let sa = SockAddr::from(addr);
let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
socket.socket.set_reuse_address(true)?;
socket.bind(&sa).await?;
socket.listen(128).await?;
Ok(Self { inner: socket })
})
.await
}
pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (socket, addr) = self.inner.accept().await?;
let stream = TcpStream { inner: socket };
Ok((stream, addr.as_socket().expect("should be SocketAddr")))
}
pub fn incoming(&self) -> TcpIncoming<'_> {
TcpIncoming {
inner: self.inner.incoming(),
}
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
pub fn ttl_v4(&self) -> io::Result<u32> {
self.inner.socket.ttl_v4()
}
pub fn set_ttl_v4(&self, ttl: u32) -> io::Result<()> {
self.inner.socket.set_ttl_v4(ttl)
}
}
impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
pub struct TcpIncoming<'a> {
inner: Incoming<'a>,
}
impl Stream for TcpIncoming<'_> {
type Item = io::Result<TcpStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.inner.poll_next_unpin(cx).map(|res| {
res.map(|res| {
let socket = res?;
Ok(TcpStream { inner: socket })
})
})
}
}
impl FusedStream for TcpIncoming<'_> {
fn is_terminated(&self) -> bool {
self.inner.is_terminated()
}
}
#[derive(Debug, Clone)]
pub struct TcpStream {
inner: Socket,
}
impl TcpStream {
pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
super::each_addr(addr, |addr| async move {
let addr2 = SockAddr::from(addr);
let socket = Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
if cfg!(windows) {
let bind_addr = if addr.is_ipv4() {
SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
} else if addr.is_ipv6() {
SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
} else {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Unsupported address domain.",
));
};
socket.bind(&bind_addr).await?;
};
socket.connect_async(&addr2).await?;
Ok(Self { inner: socket })
})
.await
}
pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
pub fn close(self) -> impl Future<Output = io::Result<()>> {
self.inner.close()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner
.peer_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
pub fn split(&self) -> (ReadHalf<'_, Self>, WriteHalf<'_, Self>) {
crate::split(self)
}
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
self.inner.to_poll_fd()
}
pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
self.inner.into_poll_fd()
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.socket.tcp_nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.socket.set_tcp_nodelay(nodelay)
}
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "fuchsia",
target_os = "cygwin",
))]
pub fn quickack(&self) -> io::Result<bool> {
self.inner.socket.tcp_quickack()
}
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "fuchsia",
target_os = "cygwin",
))]
pub fn set_quickack(&self, quickack: bool) -> io::Result<()> {
self.inner.socket.set_tcp_quickack(quickack)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.socket.linger()
}
pub fn set_zero_linger(&self) -> io::Result<()> {
self.inner.socket.set_linger(Some(Duration::ZERO))
}
pub fn ttl_v4(&self) -> io::Result<u32> {
self.inner.socket.ttl_v4()
}
pub fn set_ttl_v4(&self, ttl: u32) -> io::Result<()> {
self.inner.socket.set_ttl_v4(ttl)
}
pub async fn send_out_of_band<T: IoBuf>(&self, buf: T) -> BufResult<usize, T> {
#[cfg(unix)]
use libc::MSG_OOB;
#[cfg(windows)]
use windows_sys::Win32::Networking::WinSock::MSG_OOB;
self.inner
.send(
buf,
SendFlags::from_bits_retain(MSG_OOB as _) | MSG_NOSIGNAL,
)
.await
}
pub async fn send_zerocopy<T: IoBuf>(
&self,
buf: T,
) -> BufResult<usize, impl Future<Output = T> + use<T>> {
self.inner.send_zerocopy(buf, MSG_NOSIGNAL).await
}
pub async fn send_zerocopy_vectored<T: IoVectoredBuf>(
&self,
buf: T,
) -> BufResult<usize, impl Future<Output = T> + use<T>> {
self.inner.send_zerocopy_vectored(buf, MSG_NOSIGNAL).await
}
pub fn sock_nonempty(&self) -> Option<bool> {
self.inner.sock_nonempty()
}
}
impl AsyncRead for TcpStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
(&*self).read(buf).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
(&*self).read_vectored(buf).await
}
}
impl AsyncRead for &TcpStream {
#[inline]
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.inner.recv(buf, RecvFlags::empty()).await
}
#[inline]
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.inner.recv_vectored(buf, RecvFlags::empty()).await
}
}
impl AsyncReadManaged for TcpStream {
type Buffer = BufferRef;
async fn read_managed(&mut self, len: usize) -> io::Result<Option<Self::Buffer>> {
(&*self).read_managed(len).await
}
}
impl AsyncReadManaged for &TcpStream {
type Buffer = BufferRef;
async fn read_managed(&mut self, len: usize) -> io::Result<Option<Self::Buffer>> {
self.inner.recv_managed(len, RecvFlags::empty()).await
}
}
impl AsyncReadMulti for TcpStream {
fn read_multi(&mut self, len: usize) -> impl Stream<Item = io::Result<Self::Buffer>> {
self.inner.recv_multi(len, RecvFlags::empty())
}
}
impl AsyncReadMulti for &TcpStream {
fn read_multi(&mut self, len: usize) -> impl Stream<Item = io::Result<Self::Buffer>> {
self.inner.recv_multi(len, RecvFlags::empty())
}
}
impl AsyncReadAncillary for TcpStream {
#[inline]
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
(&*self).read_with_ancillary(buffer, control).await
}
#[inline]
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
(&*self).read_vectored_with_ancillary(buffer, control).await
}
}
impl AsyncReadAncillary for &TcpStream {
#[inline]
async fn read_with_ancillary<T: IoBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
self.inner
.recv_msg(buffer, control, RecvFlags::empty())
.await
.map_res(|(res, len, _addr)| (res, len))
}
#[inline]
async fn read_vectored_with_ancillary<T: IoVectoredBufMut, C: IoBufMut>(
&mut self,
buffer: T,
control: C,
) -> BufResult<(usize, usize), (T, C)> {
self.inner
.recv_msg_vectored(buffer, control, RecvFlags::empty())
.await
.map_res(|(res, len, _addr)| (res, len))
}
}
impl AsyncReadAncillaryManaged for TcpStream {
#[inline]
async fn read_managed_with_ancillary<C: IoBufMut>(
&mut self,
len: usize,
control: C,
) -> io::Result<Option<(Self::Buffer, C)>> {
(&*self).read_managed_with_ancillary(len, control).await
}
}
impl AsyncReadAncillaryManaged for &TcpStream {
#[inline]
async fn read_managed_with_ancillary<C: IoBufMut>(
&mut self,
len: usize,
control: C,
) -> io::Result<Option<(Self::Buffer, C)>> {
self.inner
.recv_msg_managed(len, control, RecvFlags::empty())
.await
.map(|res| res.map(|(res, len, _addr)| (res, len)))
}
}
impl AsyncReadAncillaryMulti for TcpStream {
type Return = RecvMsgMultiResult;
#[inline]
fn read_multi_with_ancillary(
&mut self,
control_len: usize,
) -> impl Stream<Item = io::Result<Self::Return>> {
self.inner.recv_msg_multi(control_len, RecvFlags::empty())
}
}
impl AsyncReadAncillaryMulti for &TcpStream {
type Return = RecvMsgMultiResult;
#[inline]
fn read_multi_with_ancillary(
&mut self,
control_len: usize,
) -> impl Stream<Item = io::Result<Self::Return>> {
self.inner.recv_msg_multi(control_len, RecvFlags::empty())
}
}
impl AsyncWrite for TcpStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write(buf).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
(&*self).write_vectored(buf).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
(&*self).flush().await
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
(&*self).shutdown().await
}
}
impl AsyncWrite for &TcpStream {
#[inline]
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send(buf, MSG_NOSIGNAL).await
}
#[inline]
async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
self.inner.send_vectored(buf, MSG_NOSIGNAL).await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
self.inner.shutdown().await
}
}
impl AsyncWriteAncillary for TcpStream {
#[inline]
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
(&*self).write_with_ancillary(buffer, control).await
}
#[inline]
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
(&*self)
.write_vectored_with_ancillary(buffer, control)
.await
}
}
impl AsyncWriteAncillary for &TcpStream {
#[inline]
async fn write_with_ancillary<T: IoBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
self.inner
.send_msg(buffer, control, None, MSG_NOSIGNAL)
.await
}
#[inline]
async fn write_vectored_with_ancillary<T: IoVectoredBuf, C: IoBuf>(
&mut self,
buffer: T,
control: C,
) -> BufResult<usize, (T, C)> {
self.inner
.send_msg_vectored(buffer, control, None, MSG_NOSIGNAL)
.await
}
}
impl Splittable for TcpStream {
type ReadHalf = OwnedReadHalf<Self>;
type WriteHalf = OwnedWriteHalf<Self>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::into_split(self)
}
}
impl<'a> Splittable for &'a TcpStream {
type ReadHalf = ReadHalf<'a, TcpStream>;
type WriteHalf = WriteHalf<'a, TcpStream>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::split(self)
}
}
impl<'a> Splittable for &'a mut TcpStream {
type ReadHalf = ReadHalf<'a, TcpStream>;
type WriteHalf = WriteHalf<'a, TcpStream>;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
crate::split(self)
}
}
impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);
#[derive(Debug)]
pub struct TcpSocket {
inner: Socket,
}
impl TcpSocket {
pub async fn new_v4() -> io::Result<TcpSocket> {
TcpSocket::new(socket2::Domain::IPV4).await
}
pub async fn new_v6() -> io::Result<TcpSocket> {
TcpSocket::new(socket2::Domain::IPV6).await
}
async fn new(domain: socket2::Domain) -> io::Result<TcpSocket> {
let inner =
Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)).await?;
Ok(TcpSocket { inner })
}
pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
self.inner.socket.set_keepalive(keepalive)
}
pub fn keepalive(&self) -> io::Result<bool> {
self.inner.socket.keepalive()
}
pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
self.inner.socket.set_reuse_address(reuseaddr)
}
pub fn reuseaddr(&self) -> io::Result<bool> {
self.inner.socket.reuse_address()
}
#[cfg(all(
unix,
not(target_os = "solaris"),
not(target_os = "illumos"),
not(target_os = "cygwin"),
))]
pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> {
self.inner.socket.set_reuse_port(reuseport)
}
#[cfg(all(
unix,
not(target_os = "solaris"),
not(target_os = "illumos"),
not(target_os = "cygwin"),
))]
pub fn reuseport(&self) -> io::Result<bool> {
self.inner.socket.reuse_port()
}
pub fn set_send_buffer_size(&self, size: u32) -> io::Result<()> {
self.inner.socket.set_send_buffer_size(size as usize)
}
pub fn send_buffer_size(&self) -> io::Result<u32> {
self.inner.socket.send_buffer_size().map(|n| n as u32)
}
pub fn set_recv_buffer_size(&self, size: u32) -> io::Result<()> {
self.inner.socket.set_recv_buffer_size(size as usize)
}
pub fn recv_buffer_size(&self) -> io::Result<u32> {
self.inner.socket.recv_buffer_size().map(|n| n as u32)
}
pub fn set_zero_linger(&self) -> io::Result<()> {
self.inner.socket.set_linger(Some(Duration::ZERO))
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.socket.linger()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.socket.set_tcp_nodelay(nodelay)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.socket.tcp_nodelay()
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "cygwin",
))]
pub fn tclass_v6(&self) -> io::Result<u32> {
self.inner.socket.tclass_v6()
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "cygwin",
))]
pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
self.inner.socket.set_tclass_v6(tclass)
}
#[cfg(not(any(
target_os = "fuchsia",
target_os = "redox",
target_os = "solaris",
target_os = "illumos",
target_os = "haiku"
)))]
pub fn tos_v4(&self) -> io::Result<u32> {
self.inner.socket.tos_v4()
}
#[cfg(not(any(
target_os = "fuchsia",
target_os = "redox",
target_os = "solaris",
target_os = "illumos",
target_os = "haiku"
)))]
pub fn set_tos_v4(&self, tos: u32) -> io::Result<()> {
self.inner.socket.set_tos_v4(tos)
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux",))]
pub fn device(&self) -> io::Result<Option<Vec<u8>>> {
self.inner.socket.device()
}
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
pub fn bind_device(&self, interface: Option<&[u8]>) -> io::Result<()> {
self.inner.socket.bind_device(interface)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self
.inner
.local_addr()?
.as_socket()
.expect("should be SocketAddr"))
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.socket.take_error()
}
pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> {
self.inner.bind(&addr.into()).await
}
pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
self.inner.connect_async(&addr.into()).await?;
Ok(TcpStream { inner: self.inner })
}
pub async fn listen(self, backlog: i32) -> io::Result<TcpListener> {
self.inner.listen(backlog).await?;
Ok(TcpListener { inner: self.inner })
}
pub fn from_std_stream(stream: std::net::TcpStream) -> io::Result<TcpSocket> {
Ok(Self {
inner: Socket::from_socket2(Socket2::from(stream))?,
})
}
}
impl_raw_fd!(TcpSocket, socket2::Socket, inner, socket);