use std::time::Duration;
pub trait Listener: Send + 'static {
type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static;
type Addr: Clone + Send + Sync + 'static;
fn accept(&mut self) -> impl std::future::Future<Output = (Self::Io, Self::Addr)> + Send;
fn local_addr(&self) -> std::io::Result<Self::Addr>;
}
pub trait ListenerExt: Listener + Sized {
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
where
F: FnMut(&mut Self::Io) + Send + 'static,
{
TapIo {
listener: self,
tap_fn,
}
}
}
impl<L: Listener> ListenerExt for L {}
impl Listener for tokio::net::TcpListener {
type Io = tokio::net::TcpStream;
type Addr = std::net::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let mut backoff = AcceptBackoff::new();
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => backoff.handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> std::io::Result<Self::Addr> {
Self::local_addr(self)
}
}
#[derive(Debug)]
pub struct TcpListenerWithOptions {
inner: tokio::net::TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
}
impl TcpListenerWithOptions {
pub fn new<A: std::net::ToSocketAddrs>(
addr: A,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
let std_listener = std::net::TcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(std_listener)?;
Ok(Self::from_listener(listener, nodelay, keepalive))
}
pub fn from_listener(
listener: tokio::net::TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
) -> Self {
Self {
inner: listener,
nodelay,
keepalive,
}
}
fn set_accepted_socket_options(&self, stream: &tokio::net::TcpStream) {
if self.nodelay && let Err(e) = stream.set_nodelay(true) {
tracing::warn!("error trying to set TCP nodelay: {}", e);
}
if let Some(timeout) = self.keepalive {
let sock_ref = socket2::SockRef::from(&stream);
let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
tracing::warn!("error trying to set TCP keepalive: {}", e);
}
}
}
}
impl Listener for TcpListenerWithOptions {
type Io = tokio::net::TcpStream;
type Addr = std::net::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let (io, addr) = Listener::accept(&mut self.inner).await;
self.set_accepted_socket_options(&io);
(io, addr)
}
#[inline]
fn local_addr(&self) -> std::io::Result<Self::Addr> {
Listener::local_addr(&self.inner)
}
}
pub struct TapIo<L, F> {
listener: L,
tap_fn: F,
}
impl<L, F> std::fmt::Debug for TapIo<L, F>
where
L: Listener + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TapIo")
.field("listener", &self.listener)
.finish_non_exhaustive()
}
}
impl<L, F> Listener for TapIo<L, F>
where
L: Listener,
F: FnMut(&mut L::Io) + Send + 'static,
{
type Io = L::Io;
type Addr = L::Addr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let (mut io, addr) = self.listener.accept().await;
(self.tap_fn)(&mut io);
(io, addr)
}
fn local_addr(&self) -> std::io::Result<Self::Addr> {
self.listener.local_addr()
}
}
struct AcceptBackoff {
next_delay: Duration,
}
impl AcceptBackoff {
const MIN: Duration = Duration::from_millis(5);
const MAX: Duration = Duration::from_secs(1);
fn new() -> Self {
Self {
next_delay: Self::MIN,
}
}
async fn handle_accept_error(&mut self, e: std::io::Error) {
if is_connection_error(&e) {
return;
}
tracing::error!(backoff = ?self.next_delay, "accept error: {e}");
tokio::time::sleep(self.next_delay).await;
self.next_delay = (self.next_delay * 2).min(Self::MAX);
}
}
fn is_connection_error(e: &std::io::Error) -> bool {
use std::io::ErrorKind;
matches!(
e.kind(),
ErrorKind::ConnectionRefused
| ErrorKind::ConnectionAborted
| ErrorKind::ConnectionReset
| ErrorKind::BrokenPipe
| ErrorKind::Interrupted
| ErrorKind::WouldBlock
| ErrorKind::TimedOut
)
}