use std::fmt;
use std::io;
#[cfg(not(target_os = "wasi"))]
use std::net::{SocketAddr, TcpListener as StdTcpListener};
#[cfg(target_os = "wasi")]
use wasmedge_wasi_socket::{SocketAddr, TcpListener as StdTcpListener};
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::time::Sleep;
use tracing::{debug, error, trace};
use crate::common::{task, Future, Pin, Poll};
use super::accept::Accept;
#[must_use = "streams do nothing unless polled"]
pub struct AddrIncoming {
addr: SocketAddr,
listener: TcpListener,
sleep_on_errors: bool,
tcp_keepalive_timeout: Option<Duration>,
tcp_nodelay: bool,
timeout: Option<Pin<Box<Sleep>>>,
}
impl AddrIncoming {
pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
#[cfg(not(target_os = "wasi"))]
let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
#[cfg(target_os = "wasi")]
let std_listener = StdTcpListener::bind(addr, true).map_err(crate::Error::new_listen)?;
AddrIncoming::from_std(std_listener)
}
#[cfg(not(target_os = "wasi"))]
pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
std_listener
.set_nonblocking(true)
.map_err(crate::Error::new_listen)?;
let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
AddrIncoming::from_listener(listener)
}
#[cfg(target_os = "wasi")]
pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
AddrIncoming::from_listener(listener)
}
pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
AddrIncoming::new(addr)
}
pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
Ok(AddrIncoming {
listener,
addr,
sleep_on_errors: true,
tcp_keepalive_timeout: None,
tcp_nodelay: false,
timeout: None,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.addr
}
pub fn set_keepalive(&mut self, keepalive: Option<Duration>) -> &mut Self {
self.tcp_keepalive_timeout = keepalive;
self
}
pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
self.tcp_nodelay = enabled;
self
}
pub fn set_sleep_on_errors(&mut self, val: bool) {
self.sleep_on_errors = val;
}
fn poll_next_(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<TcpStream>> {
if let Some(ref mut to) = self.timeout {
ready!(Pin::new(to).poll(cx));
}
self.timeout = None;
loop {
match ready!(self.listener.poll_accept(cx)) {
Ok((socket, _)) => {
#[cfg(not(target_os = "wasi"))]
if let Some(dur) = self.tcp_keepalive_timeout {
let socket = socket2::SockRef::from(&socket);
let conf = socket2::TcpKeepalive::new().with_time(dur);
if let Err(e) = socket.set_tcp_keepalive(&conf) {
trace!("error trying to set TCP keepalive: {}", e);
}
}
#[cfg(not(target_os = "wasi"))]
if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
trace!("error trying to set TCP nodelay: {}", e);
}
return Poll::Ready(Ok(socket));
}
Err(e) => {
if is_connection_error(&e) {
debug!("accepted connection already errored: {}", e);
continue;
}
if self.sleep_on_errors {
error!("accept error: {}", e);
let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
match timeout.as_mut().poll(cx) {
Poll::Ready(()) => {
continue;
}
Poll::Pending => {
self.timeout = Some(timeout);
return Poll::Pending;
}
}
} else {
return Poll::Ready(Err(e));
}
}
}
}
}
}
impl Accept for AddrIncoming {
type Conn = TcpStream;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let result = ready!(self.poll_next_(cx));
Poll::Ready(Some(result))
}
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
}
impl fmt::Debug for AddrIncoming {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AddrIncoming")
.field("addr", &self.addr)
.field("sleep_on_errors", &self.sleep_on_errors)
.field("tcp_keepalive_timeout", &self.tcp_keepalive_timeout)
.field("tcp_nodelay", &self.tcp_nodelay)
.finish()
}
}