use crate::RuntimeError;
use std::future::Future;
use std::sync::Arc;
pub(crate) trait Acceptor {
type Accepted;
fn accept(&self) -> impl Future<Output = Result<Self::Accepted, std::io::Error>> + Send + '_;
}
impl Acceptor for tokio::net::TcpListener {
type Accepted = (tokio::net::TcpStream, std::net::SocketAddr);
fn accept(&self) -> impl Future<Output = Result<Self::Accepted, std::io::Error>> + Send + '_ {
tokio::net::TcpListener::accept(self)
}
}
impl Acceptor for tokio::net::UnixListener {
type Accepted = tokio::net::UnixStream;
async fn accept(&self) -> Result<Self::Accepted, std::io::Error> {
let (stream, _addr) = tokio::net::UnixListener::accept(self).await?;
Ok(stream)
}
}
pub(crate) async fn accept_loop<L, F, Fut>(
listener: &L,
shutdown_notify: &tokio::sync::Notify,
conn_limit: Option<&Arc<tokio::sync::Semaphore>>,
on_accept: F,
) -> Result<(), RuntimeError>
where
L: Acceptor,
F: Fn(L::Accepted) -> Fut,
Fut: Future<Output = ()> + Send + 'static,
{
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok(accepted) => {
spawn_with_limit(conn_limit, on_accept(accepted)).await;
}
Err(e) if crate::error::is_transient_accept_error(&e) => {
tracing::warn!("accept: fd limit reached, backing off");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
Err(e) => return Err(e.into()),
}
}
() = shutdown_notify.notified() => {
return Ok(());
}
}
}
}
async fn spawn_with_limit<Fut>(conn_limit: Option<&Arc<tokio::sync::Semaphore>>, fut: Fut)
where
Fut: Future<Output = ()> + Send + 'static,
{
let permit = match conn_limit {
None => {
tokio::spawn(fut);
return;
}
Some(sem) => Arc::clone(sem).acquire_owned().await,
};
if let Ok(permit) = permit {
tokio::spawn(async move {
fut.await;
drop(permit);
});
}
}
pub(crate) async fn tls_handshake(
stream: tokio::net::TcpStream,
acceptor: &tokio_rustls::TlsAcceptor,
) -> Option<tokio_rustls::server::TlsStream<tokio::net::TcpStream>> {
let result =
tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream)).await;
match result {
Ok(Ok(s)) => Some(s),
Ok(Err(e)) if crate::error::is_benign_io(&e) => None,
Ok(Err(e)) => {
tracing::warn!("TLS handshake error: {e}");
None
}
Err(_) => None,
}
}