use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
#[cfg(not(feature = "compio"))]
pub async fn serve_tcp<F>(addr: &str, handler: F) -> std::io::Result<()>
where
F: Fn(
tokio::net::TcpStream,
SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("TCP server listening on {}", listener.local_addr()?);
let handler = Arc::new(handler);
loop {
let (stream, peer_addr) = listener.accept().await?;
let _ = stream.set_nodelay(true);
let handler = Arc::clone(&handler);
tokio::spawn(async move {
if let Err(e) = handler(stream, peer_addr).await {
tracing::error!("TCP connection error from {peer_addr}: {e}");
}
});
}
}
#[cfg(not(feature = "compio"))]
pub async fn serve_tcp_with_shutdown<F, S>(addr: &str, handler: F, signal: S) -> std::io::Result<()>
where
F: Fn(
tokio::net::TcpStream,
SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + Send + 'static,
{
serve_tcp_with_shutdown_and_drain(addr, handler, signal, std::time::Duration::from_secs(30)).await
}
#[cfg(not(feature = "compio"))]
pub async fn serve_tcp_with_shutdown_and_drain<F, S>(
addr: &str,
handler: F,
signal: S,
drain_timeout: std::time::Duration,
) -> std::io::Result<()>
where
F: Fn(
tokio::net::TcpStream,
SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + Send + 'static,
{
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("TCP server listening on {}", listener.local_addr()?);
let handler = Arc::new(handler);
let mut join_set = tokio::task::JoinSet::new();
tokio::pin!(signal);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, peer_addr) = result?;
let _ = stream.set_nodelay(true);
let handler = Arc::clone(&handler);
join_set.spawn(async move {
if let Err(e) = handler(stream, peer_addr).await {
tracing::error!("TCP connection error from {peer_addr}: {e}");
}
});
}
() = &mut signal => {
tracing::info!("TCP server shutting down, draining {} connections", join_set.len());
break;
}
}
}
let _ = tokio::time::timeout(drain_timeout, async {
while join_set.join_next().await.is_some() {}
})
.await;
Ok(())
}
#[cfg(feature = "compio")]
pub async fn serve_tcp<F>(addr: &str, handler: F) -> std::io::Result<()>
where
F: Fn(compio::net::TcpStream, SocketAddr) -> Pin<Box<dyn Future<Output = std::io::Result<()>>>>
+ Send
+ Sync
+ 'static,
{
let listener = compio::net::TcpListener::bind(addr).await?;
tracing::info!("TCP server listening on {}", listener.local_addr()?);
let handler = Arc::new(handler);
loop {
let (stream, peer_addr) = listener.accept().await?;
let _ = stream.set_nodelay(true);
let handler = Arc::clone(&handler);
compio::runtime::spawn(async move {
if let Err(e) = handler(stream, peer_addr).await {
tracing::error!("TCP connection error from {peer_addr}: {e}");
}
})
.detach();
}
}
#[cfg(feature = "compio")]
pub async fn serve_tcp_with_shutdown<F, S>(addr: &str, handler: F, signal: S) -> std::io::Result<()>
where
F: Fn(compio::net::TcpStream, SocketAddr) -> Pin<Box<dyn Future<Output = std::io::Result<()>>>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + 'static,
{
serve_tcp_with_shutdown_and_drain(addr, handler, signal, std::time::Duration::from_secs(30)).await
}
#[cfg(feature = "compio")]
pub async fn serve_tcp_with_shutdown_and_drain<F, S>(
addr: &str,
handler: F,
signal: S,
drain_timeout: std::time::Duration,
) -> std::io::Result<()>
where
F: Fn(compio::net::TcpStream, SocketAddr) -> Pin<Box<dyn Future<Output = std::io::Result<()>>>>
+ Send
+ Sync
+ 'static,
S: Future<Output = ()> + 'static,
{
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
let listener = compio::net::TcpListener::bind(addr).await?;
tracing::info!("TCP server listening on {}", listener.local_addr()?);
let handler = Arc::new(handler);
let inflight = Arc::new(AtomicUsize::new(0));
let drain_notify = Arc::new(tokio::sync::Notify::new());
let signal = std::pin::pin!(signal);
let mut signal = signal;
loop {
let accept_fut = listener.accept();
let accept_fut = std::pin::pin!(accept_fut);
match futures_util::future::select(accept_fut, &mut signal).await {
futures_util::future::Either::Left((result, _)) => {
let (stream, peer_addr) = result?;
let _ = stream.set_nodelay(true);
let handler = Arc::clone(&handler);
let inflight = Arc::clone(&inflight);
let drain_notify = Arc::clone(&drain_notify);
inflight.fetch_add(1, Ordering::SeqCst);
compio::runtime::spawn(async move {
if let Err(e) = handler(stream, peer_addr).await {
tracing::error!("TCP connection error from {peer_addr}: {e}");
}
inflight.fetch_sub(1, Ordering::SeqCst);
drain_notify.notify_waiters();
})
.detach();
}
futures_util::future::Either::Right(_) => {
tracing::info!(
"TCP server shutting down, draining {} connections",
inflight.load(Ordering::SeqCst)
);
break;
}
}
}
let drain_deadline = std::time::Instant::now() + drain_timeout;
while inflight.load(Ordering::SeqCst) > 0 {
let now = std::time::Instant::now();
if now >= drain_deadline {
break;
}
let remaining = drain_deadline - now;
if compio::time::timeout(remaining, drain_notify.notified())
.await
.is_err()
{
break;
}
}
Ok(())
}