use super::ConnectionService;
use super::config::ServerConfig;
use super::listener::{Listen, ListenersBuilder};
#[cfg(feature = "metrics")]
use super::metrics::{
record_accept_err, record_accept_ok, record_forced_shutdown, record_graceful_shutdown,
record_handler_duration, record_handler_err, record_handler_ok, record_handler_timeout,
record_rate_limiter_closed, record_rate_limiter_timeout, record_shutdown_duration,
record_wait_duration,
};
use crate::core::socket_addr::SocketAddr as CoreSocketAddr;
use std::io;
use std::net::SocketAddr;
#[cfg(not(target_os = "windows"))]
use std::path::Path;
use std::sync::Arc;
#[cfg(test)]
use std::sync::OnceLock;
#[cfg(feature = "scheduler")]
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::signal;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
#[cfg(test)]
static SHUTDOWN_NOTIFY: OnceLock<tokio::sync::Notify> = OnceLock::new();
#[cfg(feature = "scheduler")]
static SCHEDULER_RUNNING: AtomicBool = AtomicBool::new(false);
#[cfg(test)]
fn trigger_test_shutdown() {
SHUTDOWN_NOTIFY
.get_or_init(tokio::sync::Notify::new)
.notify_waiters();
}
fn test_shutdown_future() -> impl std::future::Future<Output = ()> {
#[cfg(test)]
{
SHUTDOWN_NOTIFY
.get_or_init(tokio::sync::Notify::new)
.notified()
}
#[cfg(not(test))]
{
futures_util::future::pending::<()>()
}
}
#[cfg(feature = "scheduler")]
fn ensure_scheduler_running() {
if SCHEDULER_RUNNING
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
tokio::spawn(async {
use crate::scheduler::{SCHEDULER, Scheduler};
Scheduler::schedule(SCHEDULER.clone()).await;
SCHEDULER_RUNNING.store(false, Ordering::Release);
});
}
}
type ListenCallback = Box<dyn Fn(&[CoreSocketAddr]) + Send + Sync>;
#[derive(Clone, Copy, Debug)]
pub struct RateLimiterConfig {
pub capacity: usize,
pub refill_every: Duration,
pub max_wait: Duration,
}
pub struct NetServer {
listeners_builder: ListenersBuilder,
shutdown_callback: Option<Box<dyn Fn() + Send + Sync>>,
listen_callback: Option<ListenCallback>,
rate_limiter: Option<RateLimiter>,
shutdown_cfg: ShutdownConfig,
config: ServerConfig,
}
impl Default for NetServer {
fn default() -> Self {
Self::new()
}
}
impl NetServer {
pub fn new() -> Self {
Self {
listeners_builder: ListenersBuilder::new(),
shutdown_callback: None,
listen_callback: None,
rate_limiter: None,
shutdown_cfg: ShutdownConfig::default(),
config: ServerConfig::default(),
}
}
pub(crate) fn from_parts(
listeners_builder: ListenersBuilder,
shutdown_callback: Option<Box<dyn Fn() + Send + Sync>>,
listen_callback: Option<ListenCallback>,
config: ServerConfig,
) -> Self {
Self {
listeners_builder,
shutdown_callback,
listen_callback,
rate_limiter: None,
shutdown_cfg: ShutdownConfig::default(),
config,
}
}
#[inline]
pub fn bind(mut self, addr: SocketAddr) -> Result<Self, io::Error> {
self.listeners_builder.bind(addr)?;
Ok(self)
}
#[cfg(not(target_os = "windows"))]
#[inline]
pub fn bind_unix<P: AsRef<Path>>(mut self, path: P) -> Result<Self, io::Error> {
self.listeners_builder.bind_unix(path)?;
Ok(self)
}
#[inline]
pub fn listen<T: Listen + Send + Sync + 'static>(mut self, listener: T) -> Self {
self.listeners_builder.add_listener(Box::new(listener));
self
}
pub fn on_listen<F>(mut self, callback: F) -> Self
where
F: Fn(&[CoreSocketAddr]) + Send + Sync + 'static,
{
self.listen_callback = Some(Box::new(callback));
self
}
pub fn set_shutdown_callback<F>(mut self, callback: F) -> Self
where
F: Fn() + Send + Sync + 'static,
{
self.shutdown_callback = Some(Box::new(callback));
self
}
pub fn with_rate_limiter(mut self, config: RateLimiterConfig) -> Self {
self.rate_limiter = Some(RateLimiter::new(
config.capacity,
config.refill_every,
config.max_wait,
));
self
}
pub fn with_shutdown(mut self, graceful_wait: Duration) -> Self {
self.shutdown_cfg.graceful_wait = graceful_wait;
self
}
pub async fn serve<H>(self, handler: H)
where
H: ConnectionService + 'static,
{
if let Err(e) = self.serve_arc(std::sync::Arc::new(handler)).await {
panic!("server loop failed: {}", e);
}
}
pub fn run<H>(self, handler: H)
where
H: ConnectionService + 'static,
{
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to build Tokio runtime");
runtime.block_on(async move {
if let Err(e) = self.serve_arc(std::sync::Arc::new(handler)).await {
panic!("server loop failed: {}", e);
}
})
}
pub async fn serve_arc<H>(self, handler: std::sync::Arc<H>) -> io::Result<()>
where
H: ConnectionService + 'static,
{
self.serve_dyn(handler as std::sync::Arc<dyn ConnectionService>)
.await
}
pub async fn serve_dyn(self, handler: std::sync::Arc<dyn ConnectionService>) -> io::Result<()> {
self.serve_connection_loop(handler).await
}
async fn serve_connection_loop(
mut self,
handler: std::sync::Arc<dyn ConnectionService>,
) -> io::Result<()> {
let loop_started = Instant::now();
let mut listeners = self.listeners_builder.listen()?;
let addrs = listeners.local_addrs().to_vec();
let handler_timeout = self.config.connection_limits.handler_timeout;
if let Some(cb) = &self.listen_callback {
(cb)(&addrs);
} else {
if addrs.len() == 1 {
tracing::info!("listening on {}", format!("{:?}", addrs[0]));
} else {
let lines = addrs
.iter()
.map(|a| format!(" - {:?}", a))
.collect::<Vec<_>>()
.join("\n");
tracing::info!("listening on:\n{}", lines);
}
}
#[cfg(feature = "scheduler")]
ensure_scheduler_running();
let mut join_set: JoinSet<()> = JoinSet::new();
let mut shutdown = ShutdownHandle::new(self.shutdown_callback.take(), self.shutdown_cfg);
let rate = self_rate_limiter(self.rate_limiter.as_ref());
let mut refill_handle = rate.as_ref().map(|r| r.spawn_refill_task());
loop {
tokio::select! {
biased;
_ = shutdown.signal() => {
tracing::info!(
elapsed = ?loop_started.elapsed(),
tasks = join_set.len(),
"shutdown signal received"
);
break;
}
accept_result = listeners.accept() => {
match accept_result {
None => {
tracing::info!(elapsed = ?loop_started.elapsed(), "listener closed, shutting down");
break;
}
Some(Ok((stream, peer_addr))) => {
#[cfg(feature = "metrics")]
record_accept_ok();
if let Some(rate) = &rate {
let semaphore = rate.semaphore.clone();
let max_wait = rate.max_wait;
let handler = handler.clone();
let peer = peer_addr.clone();
let accepted_at = Instant::now();
tracing::info!(%peer, "accepted connection");
join_set.spawn(async move {
match tokio::time::timeout(max_wait, semaphore.acquire_owned()).await {
Ok(Ok(_permit)) => {
let wait_cost = accepted_at.elapsed();
#[cfg(feature = "metrics")]
record_wait_duration(wait_cost.as_nanos() as u64);
if let Some(timeout) = handler_timeout {
match tokio::time::timeout(timeout, handler.call(stream, peer.clone())).await {
Ok(res) => {
if let Err(err) = res {
tracing::error!("Failed to serve connection: {:?}", err);
} else {
#[cfg(feature = "metrics")]
record_handler_duration(accepted_at.elapsed().as_nanos() as u64);
tracing::debug!(%peer, wait = ?wait_cost, handle = ?accepted_at.elapsed(), "connection served");
}
}
Err(_) => {
#[cfg(feature = "metrics")]
record_handler_timeout();
tracing::warn!(
%peer,
wait = ?wait_cost,
"Connection handler timed out for peer"
);
}
}
} else {
let handle_started = Instant::now();
if let Err(err) = handler.call(stream, peer.clone()).await {
#[cfg(feature = "metrics")]
record_handler_err();
tracing::error!("Failed to serve connection: {:?}", err);
} else {
#[cfg(feature = "metrics")]
record_handler_ok();
#[cfg(feature = "metrics")]
record_handler_duration(handle_started.elapsed().as_nanos() as u64);
tracing::debug!(%peer, wait = ?wait_cost, handle = ?handle_started.elapsed(), "connection served");
}
}
}
Ok(Err(_)) => {
#[cfg(feature = "metrics")]
record_rate_limiter_closed();
tracing::warn!(%peer, "Rate limiter closed, dropping connection");
}
Err(_) => {
#[cfg(feature = "metrics")]
record_rate_limiter_timeout();
tracing::warn!(%peer, "Rate limiter timeout, dropping connection");
}
}
});
} else {
let handler = handler.clone();
let peer = peer_addr.clone();
let accepted_at = Instant::now();
tracing::info!(%peer, "accepted connection");
join_set.spawn(async move {
if let Some(timeout) = handler_timeout {
match tokio::time::timeout(timeout, handler.call(stream, peer.clone())).await {
Ok(res) => {
if let Err(err) = res {
#[cfg(feature = "metrics")]
record_handler_err();
tracing::error!("Failed to serve connection: {:?}", err);
} else {
#[cfg(feature = "metrics")]
record_handler_ok();
#[cfg(feature = "metrics")]
record_handler_duration(accepted_at.elapsed().as_nanos() as u64);
tracing::debug!(%peer, handle = ?accepted_at.elapsed(), "connection served");
}
}
Err(_) => {
#[cfg(feature = "metrics")]
record_handler_timeout();
tracing::warn!(%peer, "Connection handler timed out for peer");
}
}
} else {
let handle_started = Instant::now();
if let Err(err) = handler.call(stream, peer.clone()).await {
#[cfg(feature = "metrics")]
record_handler_err();
tracing::error!("Failed to serve connection: {:?}", err);
} else {
#[cfg(feature = "metrics")]
record_handler_ok();
tracing::debug!(%peer, handle = ?handle_started.elapsed(), "connection served");
}
}
});
}
}
Some(Err(e)) => {
#[cfg(feature = "metrics")]
record_accept_err();
tracing::error!(error = ?e, tasks = join_set.len(), "accept connection failed");
}
}
}
Some(join_result) = join_set.join_next() => {
if let Err(err) = join_result {
tracing::error!(error = ?err, "connection task panicked");
}
}
_ = test_shutdown_future() => {
tracing::info!("test shutdown notify received");
break;
}
}
}
if shutdown.shutdown_cfg.graceful_wait > Duration::from_millis(0) {
let graceful_started = Instant::now();
let _ = tokio::time::timeout(shutdown.shutdown_cfg.graceful_wait, async {
while let Some(join_result) = join_set.join_next().await {
if let Err(err) = join_result
&& err.is_panic()
{
tracing::error!(error = ?err, "connection task panicked during graceful shutdown");
}
}
})
.await;
tracing::debug!(
elapsed = ?graceful_started.elapsed(),
remaining = join_set.len(),
"graceful shutdown wait finished"
);
#[cfg(feature = "metrics")]
record_graceful_shutdown();
#[cfg(feature = "metrics")]
record_shutdown_duration("graceful", graceful_started.elapsed().as_nanos() as u64);
}
if let Some(h) = &mut refill_handle {
h.abort();
let _ = h.await;
}
join_set.abort_all();
let abort_started = Instant::now();
while let Some(join_result) = join_set.join_next().await {
if let Err(err) = join_result
&& err.is_panic()
{
tracing::error!(error = ?err, "connection task panicked during forced shutdown");
}
}
tracing::debug!(elapsed = ?abort_started.elapsed(), "forced shutdown complete");
#[cfg(feature = "metrics")]
record_forced_shutdown();
#[cfg(feature = "metrics")]
record_shutdown_duration("forced", abort_started.elapsed().as_nanos() as u64);
Ok(())
}
}
#[derive(Clone)]
struct RateLimiter {
semaphore: Arc<Semaphore>,
max_wait: Duration,
capacity: usize,
refill_every: Duration,
}
impl RateLimiter {
fn new(capacity: usize, refill_every: Duration, max_wait: Duration) -> Self {
let semaphore = Arc::new(Semaphore::new(capacity));
Self {
semaphore,
max_wait,
capacity,
refill_every,
}
}
fn spawn_refill_task(&self) -> tokio::task::JoinHandle<()> {
let sem = self.semaphore.clone();
let capacity = self.capacity;
let refill_every = self.refill_every;
tokio::spawn(async move {
let mut ticker = tokio::time::interval(refill_every);
loop {
ticker.tick().await;
if sem.available_permits() < capacity {
sem.add_permits(1);
}
}
})
}
}
#[derive(Clone, Copy)]
struct ShutdownConfig {
graceful_wait: Duration,
}
impl Default for ShutdownConfig {
fn default() -> Self {
Self {
graceful_wait: Duration::from_secs(0),
}
}
}
fn self_rate_limiter(rate: Option<&RateLimiter>) -> Option<RateLimiter> {
rate.cloned()
}
struct ShutdownHandle {
shutdown_callback: Option<Box<dyn Fn() + Send + Sync>>,
shutdown_cfg: ShutdownConfig,
}
impl ShutdownHandle {
fn new(callback: Option<Box<dyn Fn() + Send + Sync>>, shutdown_cfg: ShutdownConfig) -> Self {
let shutdown_callback = callback;
Self {
shutdown_callback,
shutdown_cfg,
}
}
async fn signal(&mut self) {
#[cfg(unix)]
{
let mut term =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler");
tokio::select! {
_ = signal::ctrl_c() => (),
_ = term.recv() => (),
}
}
#[cfg(not(unix))]
{
tokio::select! {
_ = signal::ctrl_c() => (),
}
}
if let Some(cb) = &self.shutdown_callback {
(cb)();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::connection;
use crate::server::connection::BoxedConnection;
use crate::server::listener::Listen;
use crate::{AcceptFuture, BoxError};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[tokio::test]
async fn test_rate_limiter_capacity_limit() {
let limiter = RateLimiter::new(2, Duration::from_secs(60), Duration::from_secs(1));
let _permit1 = limiter
.semaphore
.clone()
.acquire_owned()
.await
.expect("first permit should be available");
let _permit2 = limiter
.semaphore
.clone()
.acquire_owned()
.await
.expect("second permit should be available");
assert_eq!(limiter.semaphore.available_permits(), 0);
}
#[tokio::test]
async fn test_rate_limiter_refill_adds_permit() {
let limiter = RateLimiter::new(1, Duration::from_millis(20), Duration::from_millis(10));
let _permit = limiter
.semaphore
.clone()
.acquire_owned()
.await
.expect("permit should be available");
assert_eq!(limiter.semaphore.available_permits(), 0);
let handle = limiter.spawn_refill_task();
tokio::time::sleep(Duration::from_millis(30)).await;
assert!(limiter.semaphore.available_permits() >= 1);
handle.abort();
let _ = handle.await;
}
struct TestListener {
addr: std::net::SocketAddr,
accepts: Arc<AtomicUsize>,
once_conn: tokio::sync::Mutex<Option<BoxedConnection>>,
}
impl TestListener {
fn new(conn: BoxedConnection, addr: std::net::SocketAddr) -> Self {
Self {
addr,
accepts: Arc::new(AtomicUsize::new(0)),
once_conn: tokio::sync::Mutex::new(Some(conn)),
}
}
}
impl Listen for TestListener {
fn accept(&self) -> AcceptFuture<'_> {
let accepts = self.accepts.clone();
let addr = self.addr;
let once = self.once_conn.try_lock();
if let Ok(mut guard) = once
&& let Some(conn) = guard.take()
{
accepts.fetch_add(1, Ordering::SeqCst);
return Box::pin(async move {
Ok((conn, crate::core::socket_addr::SocketAddr::from(addr)))
});
}
Box::pin(async move {
futures_util::future::pending::<
std::io::Result<(
Box<dyn connection::Connection + Send + Sync>,
crate::core::socket_addr::SocketAddr,
)>,
>()
.await
})
}
fn local_addr(&self) -> std::io::Result<crate::core::socket_addr::SocketAddr> {
Ok(crate::core::socket_addr::SocketAddr::from(self.addr))
}
}
#[tokio::test]
async fn test_net_server_on_listen_and_handler_called_then_abort() {
let (_a, b) = tokio::io::duplex(8);
let boxed: BoxedConnection = Box::new(b);
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TestListener::new(boxed, addr);
let on_listen_called = Arc::new(AtomicBool::new(false));
let flag = on_listen_called.clone();
let handler =
|_s: BoxedConnection, _p: CoreSocketAddr| async move { Ok::<(), BoxError>(()) };
let server = NetServer::new().listen(listener).on_listen(move |_addrs| {
flag.store(true, Ordering::SeqCst);
});
let jh = tokio::spawn(async move {
server.serve(handler).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(on_listen_called.load(Ordering::SeqCst));
jh.abort();
let _ = jh.await;
}
struct TestErrListener {
addr: std::net::SocketAddr,
sent_err: Arc<AtomicBool>,
}
impl TestErrListener {
fn new(addr: std::net::SocketAddr) -> Self {
Self {
addr,
sent_err: Arc::new(AtomicBool::new(false)),
}
}
}
impl Listen for TestErrListener {
fn accept(&self) -> AcceptFuture<'_> {
let sent = self.sent_err.clone();
Box::pin(async move {
if !sent.swap(true, Ordering::SeqCst) {
Err(std::io::Error::other("accept failed (test)"))
} else {
futures_util::future::pending::<
std::io::Result<(
Box<dyn connection::Connection + Send + Sync>,
crate::core::socket_addr::SocketAddr,
)>,
>()
.await
}
})
}
fn local_addr(&self) -> std::io::Result<crate::core::socket_addr::SocketAddr> {
Ok(crate::core::socket_addr::SocketAddr::from(self.addr))
}
}
#[tokio::test]
async fn test_net_server_accept_error_path() {
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TestErrListener::new(addr);
let on_listen_called = Arc::new(AtomicBool::new(false));
let flag = on_listen_called.clone();
let handler_calls = Arc::new(AtomicUsize::new(0));
let hc = handler_calls.clone();
let handler = move |_s: BoxedConnection, _p: CoreSocketAddr| {
let hc = hc.clone();
async move {
hc.fetch_add(1, Ordering::SeqCst);
Ok::<(), BoxError>(())
}
};
let server = NetServer::new().listen(listener).on_listen(move |_addrs| {
flag.store(true, Ordering::SeqCst);
});
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(on_listen_called.load(Ordering::SeqCst));
assert_eq!(handler_calls.load(Ordering::SeqCst), 0);
jh.abort();
let _ = jh.await;
}
#[tokio::test]
async fn test_net_server_rate_limiter_timeout_drops_connection() {
let (_a, b) = tokio::io::duplex(8);
let boxed: BoxedConnection = Box::new(b);
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TestListener::new(boxed, addr);
let calls = Arc::new(AtomicUsize::new(0));
let calls_cl = calls.clone();
let handler = move |_s: BoxedConnection, _p: CoreSocketAddr| {
let calls_cl = calls_cl.clone();
async move {
calls_cl.fetch_add(1, Ordering::SeqCst);
Ok::<(), BoxError>(())
}
};
let server = NetServer::new()
.with_rate_limiter(RateLimiterConfig {
capacity: 0,
refill_every: Duration::from_millis(100),
max_wait: Duration::from_millis(5),
})
.listen(listener);
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(
calls.load(Ordering::SeqCst),
0,
"handler should not be called due to timeout"
);
jh.abort();
let _ = jh.await;
}
#[tokio::test]
async fn test_net_server_handler_panic_logged() {
let (_a, b) = tokio::io::duplex(8);
let boxed: BoxedConnection = Box::new(b);
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TestListener::new(boxed, addr);
let handler = |_s: BoxedConnection, _p: CoreSocketAddr| async move {
panic!("panic in handler (test)");
#[allow(unreachable_code)]
Ok::<(), BoxError>(())
};
let server = NetServer::new().listen(listener);
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(50)).await;
jh.abort();
let _ = jh.await;
}
#[tokio::test]
async fn test_net_server_graceful_shutdown_timeout() {
let (_a, b) = tokio::io::duplex(8);
let boxed: BoxedConnection = Box::new(b);
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TestListener::new(boxed, addr);
let handler = |_s: BoxedConnection, _p: CoreSocketAddr| async move {
tokio::time::sleep(Duration::from_millis(50)).await;
Ok::<(), BoxError>(())
};
let server = NetServer::new()
.with_shutdown(Duration::from_millis(10))
.listen(listener);
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(10)).await;
trigger_test_shutdown();
let _ = jh.await;
}
#[tokio::test]
async fn test_net_server_rate_limiter_permit_calls_handler() {
let (_a, b) = tokio::io::duplex(8);
let boxed: BoxedConnection = Box::new(b);
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TestListener::new(boxed, addr);
let calls = Arc::new(AtomicUsize::new(0));
let calls_cl = calls.clone();
let handler = move |_s: BoxedConnection, _p: CoreSocketAddr| {
let calls_cl = calls_cl.clone();
async move {
calls_cl.fetch_add(1, Ordering::SeqCst);
Ok::<(), BoxError>(())
}
};
let server = NetServer::new()
.with_rate_limiter(RateLimiterConfig {
capacity: 1,
refill_every: Duration::from_millis(1000),
max_wait: Duration::from_millis(50),
})
.listen(listener);
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"handler should be called exactly once"
);
jh.abort();
let _ = jh.await;
}
struct TestListenerDelay {
addr: std::net::SocketAddr,
once_conn: tokio::sync::Mutex<Option<BoxedConnection>>,
delay: Duration,
}
impl TestListenerDelay {
fn new(conn: BoxedConnection, addr: std::net::SocketAddr, delay: Duration) -> Self {
Self {
addr,
once_conn: tokio::sync::Mutex::new(Some(conn)),
delay,
}
}
}
impl Listen for TestListenerDelay {
fn accept(&self) -> AcceptFuture<'_> {
let delay = self.delay;
let addr = self.addr;
let once = self.once_conn.try_lock();
if let Ok(mut guard) = once
&& let Some(conn) = guard.take()
{
return Box::pin(async move {
tokio::time::sleep(delay).await;
Ok((conn, crate::core::socket_addr::SocketAddr::from(addr)))
});
}
Box::pin(async move {
futures_util::future::pending::<
std::io::Result<(
Box<dyn connection::Connection + Send + Sync>,
crate::core::socket_addr::SocketAddr,
)>,
>()
.await
})
}
fn local_addr(&self) -> std::io::Result<crate::core::socket_addr::SocketAddr> {
Ok(crate::core::socket_addr::SocketAddr::from(self.addr))
}
}
#[tokio::test]
async fn test_net_server_multi_listeners_race() {
let (_a1, b1) = tokio::io::duplex(8);
let boxed1: BoxedConnection = Box::new(b1);
let (_a2, b2) = tokio::io::duplex(8);
let boxed2: BoxedConnection = Box::new(b2);
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let fast = TestListenerDelay::new(boxed1, addr, Duration::from_millis(1));
let slow = TestListenerDelay::new(boxed2, addr, Duration::from_millis(50));
let calls = Arc::new(AtomicUsize::new(0));
let calls_cl = calls.clone();
let handler = move |_s: BoxedConnection, _p: CoreSocketAddr| {
let calls_cl = calls_cl.clone();
async move {
calls_cl.fetch_add(1, Ordering::SeqCst);
Ok::<(), BoxError>(())
}
};
let server = NetServer::new().listen(fast).listen(slow);
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(80)).await;
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"only fast listener's connection handled"
);
jh.abort();
let _ = jh.await;
}
#[tokio::test]
async fn test_net_server_on_listen_addrs_content() {
let (_a, b) = tokio::io::duplex(8);
let boxed: BoxedConnection = Box::new(b);
let addr: std::net::SocketAddr = "127.0.0.1:5555".parse().unwrap();
let listener = TestListener::new(boxed, addr);
let seen = Arc::new(tokio::sync::Mutex::new(Vec::<CoreSocketAddr>::new()));
let seen_cl = seen.clone();
let server = NetServer::new().listen(listener).on_listen(move |addrs| {
let addrs = addrs.to_vec();
let seen_cl = seen_cl.clone();
tokio::spawn(async move {
*seen_cl.lock().await = addrs;
});
});
let handler =
|_s: BoxedConnection, _p: CoreSocketAddr| async move { Ok::<(), BoxError>(()) };
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(20)).await;
let addrs = seen.lock().await.clone();
assert_eq!(addrs.len(), 1);
assert!(matches!(addrs[0], CoreSocketAddr::Tcp(_)));
jh.abort();
let _ = jh.await;
}
#[tokio::test]
async fn test_net_server_on_listen_multi_addrs() {
let (_a1, b1) = tokio::io::duplex(8);
let boxed1: BoxedConnection = Box::new(b1);
let (_a2, b2) = tokio::io::duplex(8);
let boxed2: BoxedConnection = Box::new(b2);
let addr1: std::net::SocketAddr = "127.0.0.1:60000".parse().unwrap();
let addr2: std::net::SocketAddr = "127.0.0.1:60001".parse().unwrap();
let l1 = TestListener::new(boxed1, addr1);
let l2 = TestListener::new(boxed2, addr2);
let seen = Arc::new(tokio::sync::Mutex::new(Vec::<CoreSocketAddr>::new()));
let seen_cl = seen.clone();
let server = NetServer::new()
.listen(l1)
.listen(l2)
.on_listen(move |addrs| {
let addrs = addrs.to_vec();
let seen_cl = seen_cl.clone();
tokio::spawn(async move {
*seen_cl.lock().await = addrs;
});
});
let handler =
|_s: BoxedConnection, _p: CoreSocketAddr| async move { Ok::<(), BoxError>(()) };
let jh = tokio::spawn(async move { server.serve(handler).await });
tokio::time::sleep(Duration::from_millis(20)).await;
let addrs = seen.lock().await.clone();
assert_eq!(addrs.len(), 2);
jh.abort();
let _ = jh.await;
}
}