use std::convert::TryInto;
use std::future::Future;
use std::mem::transmute;
use std::net::{SocketAddr, TcpListener as StdListener, ToSocketAddrs};
use std::pin::Pin;
use std::task::{self, Poll};
use std::time::Duration;
use std::{fmt, io, matches};
use roa_core::{Accept, AddrStream};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{sleep, Sleep};
use tracing::{debug, error, trace};
#[must_use = "streams do nothing unless polled"]
pub struct TcpIncoming {
addr: SocketAddr,
listener: Box<TcpListener>,
sleep_on_errors: bool,
tcp_nodelay: bool,
timeout: Option<Pin<Box<Sleep>>>,
accept: Option<Pin<BoxedAccept<'static>>>,
}
type BoxedAccept<'a> =
Box<dyn 'a + Future<Output = io::Result<(TcpStream, SocketAddr)>> + Send + Sync>;
impl TcpIncoming {
pub fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
let listener = StdListener::bind(addr)?;
TcpIncoming::from_std(listener)
}
pub fn from_std(listener: StdListener) -> io::Result<Self> {
let addr = listener.local_addr()?;
listener.set_nonblocking(true)?;
Ok(TcpIncoming {
listener: Box::new(listener.try_into()?),
addr,
sleep_on_errors: true,
tcp_nodelay: false,
timeout: None,
accept: None,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.addr
}
pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
self.tcp_nodelay = enabled;
self
}
pub fn set_sleep_on_errors(&mut self, val: bool) {
self.sleep_on_errors = val;
}
fn poll_stream(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
if let Some(ref mut to) = self.timeout {
futures::ready!(Pin::new(to).poll(cx));
}
self.timeout = None;
loop {
if self.accept.is_none() {
let accept: Pin<BoxedAccept<'_>> = Box::pin(self.listener.accept());
self.accept = Some(unsafe { transmute(accept) });
}
if let Some(f) = &mut self.accept {
match futures::ready!(f.as_mut().poll(cx)) {
Ok((socket, addr)) => {
if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
trace!("error trying to set TCP nodelay: {}", e);
}
self.accept = None;
return Poll::Ready(Ok((socket, addr)));
}
Err(e) => {
if is_connection_error(&e) {
debug!("accepted connection already errored: {}", e);
continue;
}
if self.sleep_on_errors {
error!("accept error: {}", e);
let mut timeout = Box::pin(sleep(Duration::from_secs(1)));
match timeout.as_mut().poll(cx) {
Poll::Ready(()) => {
continue;
}
Poll::Pending => {
self.timeout = Some(timeout);
return Poll::Pending;
}
}
} else {
return Poll::Ready(Err(e));
}
}
}
}
}
}
}
impl Accept for TcpIncoming {
type Conn = AddrStream<TcpStream>;
type Error = io::Error;
#[inline]
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let (stream, addr) = futures::ready!(self.poll_stream(cx))?;
Poll::Ready(Some(Ok(AddrStream::new(addr, stream))))
}
}
impl Drop for TcpIncoming {
fn drop(&mut self) {
self.accept = None;
}
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
}
impl fmt::Debug for TcpIncoming {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TcpIncoming")
.field("addr", &self.addr)
.field("sleep_on_errors", &self.sleep_on_errors)
.field("tcp_nodelay", &self.tcp_nodelay)
.finish()
}
}