use crate::socket::to_socket_protocol;
use crate::socket::{IpVersion, SocketOption};
use async_io::{Async, Timer};
use futures_lite::future::FutureExt;
use socket2::{SockAddr, Socket as SystemSocket};
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct AsyncSocket {
inner: Arc<Async<SystemSocket>>,
}
impl AsyncSocket {
pub fn new(socket_option: SocketOption) -> io::Result<AsyncSocket> {
let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
Some(to_socket_protocol(protocol)),
)?
} else {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
None,
)?
};
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub async fn new_with_async_connect(addr: &SocketAddr) -> io::Result<AsyncSocket> {
let stream = Async::<TcpStream>::connect(*addr).await?;
let socket = SystemSocket::from(stream.into_inner()?);
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub async fn new_with_async_connect_timeout(
addr: &SocketAddr,
timeout: Duration,
) -> io::Result<AsyncSocket> {
let stream = Async::<TcpStream>::connect(*addr)
.or(async {
Timer::after(timeout).await;
Err(io::ErrorKind::TimedOut.into())
})
.await?;
let socket = SystemSocket::from(stream.into_inner()?);
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn new_with_connect(
socket_option: SocketOption,
addr: &SocketAddr,
) -> io::Result<AsyncSocket> {
let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
Some(to_socket_protocol(protocol)),
)?
} else {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
None,
)?
};
let addr: SockAddr = SockAddr::from(*addr);
socket.connect(&addr)?;
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn new_with_connect_timeout(
socket_option: SocketOption,
addr: &SocketAddr,
timeout: Duration,
) -> io::Result<AsyncSocket> {
let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
Some(to_socket_protocol(protocol)),
)?
} else {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
None,
)?
};
let addr: SockAddr = SockAddr::from(*addr);
socket.connect_timeout(&addr, timeout)?;
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn new_with_listener(
socket_option: SocketOption,
addr: &SocketAddr,
) -> io::Result<AsyncSocket> {
let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
Some(to_socket_protocol(protocol)),
)?
} else {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
None,
)?
};
socket.set_nonblocking(true)?;
let addr: SockAddr = SockAddr::from(*addr);
socket.bind(&addr)?;
socket.listen(1024)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn new_with_bind(
socket_option: SocketOption,
addr: &SocketAddr,
) -> io::Result<AsyncSocket> {
let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
Some(to_socket_protocol(protocol)),
)?
} else {
SystemSocket::new(
socket_option.ip_version.to_domain(),
socket_option.socket_type.to_type(),
None,
)?
};
socket.set_nonblocking(true)?;
let addr: SockAddr = SockAddr::from(*addr);
socket.bind(&addr)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn from_tcp_stream(tcp_stream: TcpStream) -> io::Result<AsyncSocket> {
let socket = SystemSocket::from(tcp_stream);
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn from_tcp_listener(tcp_listener: TcpListener) -> io::Result<AsyncSocket> {
let socket = SystemSocket::from(tcp_listener);
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub fn from_udp_socket(udp_socket: UdpSocket) -> io::Result<AsyncSocket> {
let socket = SystemSocket::from(udp_socket);
socket.set_nonblocking(true)?;
Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
})
}
pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> {
let addr: SockAddr = SockAddr::from(addr);
self.inner.write_with(|inner| inner.bind(&addr)).await
}
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
loop {
self.inner.writable().await?;
match self.inner.write_with(|inner| inner.send(buf)).await {
Ok(n) => return Ok(n),
Err(_) => continue,
}
}
}
pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
let target: SockAddr = SockAddr::from(target);
loop {
self.inner.writable().await?;
match self
.inner
.write_with(|inner| inner.send_to(buf, &target))
.await
{
Ok(n) => return Ok(n),
Err(_) => continue,
}
}
}
pub async fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
loop {
self.inner.readable().await?;
match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
Ok(result) => return Ok(result),
Err(_) => continue,
}
}
}
pub async fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
loop {
self.inner.readable().await?;
match self
.inner
.read_with(|inner| inner.recv_from(recv_buf))
.await
{
Ok(result) => {
let (n, addr) = result;
match addr.as_socket() {
Some(addr) => return Ok((n, addr)),
None => continue,
}
}
Err(_) => continue,
}
}
}
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
loop {
self.inner.writable().await?;
match self.inner.write_with(|inner| inner.send(buf)).await {
Ok(n) => return Ok(n),
Err(_) => continue,
}
}
}
pub async fn write_timeout(&self, buf: &[u8], timeout: Duration) -> io::Result<usize> {
loop {
self.inner.writable().await?;
match self
.inner
.write_with(|inner| {
match inner.set_write_timeout(Some(timeout)) {
Ok(_) => {}
Err(e) => return Err(e),
}
inner.send(buf)
})
.await
{
Ok(n) => return Ok(n),
Err(_) => continue,
}
}
}
pub async fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
loop {
self.inner.readable().await?;
match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
Ok(result) => return Ok(result),
Err(_) => continue,
}
}
}
pub async fn read_timeout(&self, buf: &mut Vec<u8>, timeout: Duration) -> io::Result<usize> {
let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
loop {
self.inner.readable().await?;
match self
.inner
.read_with(|inner| {
match inner.set_read_timeout(Some(timeout)) {
Ok(_) => {}
Err(e) => return Err(e),
}
inner.recv(recv_buf)
})
.await
{
Ok(result) => return Ok(result),
Err(_) => continue,
}
}
}
pub async fn ttl(&self, ip_version: IpVersion) -> io::Result<u32> {
match ip_version {
IpVersion::V4 => self.inner.read_with(|inner| inner.ttl()).await,
IpVersion::V6 => self.inner.read_with(|inner| inner.unicast_hops_v6()).await,
}
}
pub async fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
match ip_version {
IpVersion::V4 => self.inner.write_with(|inner| inner.set_ttl(ttl)).await,
IpVersion::V6 => {
self.inner
.write_with(|inner| inner.set_unicast_hops_v6(ttl))
.await
}
}
}
pub async fn tos(&self) -> io::Result<u32> {
self.inner.read_with(|inner| inner.tos()).await
}
pub async fn set_tos(&self, tos: u32) -> io::Result<()> {
self.inner.write_with(|inner| inner.set_tos(tos)).await
}
pub async fn receive_tos(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.recv_tos()).await
}
pub async fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_recv_tos(receive_tos))
.await
}
pub async fn connect(&mut self, addr: &SocketAddr) -> io::Result<()> {
let addr: SockAddr = SockAddr::from(*addr);
self.inner.write_with(|inner| inner.connect(&addr)).await
}
pub async fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
let addr: SockAddr = SockAddr::from(*addr);
self.inner
.write_with(|inner| inner.connect_timeout(&addr, timeout))
.await
}
pub async fn listen(&self, backlog: i32) -> io::Result<()> {
self.inner.write_with(|inner| inner.listen(backlog)).await
}
pub async fn accept(&self) -> io::Result<(AsyncSocket, SocketAddr)> {
match self.inner.read_with(|inner| inner.accept()).await {
Ok((socket, addr)) => {
let socket = AsyncSocket {
inner: Arc::new(Async::new(socket)?),
};
Ok((socket, addr.as_socket().unwrap()))
}
Err(e) => Err(e),
}
}
pub async fn local_addr(&self) -> io::Result<SocketAddr> {
match self.inner.read_with(|inner| inner.local_addr()).await {
Ok(addr) => Ok(addr.as_socket().unwrap()),
Err(e) => Err(e),
}
}
pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
match self.inner.read_with(|inner| inner.peer_addr()).await {
Ok(addr) => Ok(addr.as_socket().unwrap()),
Err(e) => Err(e),
}
}
pub async fn socket_type(&self) -> io::Result<crate::socket::SocketType> {
match self.inner.read_with(|inner| inner.r#type()).await {
Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)),
Err(e) => Err(e),
}
}
pub async fn try_clone(&self) -> io::Result<AsyncSocket> {
match self.inner.read_with(|inner| inner.try_clone()).await {
Ok(socket) => Ok(AsyncSocket {
inner: Arc::new(Async::new(socket)?),
}),
Err(e) => Err(e),
}
}
#[cfg(not(target_os = "windows"))]
pub async fn is_nonblocking(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.nonblocking()).await
}
pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_nonblocking(nonblocking))
.await
}
pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.write_with(|inner| inner.shutdown(how)).await
}
pub async fn is_broadcast(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.broadcast()).await
}
pub async fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_broadcast(broadcast))
.await
}
pub async fn get_error(&self) -> io::Result<Option<io::Error>> {
self.inner.read_with(|inner| inner.take_error()).await
}
pub async fn is_keepalive(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.keepalive()).await
}
pub async fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_keepalive(keepalive))
.await
}
pub async fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.read_with(|inner| inner.linger()).await
}
pub async fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.write_with(|inner| inner.set_linger(dur)).await
}
pub async fn receive_buffer_size(&self) -> io::Result<usize> {
self.inner.read_with(|inner| inner.recv_buffer_size()).await
}
pub async fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_recv_buffer_size(size))
.await
}
pub async fn receive_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_with(|inner| inner.read_timeout()).await
}
pub async fn set_receive_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_read_timeout(duration))
.await
}
pub async fn reuse_address(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.reuse_address()).await
}
pub async fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_reuse_address(reuse))
.await
}
pub async fn send_buffer_size(&self) -> io::Result<usize> {
self.inner.read_with(|inner| inner.send_buffer_size()).await
}
pub async fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_send_buffer_size(size))
.await
}
pub async fn send_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_with(|inner| inner.write_timeout()).await
}
pub async fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_write_timeout(duration))
.await
}
pub async fn is_ip_header_included(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.header_included()).await
}
pub async fn set_ip_header_included(&self, include: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_header_included(include))
.await
}
pub async fn is_nodelay(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.nodelay()).await
}
pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_nodelay(nodelay))
.await
}
pub fn into_tcp_stream(&self) -> io::Result<TcpStream> {
let socket = Arc::try_unwrap(self.inner.clone())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))?
.into_inner()?;
let tcp_stream = TcpStream::from(socket);
Ok(tcp_stream)
}
pub fn into_tcp_listener(&self) -> io::Result<TcpListener> {
let socket = Arc::try_unwrap(self.inner.clone())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))?
.into_inner()?;
let tcp_listener = TcpListener::from(socket);
Ok(tcp_listener)
}
pub fn into_udp_socket(&self) -> io::Result<UdpSocket> {
let socket = Arc::try_unwrap(self.inner.clone())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))?
.into_inner()?;
let udp_socket = UdpSocket::from(socket);
Ok(udp_socket)
}
}
#[derive(Clone, Debug)]
pub struct AsyncTcpStream {
inner: Arc<Async<TcpStream>>,
}
impl AsyncTcpStream {
pub async fn connect(addr: SocketAddr) -> io::Result<Self> {
let stream = Async::<TcpStream>::connect(addr).await?;
Ok(AsyncTcpStream {
inner: Arc::new(stream),
})
}
pub async fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<Self> {
let stream = Async::<TcpStream>::connect(*addr)
.or(async {
Timer::after(timeout).await;
Err(std::io::ErrorKind::TimedOut.into())
})
.await?;
Ok(AsyncTcpStream {
inner: Arc::new(stream),
})
}
pub async fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.read_with(|inner| inner.local_addr()).await
}
pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.read_with(|inner| inner.peer_addr()).await
}
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
self.inner.write_with(|mut inner| inner.write(buf)).await
}
pub async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
self.inner
.write_with(|mut inner| inner.write_all(buf))
.await
}
pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read_with(|mut inner| inner.read(buf)).await
}
pub async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
self.inner
.read_with(|mut inner| inner.read_to_end(buf))
.await
}
pub async fn read_to_end_timeout(
&self,
buf: &mut Vec<u8>,
timeout: Duration,
) -> io::Result<usize> {
let mut io_error: io::Error = io::Error::new(io::ErrorKind::Other, "No response");
match self
.read_to_end(buf)
.or(async {
Timer::after(timeout).await;
Err(std::io::ErrorKind::TimedOut.into())
})
.await
{
Ok(_) => {}
Err(e) => {
io_error = e;
}
}
if buf.is_empty() {
Err(io_error)
} else {
Ok(buf.len())
}
}
pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.write_with(|inner| inner.shutdown(how)).await
}
pub async fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.read_with(|inner| inner.take_error()).await
}
pub async fn try_clone(&self) -> io::Result<Self> {
let stream = self.inner.read_with(|inner| inner.try_clone()).await?;
Ok(AsyncTcpStream {
inner: Arc::new(Async::new(stream)?),
})
}
pub async fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_read_timeout(dur))
.await
}
pub async fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_write_timeout(dur))
.await
}
pub async fn read_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_with(|inner| inner.read_timeout()).await
}
pub async fn write_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_with(|inner| inner.write_timeout()).await
}
pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_nodelay(nodelay))
.await
}
pub async fn nodelay(&self) -> io::Result<bool> {
self.inner.read_with(|inner| inner.nodelay()).await
}
pub async fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.write_with(|inner| inner.set_ttl(ttl)).await
}
pub async fn ttl(&self) -> io::Result<u32> {
self.inner.read_with(|inner| inner.ttl()).await
}
pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner
.write_with(|inner| inner.set_nonblocking(nonblocking))
.await
}
}