use std::any::Any;
use std::collections::hash_map::RandomState;
use std::future::Future;
use std::hash::BuildHasher;
use std::net::SocketAddr;
use std::num::NonZeroU64;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use bytes::Bytes;
use http::Response;
use http::StatusCode;
use http::header;
use http_body_util::Full;
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use hyper_util::rt::TokioTimer;
use hyper_util::server::conn::auto::Builder as AutoBuilder;
use hyper_util::server::graceful::GracefulConnection;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio::task::JoinSet;
use tower::Service;
use tower::ServiceBuilder;
use tower_http::catch_panic::CatchPanicLayer;
use crate::codec::content_type;
use crate::dispatcher::Dispatcher;
use crate::error::ConnectError;
use crate::error::ErrorCode;
use crate::router::Router;
use crate::service::ConnectRpcService;
#[derive(Clone, Debug)]
pub struct PeerAddr(pub SocketAddr);
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
#[derive(Clone, Debug)]
pub struct PeerCerts(pub Arc<[rustls::pki_types::CertificateDer<'static>]>);
#[derive(Clone, Debug)]
struct PeerInfo {
addr: SocketAddr,
#[cfg(feature = "server-tls")]
certs: Option<Arc<[rustls::pki_types::CertificateDer<'static>]>>,
}
impl PeerInfo {
fn insert_into(&self, ext: &mut http::Extensions) {
ext.insert(PeerAddr(self.addr));
#[cfg(feature = "server-tls")]
if let Some(certs) = &self.certs {
ext.insert(PeerCerts(Arc::clone(certs)));
}
}
}
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
pub const DEFAULT_TLS_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
pub const DEFAULT_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_CONNECTION_AGE_GRACE: Duration = Duration::from_secs(5);
const MAX_CONNECTION_AGE_JITTER_BASIS_POINTS: u128 = 10_000;
const MAX_CONNECTION_AGE_JITTER_SPREAD_BASIS_POINTS: u128 = 1_000;
const NANOS_PER_SEC: u128 = 1_000_000_000;
pub const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
pub const DEFAULT_HTTP2_ADAPTIVE_WINDOW: bool = true;
#[derive(Clone, Copy, Debug)]
struct Http2Config {
adaptive_window: bool,
initial_stream_window_size: Option<u32>,
initial_connection_window_size: Option<u32>,
max_concurrent_streams: Option<u32>,
keepalive_interval: Option<Duration>,
keepalive_timeout: Duration,
}
impl Default for Http2Config {
fn default() -> Self {
Self {
adaptive_window: DEFAULT_HTTP2_ADAPTIVE_WINDOW,
initial_stream_window_size: None,
initial_connection_window_size: None,
max_concurrent_streams: None,
keepalive_interval: None,
keepalive_timeout: DEFAULT_HTTP2_KEEPALIVE_TIMEOUT,
}
}
}
impl Http2Config {
fn effective_windows(self) -> (Option<u32>, Option<u32>) {
if self.adaptive_window {
(None, None)
} else {
(
self.initial_stream_window_size,
self.initial_connection_window_size,
)
}
}
}
pub struct Server {
service: ConnectRpcService,
http1_keep_alive: bool,
#[cfg(feature = "server-tls")]
tls_config: Option<Arc<rustls::ServerConfig>>,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: std::time::Duration,
header_read_timeout: Option<Duration>,
max_connection_age: Option<Duration>,
max_connection_age_grace: Duration,
max_connection_idle: Option<Duration>,
http2: Http2Config,
max_requests_per_connection: Option<NonZeroU64>,
}
impl Server {
pub fn new(router: Router) -> Self {
Self {
service: ConnectRpcService::new(router),
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
header_read_timeout: Some(DEFAULT_HEADER_READ_TIMEOUT),
max_connection_age: None,
max_connection_age_grace: DEFAULT_MAX_CONNECTION_AGE_GRACE,
max_connection_idle: None,
http2: Http2Config::default(),
max_requests_per_connection: None,
}
}
pub fn from_service(service: ConnectRpcService) -> Self {
Self {
service,
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
header_read_timeout: Some(DEFAULT_HEADER_READ_TIMEOUT),
max_connection_age: None,
max_connection_age_grace: DEFAULT_MAX_CONNECTION_AGE_GRACE,
max_connection_idle: None,
http2: Http2Config::default(),
max_requests_per_connection: None,
}
}
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
#[must_use]
pub fn with_tls(mut self, config: Arc<rustls::ServerConfig>) -> Self {
self.tls_config = Some(config);
self
}
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
#[must_use]
pub fn with_tls_handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
self.tls_handshake_timeout = timeout;
self
}
#[must_use]
pub fn with_header_read_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.header_read_timeout = timeout.into();
self
}
#[must_use]
pub fn with_http1_keep_alive(mut self, enabled: bool) -> Self {
self.http1_keep_alive = enabled;
self
}
#[must_use]
pub fn with_deadline_policy(mut self, policy: crate::DeadlinePolicy) -> Self {
self.service = self.service.with_deadline_policy(policy);
self
}
#[must_use]
pub fn with_limits(mut self, limits: crate::Limits) -> Self {
self.service = self.service.with_limits(limits);
self
}
#[must_use]
pub fn with_compression(mut self, registry: crate::CompressionRegistry) -> Self {
self.service = self.service.with_compression(registry);
self
}
#[must_use]
pub fn with_compression_policy(mut self, policy: crate::CompressionPolicy) -> Self {
self.service = self.service.with_compression_policy(policy);
self
}
#[must_use]
pub fn with_interceptor(mut self, interceptor: impl crate::Interceptor) -> Self {
self.service = self.service.with_interceptor(interceptor);
self
}
#[must_use]
pub fn with_interceptor_arc(mut self, interceptor: Arc<dyn crate::Interceptor>) -> Self {
self.service = self.service.with_interceptor_arc(interceptor);
self
}
#[must_use]
pub fn with_max_connection_age(mut self, max_age: Duration) -> Self {
assert!(
!max_age.is_zero(),
"with_max_connection_age requires a non-zero duration",
);
self.max_connection_age = Some(max_age);
self
}
#[must_use]
pub fn with_max_connection_age_grace(mut self, grace: Duration) -> Self {
self.max_connection_age_grace = grace;
self
}
#[must_use]
pub fn with_max_connection_idle(mut self, duration: Duration) -> Self {
assert!(
!duration.is_zero(),
"with_max_connection_idle requires a non-zero duration",
);
self.max_connection_idle = Some(duration);
self
}
#[must_use]
pub fn with_http2_adaptive_window(mut self, enabled: bool) -> Self {
self.http2.adaptive_window = enabled;
self
}
#[must_use]
pub fn with_http2_initial_stream_window_size(mut self, size: impl Into<Option<u32>>) -> Self {
self.http2.initial_stream_window_size = size.into();
if self.http2.initial_stream_window_size.is_some() {
self.http2.adaptive_window = false;
}
self
}
#[must_use]
pub fn with_http2_initial_connection_window_size(
mut self,
size: impl Into<Option<u32>>,
) -> Self {
self.http2.initial_connection_window_size = size.into();
if self.http2.initial_connection_window_size.is_some() {
self.http2.adaptive_window = false;
}
self
}
#[must_use]
pub fn with_max_concurrent_streams(mut self, max_streams: u32) -> Self {
assert!(
max_streams != 0,
"with_max_concurrent_streams requires a non-zero value",
);
self.http2.max_concurrent_streams = Some(max_streams);
self
}
#[must_use]
pub fn with_max_requests_per_connection(mut self, max: NonZeroU64) -> Self {
self.max_requests_per_connection = Some(max);
self
}
#[must_use]
pub fn with_http2_keepalive_interval(mut self, interval: Duration) -> Self {
assert!(
!interval.is_zero(),
"with_http2_keepalive_interval requires a non-zero duration",
);
self.http2.keepalive_interval = Some(interval);
self
}
#[must_use]
pub fn with_http2_keepalive_timeout(mut self, timeout: Duration) -> Self {
self.http2.keepalive_timeout = timeout;
self
}
fn connection_age_config(&self) -> Option<ConnectionAgeConfig> {
build_connection_age_config(
self.max_connection_age,
self.max_connection_idle,
self.max_connection_age_grace,
self.max_requests_per_connection.is_some(),
)
}
fn connection_idle_config(&self) -> Option<IdleConfig> {
build_connection_idle_config(self.max_connection_idle, self.max_connection_age_grace)
}
fn request_retirement_config(&self) -> Option<RequestRetirementConfig> {
build_request_retirement_config(
self.max_requests_per_connection,
self.max_connection_age_grace,
)
}
fn retirement_config(&self) -> RetirementConfig {
RetirementConfig {
age: self.connection_age_config(),
idle: self.connection_idle_config(),
requests: self.request_retirement_config(),
}
}
pub fn router(&self) -> &Router {
self.service.dispatcher()
}
pub async fn serve(
self,
addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
let retirement = self.retirement_config();
#[cfg(feature = "server-tls")]
let tls_acceptor = self.tls_config.map(tokio_rustls::TlsAcceptor::from);
#[cfg(not(feature = "server-tls"))]
let tls_acceptor: Option<()> = None;
let scheme = if tls_acceptor.is_some() {
"https"
} else {
"http"
};
tracing::info!("ConnectRPC server listening on {scheme}://{addr}");
serve_with_listener(
listener,
self.service,
tls_acceptor,
self.http1_keep_alive,
self.header_read_timeout,
#[cfg(feature = "server-tls")]
self.tls_handshake_timeout,
None,
retirement,
self.http2,
)
.await
}
#[must_use]
pub fn from_listener(listener: TcpListener) -> BoundServer {
BoundServer {
listener,
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
header_read_timeout: Some(DEFAULT_HEADER_READ_TIMEOUT),
max_connection_age: None,
max_connection_age_grace: DEFAULT_MAX_CONNECTION_AGE_GRACE,
max_connection_idle: None,
http2: Http2Config::default(),
max_requests_per_connection: None,
}
}
pub async fn bind(
addr: impl tokio::net::ToSocketAddrs,
) -> Result<BoundServer, Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr).await?;
Ok(BoundServer {
listener,
http1_keep_alive: true,
#[cfg(feature = "server-tls")]
tls_config: None,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
header_read_timeout: Some(DEFAULT_HEADER_READ_TIMEOUT),
max_connection_age: None,
max_connection_age_grace: DEFAULT_MAX_CONNECTION_AGE_GRACE,
max_connection_idle: None,
http2: Http2Config::default(),
max_requests_per_connection: None,
})
}
}
pub struct BoundServer {
listener: TcpListener,
http1_keep_alive: bool,
#[cfg(feature = "server-tls")]
tls_config: Option<Arc<rustls::ServerConfig>>,
#[cfg(feature = "server-tls")]
tls_handshake_timeout: std::time::Duration,
header_read_timeout: Option<Duration>,
max_connection_age: Option<Duration>,
max_connection_age_grace: Duration,
max_connection_idle: Option<Duration>,
http2: Http2Config,
max_requests_per_connection: Option<NonZeroU64>,
}
impl BoundServer {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
#[must_use]
pub fn with_tls(mut self, config: Arc<rustls::ServerConfig>) -> Self {
self.tls_config = Some(config);
self
}
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
#[must_use]
pub fn with_tls_handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
self.tls_handshake_timeout = timeout;
self
}
#[must_use]
pub fn with_header_read_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.header_read_timeout = timeout.into();
self
}
#[must_use]
pub fn with_http1_keep_alive(mut self, enabled: bool) -> Self {
self.http1_keep_alive = enabled;
self
}
#[must_use]
pub fn with_max_connection_age(mut self, max_age: Duration) -> Self {
assert!(
!max_age.is_zero(),
"with_max_connection_age requires a non-zero duration",
);
self.max_connection_age = Some(max_age);
self
}
#[must_use]
pub fn with_max_connection_age_grace(mut self, grace: Duration) -> Self {
self.max_connection_age_grace = grace;
self
}
#[must_use]
pub fn with_max_connection_idle(mut self, duration: Duration) -> Self {
assert!(
!duration.is_zero(),
"with_max_connection_idle requires a non-zero duration",
);
self.max_connection_idle = Some(duration);
self
}
#[must_use]
pub fn with_http2_adaptive_window(mut self, enabled: bool) -> Self {
self.http2.adaptive_window = enabled;
self
}
#[must_use]
pub fn with_http2_initial_stream_window_size(mut self, size: impl Into<Option<u32>>) -> Self {
self.http2.initial_stream_window_size = size.into();
if self.http2.initial_stream_window_size.is_some() {
self.http2.adaptive_window = false;
}
self
}
#[must_use]
pub fn with_http2_initial_connection_window_size(
mut self,
size: impl Into<Option<u32>>,
) -> Self {
self.http2.initial_connection_window_size = size.into();
if self.http2.initial_connection_window_size.is_some() {
self.http2.adaptive_window = false;
}
self
}
#[must_use]
pub fn with_max_concurrent_streams(mut self, max_streams: u32) -> Self {
assert!(
max_streams != 0,
"with_max_concurrent_streams requires a non-zero value",
);
self.http2.max_concurrent_streams = Some(max_streams);
self
}
#[must_use]
pub fn with_max_requests_per_connection(mut self, max: NonZeroU64) -> Self {
self.max_requests_per_connection = Some(max);
self
}
#[must_use]
pub fn with_http2_keepalive_interval(mut self, interval: Duration) -> Self {
assert!(
!interval.is_zero(),
"with_http2_keepalive_interval requires a non-zero duration",
);
self.http2.keepalive_interval = Some(interval);
self
}
#[must_use]
pub fn with_http2_keepalive_timeout(mut self, timeout: Duration) -> Self {
self.http2.keepalive_timeout = timeout;
self
}
pub async fn serve(
self,
router: Router,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.serve_with_service(ConnectRpcService::new(router))
.await
}
pub async fn serve_with_graceful_shutdown<F>(
self,
router: Router,
signal: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: Future<Output = ()> + Send + 'static,
{
self.serve_with_service_and_shutdown(ConnectRpcService::new(router), signal)
.await
}
pub async fn serve_with_service<D: Dispatcher>(
self,
service: ConnectRpcService<D>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let retirement = self.retirement_config();
#[cfg(feature = "server-tls")]
let tls_acceptor = self.tls_config.map(tokio_rustls::TlsAcceptor::from);
#[cfg(not(feature = "server-tls"))]
let tls_acceptor: Option<()> = None;
serve_with_listener(
self.listener,
service,
tls_acceptor,
self.http1_keep_alive,
self.header_read_timeout,
#[cfg(feature = "server-tls")]
self.tls_handshake_timeout,
None,
retirement,
self.http2,
)
.await
}
pub async fn serve_with_service_and_shutdown<D, F>(
self,
service: ConnectRpcService<D>,
signal: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
D: Dispatcher,
F: Future<Output = ()> + Send + 'static,
{
let retirement = self.retirement_config();
#[cfg(feature = "server-tls")]
let tls_acceptor = self.tls_config.map(tokio_rustls::TlsAcceptor::from);
#[cfg(not(feature = "server-tls"))]
let tls_acceptor: Option<()> = None;
serve_with_listener(
self.listener,
service,
tls_acceptor,
self.http1_keep_alive,
self.header_read_timeout,
#[cfg(feature = "server-tls")]
self.tls_handshake_timeout,
Some(Box::pin(signal)),
retirement,
self.http2,
)
.await
}
fn connection_age_config(&self) -> Option<ConnectionAgeConfig> {
build_connection_age_config(
self.max_connection_age,
self.max_connection_idle,
self.max_connection_age_grace,
self.max_requests_per_connection.is_some(),
)
}
fn connection_idle_config(&self) -> Option<IdleConfig> {
build_connection_idle_config(self.max_connection_idle, self.max_connection_age_grace)
}
fn request_retirement_config(&self) -> Option<RequestRetirementConfig> {
build_request_retirement_config(
self.max_requests_per_connection,
self.max_connection_age_grace,
)
}
fn retirement_config(&self) -> RetirementConfig {
RetirementConfig {
age: self.connection_age_config(),
idle: self.connection_idle_config(),
requests: self.request_retirement_config(),
}
}
}
fn build_connection_age_config(
max_age: Option<Duration>,
max_idle: Option<Duration>,
grace: Duration,
request_retirement_active: bool,
) -> Option<ConnectionAgeConfig> {
let Some(max_age) = max_age else {
if max_idle.is_none()
&& !request_retirement_active
&& grace != DEFAULT_MAX_CONNECTION_AGE_GRACE
{
tracing::debug!(
"max_connection_age_grace is set but none of max_connection_age, \
max_connection_idle, or max_requests_per_connection are; the \
grace period has no effect",
);
}
return None;
};
Some(ConnectionAgeConfig { max_age, grace })
}
fn build_connection_idle_config(max_idle: Option<Duration>, grace: Duration) -> Option<IdleConfig> {
max_idle.map(|idle| IdleConfig { idle, grace })
}
fn build_request_retirement_config(
max_requests: Option<NonZeroU64>,
grace: Duration,
) -> Option<RequestRetirementConfig> {
max_requests.map(|max| RequestRetirementConfig { max, grace })
}
type WrappedService<D> = tower_http::catch_panic::CatchPanic<
ConnectRpcService<D>,
fn(Box<dyn Any + Send>) -> Response<Full<Bytes>>,
>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct ConnectionAgeConfig {
max_age: Duration,
grace: Duration,
}
impl ConnectionAgeConfig {
fn with_jitter(self, sample: u64) -> Self {
Self {
max_age: jitter_connection_age(self.max_age, sample),
grace: self.grace,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct IdleConfig {
idle: Duration,
grace: Duration,
}
#[derive(Clone, Copy, Debug, Default)]
struct RetirementConfig {
age: Option<ConnectionAgeConfig>,
idle: Option<IdleConfig>,
requests: Option<RequestRetirementConfig>,
}
#[derive(Debug, Default)]
struct ConnectionActivity {
in_flight: AtomicUsize,
epoch: AtomicU64,
}
impl ConnectionActivity {
fn request_started(&self) {
self.in_flight.fetch_add(1, Ordering::SeqCst);
self.epoch.fetch_add(1, Ordering::SeqCst);
}
fn request_finished(&self) {
self.in_flight.fetch_sub(1, Ordering::SeqCst);
self.epoch.fetch_add(1, Ordering::SeqCst);
}
fn snapshot(&self) -> (usize, u64) {
(
self.in_flight.load(Ordering::SeqCst),
self.epoch.load(Ordering::SeqCst),
)
}
}
struct ActiveRequestGuard(Arc<ConnectionActivity>);
impl ActiveRequestGuard {
fn new(activity: Arc<ConnectionActivity>) -> Self {
activity.request_started();
Self(activity)
}
}
impl Drop for ActiveRequestGuard {
fn drop(&mut self) {
self.0.request_finished();
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct RequestRetirementConfig {
max: NonZeroU64,
grace: Duration,
}
#[allow(clippy::too_many_arguments)]
async fn serve_accepted_stream<D, S>(
io: S,
peer: PeerInfo,
service: Arc<WrappedService<D>>,
http1_keep_alive: bool,
header_read_timeout: Option<Duration>,
global_shutdown: watch::Receiver<bool>,
retirement: RetirementConfig,
http2: Http2Config,
) where
D: Dispatcher,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
tracing::trace!(remote_addr = %peer.addr, "Accepted new connection");
let activity = retirement
.idle
.map(|_| Arc::new(ConnectionActivity::default()));
let (request_counter, request_retire) = match retirement.requests {
Some(config) => {
let (tx, rx) = watch::channel(false);
(
Some(RequestCounter {
served: AtomicU64::new(0),
max: config.max,
retire: tx,
}),
Some((rx, config.grace)),
)
}
None => (None, None),
};
let peer_for_requests = peer.clone();
let activity_for_requests = activity.clone();
let svc = hyper::service::service_fn(move |mut req| {
peer_for_requests.insert_into(req.extensions_mut());
if let Some(counter) = &request_counter {
counter.record_request();
}
let mut service = (*service).clone();
let guard = activity_for_requests
.as_ref()
.map(|activity| ActiveRequestGuard::new(Arc::clone(activity)));
async move {
let _guard = guard;
service.call(req).await
}
});
let mut builder = AutoBuilder::new(TokioExecutor::new());
builder
.http1()
.timer(TokioTimer::new())
.keep_alive(http1_keep_alive)
.header_read_timeout(header_read_timeout);
configure_http2(&mut builder, http2);
let conn = builder.serve_connection(TokioIo::new(io), svc).into_owned();
serve_connection_with_lifecycle(
conn,
peer.addr,
global_shutdown,
retirement.age,
retirement.idle.zip(activity),
request_retire,
)
.await;
}
struct RequestCounter {
served: AtomicU64,
max: NonZeroU64,
retire: watch::Sender<bool>,
}
impl RequestCounter {
fn record_request(&self) {
let served = self
.served
.fetch_add(1, Ordering::Relaxed)
.saturating_add(1);
if served >= self.max.get() {
let _ = self.retire.send(true);
}
}
}
fn configure_http2(builder: &mut AutoBuilder<TokioExecutor>, config: Http2Config) {
let mut http2 = builder.http2();
http2.adaptive_window(config.adaptive_window);
let (stream_window, connection_window) = config.effective_windows();
if let Some(size) = stream_window {
http2.initial_stream_window_size(size);
}
if let Some(size) = connection_window {
http2.initial_connection_window_size(size);
}
if let Some(max) = config.max_concurrent_streams {
http2.max_concurrent_streams(max);
}
if let Some(interval) = config.keepalive_interval {
http2
.timer(TokioTimer::new())
.keep_alive_interval(interval)
.keep_alive_timeout(config.keepalive_timeout);
}
}
fn serve_connection_with_lifecycle<C>(
conn: C,
remote_addr: SocketAddr,
global_shutdown: watch::Receiver<bool>,
connection_age: Option<ConnectionAgeConfig>,
connection_idle: Option<(IdleConfig, Arc<ConnectionActivity>)>,
request_retire: Option<(watch::Receiver<bool>, Duration)>,
) -> ConnectionLifecycle<C>
where
C: GracefulConnection,
C::Error: std::fmt::Display,
{
ConnectionLifecycle {
conn: Box::pin(conn),
remote_addr,
global_shutdown: global_shutdown_future(global_shutdown),
age: connection_age.map(|config| (Box::pin(tokio::time::sleep(config.max_age)), config)),
idle: connection_idle.map(|(config, activity)| {
let armed_epoch = activity.snapshot().1;
IdleTracker {
config,
activity,
timer: Box::pin(tokio::time::sleep(config.idle)),
armed_epoch,
}
}),
requests: request_retire.map(|(rx, grace)| (global_shutdown_future(rx), grace)),
state: ConnectionLifecycleState::Serving,
}
}
fn global_shutdown_future(
mut global_shutdown: watch::Receiver<bool>,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
let _ = global_shutdown.wait_for(|fired| *fired).await;
})
}
type RetirementSignal = Pin<Box<dyn Future<Output = ()> + Send>>;
struct ConnectionLifecycle<C: GracefulConnection> {
conn: Pin<Box<C>>,
remote_addr: SocketAddr,
global_shutdown: Pin<Box<dyn Future<Output = ()> + Send>>,
age: Option<(Pin<Box<tokio::time::Sleep>>, ConnectionAgeConfig)>,
idle: Option<IdleTracker>,
requests: Option<(RetirementSignal, Duration)>,
state: ConnectionLifecycleState,
}
struct IdleTracker {
config: IdleConfig,
activity: Arc<ConnectionActivity>,
timer: Pin<Box<tokio::time::Sleep>>,
armed_epoch: u64,
}
enum ConnectionLifecycleState {
Serving,
GlobalDraining,
Draining {
grace: Pin<Box<tokio::time::Sleep>>,
duration: Duration,
},
}
impl<C> Future for ConnectionLifecycle<C>
where
C: GracefulConnection,
C::Error: std::fmt::Display,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.get_mut();
loop {
match &mut this.state {
ConnectionLifecycleState::Serving => {
if let Poll::Ready(result) = this.conn.as_mut().poll(cx) {
log_connection_result(this.remote_addr, result);
return Poll::Ready(());
}
if let Poll::Ready(()) = this.global_shutdown.as_mut().poll(cx) {
this.conn.as_mut().graceful_shutdown();
this.state = ConnectionLifecycleState::GlobalDraining;
continue;
}
if let Some((age, config)) = &mut this.age
&& age.as_mut().poll(cx).is_ready()
{
tracing::trace!(
remote_addr = %this.remote_addr,
max_age = ?config.max_age,
grace = ?config.grace,
"Connection reached maximum age; starting graceful shutdown",
);
this.conn.as_mut().graceful_shutdown();
this.state = ConnectionLifecycleState::Draining {
grace: Box::pin(tokio::time::sleep(config.grace)),
duration: config.grace,
};
continue;
}
if let Some(idle) = &mut this.idle
&& idle.timer.as_mut().poll(cx).is_ready()
{
let (in_flight, epoch) = idle.activity.snapshot();
if in_flight == 0 && epoch == idle.armed_epoch {
tracing::trace!(
remote_addr = %this.remote_addr,
idle = ?idle.config.idle,
grace = ?idle.config.grace,
"Connection idle; starting graceful shutdown",
);
this.conn.as_mut().graceful_shutdown();
this.state = ConnectionLifecycleState::Draining {
grace: Box::pin(tokio::time::sleep(idle.config.grace)),
duration: idle.config.grace,
};
continue;
}
idle.armed_epoch = epoch;
let next = tokio::time::Instant::now() + idle.config.idle;
idle.timer.as_mut().reset(next);
continue;
}
if let Some((requests, grace)) = &mut this.requests
&& requests.as_mut().poll(cx).is_ready()
{
let grace = *grace;
tracing::trace!(
remote_addr = %this.remote_addr,
grace = ?grace,
"Connection reached maximum requests; starting graceful shutdown",
);
this.conn.as_mut().graceful_shutdown();
this.state = ConnectionLifecycleState::Draining {
grace: Box::pin(tokio::time::sleep(grace)),
duration: grace,
};
continue;
}
return Poll::Pending;
}
ConnectionLifecycleState::GlobalDraining => {
if let Poll::Ready(result) = this.conn.as_mut().poll(cx) {
log_connection_result(this.remote_addr, result);
return Poll::Ready(());
}
return Poll::Pending;
}
ConnectionLifecycleState::Draining { grace, duration } => {
if let Poll::Ready(result) = this.conn.as_mut().poll(cx) {
log_connection_result(this.remote_addr, result);
return Poll::Ready(());
}
if let Poll::Ready(()) = this.global_shutdown.as_mut().poll(cx) {
this.state = ConnectionLifecycleState::GlobalDraining;
continue;
}
if grace.as_mut().poll(cx).is_ready() {
tracing::trace!(
remote_addr = %this.remote_addr,
grace = ?duration,
"Connection retirement grace expired; closing connection",
);
return Poll::Ready(());
}
return Poll::Pending;
}
}
}
}
}
fn log_connection_result<E: std::fmt::Display>(remote_addr: SocketAddr, result: Result<(), E>) {
match result {
Ok(()) => {
tracing::trace!(remote_addr = %remote_addr, "Connection completed normally");
}
Err(err) => {
tracing::trace!(
remote_addr = %remote_addr,
error = %err,
"Connection ended with error",
);
}
}
}
fn jitter_connection_age(age: Duration, sample: u64) -> Duration {
if age.is_zero() {
return age;
}
let spread = MAX_CONNECTION_AGE_JITTER_SPREAD_BASIS_POINTS * 2;
let offset = (u128::from(sample) * spread) / u128::from(u64::MAX);
let basis_points = MAX_CONNECTION_AGE_JITTER_BASIS_POINTS
- MAX_CONNECTION_AGE_JITTER_SPREAD_BASIS_POINTS
+ offset;
let scaled = age.as_nanos().saturating_mul(basis_points);
let nanos = if basis_points < MAX_CONNECTION_AGE_JITTER_BASIS_POINTS {
scaled.saturating_add(MAX_CONNECTION_AGE_JITTER_BASIS_POINTS - 1)
/ MAX_CONNECTION_AGE_JITTER_BASIS_POINTS
} else {
scaled / MAX_CONNECTION_AGE_JITTER_BASIS_POINTS
};
duration_from_nanos(nanos.min(Duration::MAX.as_nanos()))
}
fn duration_from_nanos(nanos: u128) -> Duration {
Duration::new(
(nanos / NANOS_PER_SEC) as u64,
(nanos % NANOS_PER_SEC) as u32,
)
}
#[cfg(feature = "server-tls")]
type MaybeTlsAcceptor = Option<tokio_rustls::TlsAcceptor>;
#[cfg(not(feature = "server-tls"))]
type MaybeTlsAcceptor = Option<()>;
type ShutdownSignal = Option<Pin<Box<dyn Future<Output = ()> + Send>>>;
#[allow(clippy::too_many_arguments)]
async fn serve_with_listener<D: Dispatcher>(
listener: TcpListener,
service: ConnectRpcService<D>,
tls_acceptor: MaybeTlsAcceptor,
http1_keep_alive: bool,
header_read_timeout: Option<Duration>,
#[cfg(feature = "server-tls")] tls_handshake_timeout: std::time::Duration,
shutdown: ShutdownSignal,
retirement: RetirementConfig,
http2: Http2Config,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if http2.keepalive_interval.is_none()
&& http2.keepalive_timeout != DEFAULT_HTTP2_KEEPALIVE_TIMEOUT
{
tracing::debug!(
"http2_keepalive_timeout is set but http2_keepalive_interval is not; \
HTTP/2 keepalive stays disabled and the timeout has no effect",
);
}
let service: WrappedService<D> = ServiceBuilder::new()
.layer(CatchPanicLayer::custom(panic_handler as fn(_) -> _))
.service(service);
let service = Arc::new(service);
#[cfg(feature = "server-tls")]
let tls_acceptor = tls_acceptor.map(Arc::new);
#[cfg(not(feature = "server-tls"))]
let _ = tls_acceptor;
let mut shutdown = shutdown.unwrap_or_else(|| Box::pin(std::future::pending()));
let (global_shutdown_tx, global_shutdown_rx) = watch::channel(false);
let mut connections = JoinSet::new();
let jitter_state = RandomState::new();
let mut connection_sequence = 0u64;
loop {
let (stream, remote_addr) = tokio::select! {
biased;
_ = &mut shutdown => {
tracing::info!("Shutdown signal received; draining connections");
break;
}
Some(result) = connections.join_next(), if !connections.is_empty() => {
log_connection_task_result(result);
continue;
}
accept_result = listener.accept() => match accept_result {
Ok(conn) => conn,
Err(err) => {
if is_transient_accept_error(&err) {
tracing::warn!("Transient accept error (continuing): {}", err);
continue;
}
connections.detach_all();
return Err(err.into());
}
},
};
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!("failed to set TCP_NODELAY: {e}");
}
let service = Arc::clone(&service);
let global_shutdown = global_shutdown_rx.clone();
connection_sequence = connection_sequence.wrapping_add(1);
let retirement = RetirementConfig {
age: retirement.age.map(|config| {
config.with_jitter(jitter_state.hash_one((remote_addr, connection_sequence)))
}),
..retirement
};
#[cfg(feature = "server-tls")]
let tls_acceptor = tls_acceptor.clone();
connections.spawn(async move {
#[cfg(feature = "server-tls")]
if let Some(acceptor) = tls_acceptor {
match tokio::time::timeout(tls_handshake_timeout, acceptor.accept(stream)).await {
Ok(Ok(tls_stream)) => {
let (_, conn) = tls_stream.get_ref();
let certs = conn.peer_certificates().map(|chain| -> Arc<[_]> {
chain.iter().map(|c| c.clone().into_owned()).collect()
});
let peer = PeerInfo {
addr: remote_addr,
certs,
};
serve_accepted_stream(
tls_stream,
peer,
service,
http1_keep_alive,
header_read_timeout,
global_shutdown,
retirement,
http2,
)
.await;
}
Ok(Err(err)) => {
tracing::debug!(
remote_addr = %remote_addr,
error = ?err,
"TLS handshake failed: {err}",
);
}
Err(_) => {
tracing::warn!(
remote_addr = %remote_addr,
"TLS handshake timed out after {tls_handshake_timeout:?}",
);
}
}
return;
}
let peer = PeerInfo {
addr: remote_addr,
#[cfg(feature = "server-tls")]
certs: None,
};
serve_accepted_stream(
stream,
peer,
service,
http1_keep_alive,
header_read_timeout,
global_shutdown,
retirement,
http2,
)
.await;
});
}
drop(listener);
let _ = global_shutdown_tx.send(true);
while let Some(result) = connections.join_next().await {
log_connection_task_result(result);
}
tracing::info!("All connections drained; shutdown complete");
Ok(())
}
fn log_connection_task_result(result: Result<(), tokio::task::JoinError>) {
if let Err(err) = result {
tracing::warn!(error = %err, "Connection task ended unexpectedly");
}
}
fn panic_handler(err: Box<dyn Any + Send + 'static>) -> Response<Full<Bytes>> {
let backtrace = std::backtrace::Backtrace::capture();
let message = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
(*s).to_string()
} else {
"handler panicked".to_string()
};
match backtrace.status() {
std::backtrace::BacktraceStatus::Captured => {
tracing::error!(
"Request handler panicked: {}\n\nBacktrace:\n{}",
message,
backtrace
);
}
_ => {
tracing::error!(
"Request handler panicked: {} (set RUST_BACKTRACE=1 for backtrace)",
message
);
}
}
let error = ConnectError::new(ErrorCode::Internal, "internal server error");
let body = error.to_json();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, content_type::JSON)
.body(Full::new(body))
.unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::new()))
.unwrap()
})
}
pub(crate) fn is_transient_accept_error(err: &std::io::Error) -> bool {
use std::io::ErrorKind;
matches!(
err.kind(),
ErrorKind::WouldBlock |
ErrorKind::Interrupted |
ErrorKind::ConnectionAborted |
ErrorKind::ConnectionReset
) || {
err.raw_os_error()
.is_some_and(|code| code == libc::EMFILE || code == libc::ENFILE)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
const ECHO_REQ: &[u8] = concat!(
"POST /svc/Echo HTTP/1.1\r\n",
"Host: localhost\r\n",
"Content-Type: application/proto\r\n",
"Content-Length: 0\r\n",
"Connection: close\r\n",
"\r\n",
)
.as_bytes();
const KEEPALIVE_ECHO_REQ: &[u8] = concat!(
"POST /svc/Echo HTTP/1.1\r\n",
"Host: localhost\r\n",
"Content-Type: application/proto\r\n",
"Content-Length: 0\r\n",
"Connection: keep-alive\r\n",
"\r\n",
)
.as_bytes();
#[test]
fn test_server_creation() {
let router = Router::new();
let _server = Server::new(router);
}
#[test]
fn test_server_dispatch_config_proxies() {
use crate::service::Limits;
use crate::{CompressionPolicy, CompressionRegistry};
let limits = Limits {
max_request_body_size: 1024,
max_message_size: 512,
};
let server = Server::new(Router::new())
.with_limits(limits.clone())
.with_compression(CompressionRegistry::default())
.with_compression_policy(CompressionPolicy::default().min_size(8192))
.with_http1_keep_alive(false);
assert_eq!(server.service.limits().max_request_body_size, 1024);
assert_eq!(server.service.limits().max_message_size, 512);
assert!(!server.http1_keep_alive);
}
#[test]
fn test_server_interceptor_proxies() {
struct Noop;
#[async_trait::async_trait]
impl crate::Interceptor for Noop {}
let shared: Arc<dyn crate::Interceptor> = Arc::new(Noop);
assert_eq!(Arc::strong_count(&shared), 1);
let server = Server::new(Router::new())
.with_interceptor(Noop)
.with_interceptor_arc(Arc::clone(&shared));
assert_eq!(
Arc::strong_count(&shared),
2,
"Server::with_interceptor_arc must reach the underlying service"
);
drop(server);
assert_eq!(Arc::strong_count(&shared), 1);
}
#[tokio::test]
async fn test_graceful_shutdown_immediate() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
rx.await.ok();
})
.await
});
tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down in time")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn test_graceful_shutdown_drains_inflight_request() {
let (entered_tx, entered_rx) = tokio::sync::oneshot::channel();
let (release_tx, release_rx) = tokio::sync::oneshot::channel();
let chans = Arc::new(Mutex::new(Some((entered_tx, release_rx))));
let router = Router::new().route(
"svc",
"Slow",
crate::handler_fn(
move |_ctx: crate::RequestContext, _req: buffa_types::Empty| {
let chans = Arc::clone(&chans);
async move {
let taken = chans.lock().unwrap().take();
if let Some((entered_tx, release_rx)) = taken {
entered_tx.send(()).ok();
release_rx.await.ok();
}
crate::Response::ok(buffa_types::Empty::default())
}
},
),
);
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (resp_fut, _) = send_request.send_request(req, true).unwrap();
let mut resp_fut = tokio::spawn(resp_fut);
tokio::time::timeout(Duration::from_secs(5), entered_rx)
.await
.expect("handler never entered")
.unwrap();
shutdown_tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(
!serve.is_finished(),
"server shut down before in-flight request completed"
);
assert!(!resp_fut.is_finished(), "response arrived too early");
release_tx.send(()).unwrap();
let resp = tokio::time::timeout(Duration::from_secs(5), &mut resp_fut)
.await
.expect("response never arrived")
.expect("join error")
.expect("h2 request failed");
assert!(resp.status().is_success(), "got status {}", resp.status());
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down after in-flight request drained")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn test_graceful_shutdown_rejects_new_connections() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
rx.await.ok();
})
.await
});
tokio::time::sleep(Duration::from_millis(20)).await;
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
let connect_result = tokio::net::TcpStream::connect(addr).await;
assert!(
connect_result.is_err(),
"expected connection refused after shutdown"
);
}
#[tokio::test]
async fn test_graceful_shutdown_sends_h2_goaway() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
resp.await.unwrap();
tx.send(()).unwrap();
let conn_result = tokio::time::timeout(Duration::from_secs(2), h2_task)
.await
.expect("server did not close idle h2 connection (no GOAWAY?)")
.expect("h2 connection task panicked");
if let Err(e) = conn_result {
assert!(
e.is_go_away(),
"h2 connection ended with non-GOAWAY error: {e:?}"
);
}
let result = tokio::time::timeout(Duration::from_secs(2), serve)
.await
.expect("server did not shut down after GOAWAY drained the connection")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
drop(send_request);
}
#[test]
fn max_connection_age_jitter_stays_within_bounds() {
let samples = [0, 1, u64::MAX / 2, u64::MAX - 1, u64::MAX];
let ages = [
Duration::ZERO,
Duration::from_nanos(1),
Duration::from_secs(10),
Duration::MAX,
];
assert_eq!(
jitter_connection_age(Duration::from_secs(10), 0),
Duration::from_secs(9)
);
assert_eq!(
jitter_connection_age(Duration::from_secs(10), u64::MAX),
Duration::from_secs(11)
);
for age in ages {
for sample in samples {
let jittered = jitter_connection_age(age, sample);
if age.is_zero() {
assert_eq!(jittered, Duration::ZERO);
continue;
}
assert!(
jittered
.as_nanos()
.saturating_mul(MAX_CONNECTION_AGE_JITTER_BASIS_POINTS)
>= age.as_nanos().saturating_mul(
MAX_CONNECTION_AGE_JITTER_BASIS_POINTS
- MAX_CONNECTION_AGE_JITTER_SPREAD_BASIS_POINTS
),
"{jittered:?} was below the 90% jitter bound for {age:?}"
);
assert!(
jittered
.as_nanos()
.saturating_mul(MAX_CONNECTION_AGE_JITTER_BASIS_POINTS)
<= age.as_nanos().saturating_mul(
MAX_CONNECTION_AGE_JITTER_BASIS_POINTS
+ MAX_CONNECTION_AGE_JITTER_SPREAD_BASIS_POINTS
),
"{jittered:?} was above the 110% jitter bound for {age:?}"
);
}
}
}
#[tokio::test]
async fn max_connection_age_builder_defaults_and_overrides() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert_eq!(bound.max_connection_age, None);
assert_eq!(
bound.max_connection_age_grace,
DEFAULT_MAX_CONNECTION_AGE_GRACE
);
let bound = bound
.with_max_connection_age(Duration::from_secs(30))
.with_max_connection_age_grace(Duration::ZERO);
assert_eq!(bound.max_connection_age, Some(Duration::from_secs(30)));
assert_eq!(bound.max_connection_age_grace, Duration::ZERO);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound =
Server::from_listener(listener).with_max_connection_age_grace(Duration::from_secs(2));
assert_eq!(bound.max_connection_age, None);
assert_eq!(bound.max_connection_age_grace, Duration::from_secs(2));
}
#[test]
fn server_max_connection_age_builder_threads_through() {
let server = Server::new(Router::new());
assert_eq!(server.max_connection_age, None);
assert_eq!(server.connection_age_config(), None);
let server = Server::new(Router::new())
.with_max_connection_age(Duration::from_secs(30))
.with_max_connection_age_grace(Duration::from_secs(2));
assert_eq!(
server.connection_age_config(),
Some(ConnectionAgeConfig {
max_age: Duration::from_secs(30),
grace: Duration::from_secs(2),
})
);
}
#[test]
#[should_panic(expected = "non-zero duration")]
fn with_max_connection_age_rejects_zero() {
let _ = Server::new(Router::new()).with_max_connection_age(Duration::ZERO);
}
#[test]
fn header_read_timeout_builder_defaults_and_overrides() {
let server = Server::new(Router::new());
assert_eq!(
server.header_read_timeout,
Some(DEFAULT_HEADER_READ_TIMEOUT)
);
let server =
Server::new(Router::new()).with_header_read_timeout(Some(Duration::from_secs(5)));
assert_eq!(server.header_read_timeout, Some(Duration::from_secs(5)));
let server = Server::new(Router::new()).with_header_read_timeout(None::<Duration>);
assert_eq!(server.header_read_timeout, None);
}
#[tokio::test]
async fn bound_server_header_read_timeout_builder_defaults_and_overrides() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert_eq!(bound.header_read_timeout, Some(DEFAULT_HEADER_READ_TIMEOUT));
let bound = bound.with_header_read_timeout(Some(Duration::from_secs(2)));
assert_eq!(bound.header_read_timeout, Some(Duration::from_secs(2)));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener).with_header_read_timeout(None::<Duration>);
assert_eq!(bound.header_read_timeout, None);
}
#[tokio::test(start_paused = true)]
async fn header_read_timeout_closes_stalled_connection() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_header_read_timeout(Some(Duration::from_secs(10)));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"POST /svc/Echo HTTP/1.1\r\nHost: localhost\r\n")
.await
.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
let mut buf = [0; 1];
let read = stream.read(&mut buf).await.unwrap();
assert_eq!(
read, 0,
"stalled connection stayed open past the header read timeout"
);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn header_read_timeout_allows_prompt_requests() {
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(
|_ctx: crate::RequestContext, _req: buffa_types::Empty| async move {
crate::Response::ok(buffa_types::Empty::default())
},
),
);
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_header_read_timeout(Some(Duration::from_secs(30)));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
stream.write_all(ECHO_REQ).await.unwrap();
let resp = read_http1_response(&mut stream).await;
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[test]
fn http2_config_default_enables_adaptive_window() {
let config = Http2Config::default();
assert!(config.adaptive_window);
assert_eq!(config.adaptive_window, DEFAULT_HTTP2_ADAPTIVE_WINDOW);
assert_eq!(config.initial_stream_window_size, None);
assert_eq!(config.initial_connection_window_size, None);
}
#[test]
fn server_http2_builder_defaults_match_adaptive_on() {
let server = Server::new(Router::new());
assert!(server.http2.adaptive_window);
assert_eq!(server.http2.initial_stream_window_size, None);
assert_eq!(server.http2.initial_connection_window_size, None);
let from_service = Server::from_service(ConnectRpcService::new(Router::new()));
assert!(from_service.http2.adaptive_window);
}
#[tokio::test]
async fn bound_server_http2_builder_defaults_match_adaptive_on() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert!(bound.http2.adaptive_window);
assert_eq!(bound.http2.initial_stream_window_size, None);
assert_eq!(bound.http2.initial_connection_window_size, None);
let bound = Server::bind("127.0.0.1:0").await.unwrap();
assert!(bound.http2.adaptive_window);
}
#[test]
fn with_http2_adaptive_window_toggles_flag() {
let server = Server::new(Router::new()).with_http2_adaptive_window(false);
assert!(!server.http2.adaptive_window);
let server = server.with_http2_adaptive_window(true);
assert!(server.http2.adaptive_window);
}
#[test]
fn explicit_stream_window_disables_adaptive() {
let server = Server::new(Router::new()).with_http2_initial_stream_window_size(1 << 20);
assert_eq!(server.http2.initial_stream_window_size, Some(1 << 20));
assert!(
!server.http2.adaptive_window,
"an explicit stream window must turn adaptive sizing off"
);
}
#[test]
fn explicit_connection_window_disables_adaptive() {
let server = Server::new(Router::new()).with_http2_initial_connection_window_size(2 << 20);
assert_eq!(server.http2.initial_connection_window_size, Some(2 << 20));
assert!(
!server.http2.adaptive_window,
"an explicit connection window must turn adaptive sizing off"
);
}
#[test]
fn clearing_window_with_none_keeps_adaptive_flag() {
let server = Server::new(Router::new())
.with_http2_initial_stream_window_size(None)
.with_http2_initial_connection_window_size(None);
assert!(server.http2.adaptive_window);
assert_eq!(server.http2.initial_stream_window_size, None);
assert_eq!(server.http2.initial_connection_window_size, None);
}
#[test]
fn re_enabling_adaptive_after_explicit_window_wins() {
let server = Server::new(Router::new())
.with_http2_initial_stream_window_size(1 << 20)
.with_http2_adaptive_window(true);
assert!(server.http2.adaptive_window);
assert_eq!(server.http2.initial_stream_window_size, Some(1 << 20));
assert_eq!(server.http2.effective_windows(), (None, None));
}
#[test]
fn effective_windows_resolves_adaptive_precedence() {
assert_eq!(Http2Config::default().effective_windows(), (None, None));
let off = Server::new(Router::new()).with_http2_adaptive_window(false);
assert_eq!(off.http2.effective_windows(), (None, None));
let fixed = Server::new(Router::new())
.with_http2_initial_stream_window_size(1 << 20)
.with_http2_initial_connection_window_size(2 << 20);
assert!(!fixed.http2.adaptive_window);
assert_eq!(
fixed.http2.effective_windows(),
(Some(1 << 20), Some(2 << 20))
);
}
#[tokio::test]
async fn bound_server_http2_window_setters_thread_through() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener)
.with_http2_initial_stream_window_size(512 * 1024)
.with_http2_initial_connection_window_size(1024 * 1024);
assert_eq!(bound.http2.initial_stream_window_size, Some(512 * 1024));
assert_eq!(
bound.http2.initial_connection_window_size,
Some(1024 * 1024)
);
assert!(!bound.http2.adaptive_window);
}
#[tokio::test]
async fn http2_explicit_windows_serve_request() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_http2_initial_stream_window_size(256 * 1024)
.with_http2_initial_connection_window_size(512 * 1024);
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let _resp = resp.await.expect("h2 request failed");
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
h2_task.await.expect("h2 connection task panicked").ok();
}
#[tokio::test]
async fn http2_adaptive_window_default_serves_request() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
assert!(bound.http2.adaptive_window);
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let _resp = resp.await.expect("h2 request failed");
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
h2_task.await.expect("h2 connection task panicked").ok();
}
#[tokio::test]
async fn http2_keepalive_builder_defaults_and_overrides() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert_eq!(bound.http2.keepalive_interval, None);
assert_eq!(
bound.http2.keepalive_timeout,
DEFAULT_HTTP2_KEEPALIVE_TIMEOUT
);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener)
.with_http2_keepalive_interval(Duration::from_secs(30))
.with_http2_keepalive_timeout(Duration::from_secs(5));
assert_eq!(
bound.http2.keepalive_interval,
Some(Duration::from_secs(30))
);
assert_eq!(bound.http2.keepalive_timeout, Duration::from_secs(5));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound =
Server::from_listener(listener).with_http2_keepalive_timeout(Duration::from_secs(1));
assert_eq!(bound.http2.keepalive_interval, None);
assert_eq!(bound.http2.keepalive_timeout, Duration::from_secs(1));
}
#[test]
fn server_http2_keepalive_builder_threads_through() {
let server = Server::new(Router::new());
assert_eq!(server.http2.keepalive_interval, None);
assert_eq!(
server.http2.keepalive_timeout,
DEFAULT_HTTP2_KEEPALIVE_TIMEOUT
);
let server = Server::new(Router::new())
.with_http2_keepalive_interval(Duration::from_millis(500))
.with_http2_keepalive_timeout(Duration::from_millis(250));
assert_eq!(
server.http2.keepalive_interval,
Some(Duration::from_millis(500))
);
assert_eq!(server.http2.keepalive_timeout, Duration::from_millis(250));
}
#[test]
#[should_panic(expected = "non-zero duration")]
fn with_http2_keepalive_interval_rejects_zero() {
let _ = Server::new(Router::new()).with_http2_keepalive_interval(Duration::ZERO);
}
#[test]
fn configure_http2_default_leaves_keepalive_disabled() {
assert!(Http2Config::default().keepalive_interval.is_none());
let mut builder = AutoBuilder::new(TokioExecutor::new());
configure_http2(&mut builder, Http2Config::default());
}
#[tokio::test]
async fn http2_keepalive_closes_unresponsive_peer() {
let (router, entered_rx, _release_tx) = slow_router();
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_http2_keepalive_interval(Duration::from_millis(100))
.with_http2_keepalive_timeout(Duration::from_millis(100));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, mut h2_conn) = h2::client::handshake(tcp).await.unwrap();
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (_resp, _) = send_request.send_request(req, true).unwrap();
tokio::select! {
result = &mut h2_conn => panic!("connection closed before handler ran: {result:?}"),
entered = entered_rx => entered.expect("handler never entered"),
}
tokio::time::sleep(Duration::from_secs(1)).await;
let closed = tokio::time::timeout(Duration::from_secs(5), &mut h2_conn).await;
assert!(
closed.is_ok(),
"server did not close the unresponsive connection; keepalive PINGs were not plumbed through",
);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn max_concurrent_streams_builder_defaults_and_overrides() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert_eq!(bound.http2.max_concurrent_streams, None);
let bound = bound.with_max_concurrent_streams(64);
assert_eq!(bound.http2.max_concurrent_streams, Some(64));
let server = Server::new(Router::new());
assert_eq!(server.http2.max_concurrent_streams, None);
let server = server.with_max_concurrent_streams(64);
assert_eq!(server.http2.max_concurrent_streams, Some(64));
}
#[test]
#[should_panic(expected = "non-zero value")]
fn with_max_concurrent_streams_rejects_zero() {
let _ = Server::new(Router::new()).with_max_concurrent_streams(0);
}
#[tokio::test]
async fn max_concurrent_streams_is_advertised_in_settings() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_concurrent_streams(7);
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let advertised = read_advertised_max_concurrent_streams(addr).await;
assert_eq!(
advertised,
Some(7),
"server did not advertise the configured max_concurrent_streams",
);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn max_concurrent_streams_unset_uses_hyper_default() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let advertised = read_advertised_max_concurrent_streams(addr).await;
assert_eq!(
advertised,
Some(200),
"unset max_concurrent_streams should keep hyper's default of 200",
);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
const SETTINGS_MAX_CONCURRENT_STREAMS_ID: u16 = 0x3;
async fn read_advertised_max_concurrent_streams(addr: SocketAddr) -> Option<u32> {
let mut tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
tcp.write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
.await
.unwrap();
tcp.write_all(&[0, 0, 0, 0x4, 0, 0, 0, 0, 0]).await.unwrap();
tcp.flush().await.unwrap();
loop {
let mut header = [0u8; 9];
tcp.read_exact(&mut header).await.unwrap();
let length = u32::from_be_bytes([0, header[0], header[1], header[2]]) as usize;
let frame_type = header[3];
let flags = header[4];
let mut payload = vec![0u8; length];
tcp.read_exact(&mut payload).await.unwrap();
if frame_type == 0x4 && flags & 0x1 == 0 {
return parse_max_concurrent_streams(&payload);
}
}
}
fn parse_max_concurrent_streams(payload: &[u8]) -> Option<u32> {
payload.chunks_exact(6).find_map(|entry| {
let id = u16::from_be_bytes([entry[0], entry[1]]);
(id == SETTINGS_MAX_CONCURRENT_STREAMS_ID)
.then(|| u32::from_be_bytes([entry[2], entry[3], entry[4], entry[5]]))
})
}
#[tokio::test]
async fn global_shutdown_future_resolves_on_signal() {
let (tx, rx) = tokio::sync::watch::channel(false);
let mut fut = global_shutdown_future(rx);
assert!(
tokio::time::timeout(Duration::from_millis(50), &mut fut)
.await
.is_err(),
"shutdown future resolved before any signal",
);
tx.send(true).unwrap();
tokio::time::timeout(Duration::from_secs(1), fut)
.await
.expect("shutdown future must resolve after send(true)");
}
#[tokio::test]
async fn global_shutdown_future_resolves_when_sender_dropped() {
let (tx, rx) = tokio::sync::watch::channel(false);
let fut = global_shutdown_future(rx);
drop(tx);
tokio::time::timeout(Duration::from_secs(1), fut)
.await
.expect("shutdown future must resolve when the sender is dropped");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_sends_h2_goaway_without_global_shutdown() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
resp.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"server did not close idle h2 connection after max age"
);
let conn_result = h2_task.await.expect("h2 connection task panicked");
if let Err(err) = conn_result {
assert!(
err.is_go_away(),
"h2 connection ended with non-GOAWAY error: {err:?}"
);
}
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_retiring_one_connection_keeps_listener_running() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
resp.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"aged connection should retire without stopping listener"
);
h2_task.await.expect("h2 connection task panicked").ok();
drop(send_request);
let second = tokio::net::TcpStream::connect(addr).await;
assert!(
second.is_ok(),
"listener should still accept new connections after one ages out"
);
drop(second);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_inflight_stream_completes_during_grace() {
let (router, entered_rx, release_tx) = slow_router();
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10))
.with_max_connection_age_grace(Duration::from_secs(5));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let resp_task = tokio::spawn(resp);
entered_rx.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
!resp_task.is_finished(),
"response should remain in-flight during max-age grace"
);
release_tx.send(()).unwrap();
yield_to_tasks().await;
assert!(
resp_task.is_finished(),
"in-flight response did not complete during grace"
);
let resp = resp_task
.await
.expect("response task panicked")
.expect("h2 request failed");
assert!(resp.status().is_success(), "got status {}", resp.status());
drain_h2_body(resp).await;
drop(send_request);
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"h2 connection should close after graceful max-age drain"
);
h2_task.await.expect("h2 connection task panicked").ok();
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_unfinished_stream_closes_after_grace() {
let (router, entered_rx, _release_tx) = slow_router();
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10))
.with_max_connection_age_grace(Duration::from_secs(5));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let resp_task = tokio::spawn(resp);
entered_rx.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
!resp_task.is_finished(),
"unfinished stream should remain open until age grace expires"
);
tokio::time::advance(Duration::from_secs(6)).await;
yield_to_tasks().await;
assert!(
resp_task.is_finished(),
"unfinished in-flight stream should close after age grace"
);
let resp_result = resp_task.await.expect("response task panicked");
assert!(
resp_result.is_err(),
"unfinished stream unexpectedly completed after max-age grace"
);
drop(send_request);
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"h2 connection should close after max-age grace expires"
);
h2_task.await.expect("h2 connection task panicked").ok();
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_http1_keep_alive_connections_retire() {
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(
|_ctx: crate::RequestContext, _req: buffa_types::Empty| async move {
crate::Response::ok(buffa_types::Empty::default())
},
),
);
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
stream.write_all(KEEPALIVE_ECHO_REQ).await.unwrap();
let resp = read_http1_response(&mut stream).await;
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
let mut buf = [0; 1];
let read = stream.read(&mut buf).await.unwrap();
assert_eq!(read, 0, "HTTP/1.1 keep-alive connection stayed open");
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_grace_does_not_cap_global_shutdown() {
let (router, entered_rx, release_tx) = slow_router();
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10))
.with_max_connection_age_grace(Duration::from_secs(1));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let resp_task = tokio::spawn(resp);
entered_rx.await.unwrap();
shutdown_tx.send(()).unwrap();
tokio::time::advance(Duration::from_secs(30)).await;
yield_to_tasks().await;
assert!(
!serve.is_finished(),
"global shutdown should not be capped by max-age grace"
);
assert!(
!resp_task.is_finished(),
"global shutdown should keep in-flight request alive"
);
release_tx.send(()).unwrap();
yield_to_tasks().await;
assert!(
resp_task.is_finished(),
"in-flight response did not complete after release"
);
let resp = resp_task
.await
.expect("response task panicked")
.expect("h2 request failed");
assert!(resp.status().is_success(), "got status {}", resp.status());
drain_h2_body(resp).await;
drop(send_request);
yield_to_tasks().await;
h2_task.await.expect("h2 connection task panicked").ok();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_age_global_shutdown_during_age_grace_drains_indefinitely() {
let (router, entered_rx, release_tx) = slow_router();
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(10))
.with_max_connection_age_grace(Duration::from_secs(1));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let resp_task = tokio::spawn(resp);
entered_rx.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
!resp_task.is_finished(),
"request should still be in-flight during age grace"
);
shutdown_tx.send(()).unwrap();
tokio::time::advance(Duration::from_secs(30)).await;
yield_to_tasks().await;
assert!(
!serve.is_finished(),
"global shutdown during age grace should drain indefinitely"
);
assert!(
!resp_task.is_finished(),
"global shutdown during age grace should not force-close the request"
);
release_tx.send(()).unwrap();
yield_to_tasks().await;
assert!(
resp_task.is_finished(),
"in-flight response did not complete after release"
);
let resp = resp_task
.await
.expect("response task panicked")
.expect("h2 request failed");
assert!(resp.status().is_success(), "got status {}", resp.status());
drain_h2_body(resp).await;
drop(send_request);
yield_to_tasks().await;
h2_task.await.expect("h2 connection task panicked").ok();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[test]
fn connection_activity_tracks_in_flight_and_epoch() {
let activity = ConnectionActivity::default();
assert_eq!(activity.snapshot(), (0, 0));
let guard = ActiveRequestGuard::new(Arc::new(ConnectionActivity::default()));
drop(guard);
let shared = Arc::new(ConnectionActivity::default());
let g1 = ActiveRequestGuard::new(Arc::clone(&shared));
let g2 = ActiveRequestGuard::new(Arc::clone(&shared));
assert_eq!(shared.snapshot(), (2, 2));
drop(g1);
assert_eq!(shared.snapshot(), (1, 3));
drop(g2);
assert_eq!(shared.snapshot(), (0, 4));
}
#[tokio::test]
async fn max_connection_idle_builder_defaults_and_overrides() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert_eq!(bound.max_connection_idle, None);
assert_eq!(bound.connection_idle_config(), None);
let bound = bound.with_max_connection_idle(Duration::from_secs(30));
assert_eq!(bound.max_connection_idle, Some(Duration::from_secs(30)));
assert_eq!(
bound.connection_idle_config(),
Some(IdleConfig {
idle: Duration::from_secs(30),
grace: DEFAULT_MAX_CONNECTION_AGE_GRACE,
})
);
}
#[test]
fn server_max_connection_idle_builder_threads_through() {
let server = Server::new(Router::new());
assert_eq!(server.max_connection_idle, None);
assert_eq!(server.connection_idle_config(), None);
let server = Server::new(Router::new())
.with_max_connection_idle(Duration::from_secs(30))
.with_max_connection_age_grace(Duration::from_secs(2));
assert_eq!(
server.connection_idle_config(),
Some(IdleConfig {
idle: Duration::from_secs(30),
grace: Duration::from_secs(2),
})
);
}
#[test]
#[should_panic(expected = "non-zero duration")]
fn with_max_connection_idle_rejects_zero() {
let _ = Server::new(Router::new()).with_max_connection_idle(Duration::ZERO);
}
#[tokio::test(start_paused = true)]
async fn max_connection_idle_reaps_quiet_connection() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_idle(Duration::from_secs(10));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
resp.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
!h2_task.is_finished(),
"connection reaped despite activity within the idle window"
);
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"idle connection was not reaped after a quiet window"
);
let conn_result = h2_task.await.expect("h2 connection task panicked");
if let Err(err) = conn_result {
assert!(
err.is_go_away(),
"h2 connection ended with non-GOAWAY error: {err:?}"
);
}
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_idle_inflight_request_prevents_reaping() {
let (router, entered_rx, release_tx) = slow_router();
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_idle(Duration::from_secs(10));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Slow"))
.header(header::CONTENT_TYPE, "application/proto")
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
let resp_task = tokio::spawn(resp);
entered_rx.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
!h2_task.is_finished(),
"connection with an in-flight request was retired by the idle timer"
);
assert!(
!resp_task.is_finished(),
"in-flight request unexpectedly ended"
);
release_tx.send(()).unwrap();
yield_to_tasks().await;
let resp = resp_task
.await
.expect("response task panicked")
.expect("h2 request failed");
assert!(resp.status().is_success(), "got status {}", resp.status());
drain_h2_body(resp).await;
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"connection was not reaped after the in-flight request completed"
);
h2_task.await.expect("h2 connection task panicked").ok();
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_connection_idle_fires_before_a_longer_max_age() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(60))
.with_max_connection_idle(Duration::from_secs(10));
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
resp.await.unwrap();
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
tokio::time::advance(Duration::from_secs(11)).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"idle timer did not retire the connection before max age"
);
h2_task.await.expect("h2 connection task panicked").ok();
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test]
async fn max_requests_per_connection_builder_defaults_and_threads_through() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener);
assert_eq!(bound.max_requests_per_connection, None);
assert_eq!(bound.request_retirement_config(), None);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound = Server::from_listener(listener)
.with_max_requests_per_connection(NonZeroU64::new(100).unwrap())
.with_max_connection_age_grace(Duration::from_secs(3));
assert_eq!(bound.max_requests_per_connection, NonZeroU64::new(100));
assert_eq!(
bound.request_retirement_config(),
Some(RequestRetirementConfig {
max: NonZeroU64::new(100).unwrap(),
grace: Duration::from_secs(3),
})
);
let server = Server::new(Router::new());
assert_eq!(server.max_requests_per_connection, None);
assert_eq!(server.request_retirement_config(), None);
let server = Server::new(Router::new())
.with_max_requests_per_connection(NonZeroU64::new(5).unwrap());
assert_eq!(
server.request_retirement_config(),
Some(RequestRetirementConfig {
max: NonZeroU64::new(5).unwrap(),
grace: DEFAULT_MAX_CONNECTION_AGE_GRACE,
})
);
}
#[tokio::test(start_paused = true)]
async fn max_requests_per_connection_retires_h2_after_limit() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_requests_per_connection(NonZeroU64::new(2).unwrap());
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
send_unary(&mut send_request, addr).await;
yield_to_tasks().await;
assert!(
!h2_task.is_finished(),
"connection retired before reaching the request limit"
);
send_unary(&mut send_request, addr).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"connection did not retire after reaching the request limit"
);
let conn_result = h2_task.await.expect("h2 connection task panicked");
if let Err(err) = conn_result {
assert!(
err.is_go_away(),
"h2 connection ended with non-GOAWAY error: {err:?}"
);
}
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_requests_per_connection_unlimited_when_unset() {
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
for _ in 0..5 {
send_unary(&mut send_request, addr).await;
}
yield_to_tasks().await;
assert!(
!h2_task.is_finished(),
"connection retired despite no request limit being configured"
);
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_requests_per_connection_first_trigger_wins_over_age() {
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_connection_age(Duration::from_secs(3600))
.with_max_requests_per_connection(NonZeroU64::new(1).unwrap());
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(Router::new(), async {
shutdown_rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let (mut send_request, h2_conn) = h2::client::handshake(tcp).await.unwrap();
let h2_task = tokio::spawn(h2_conn);
send_unary(&mut send_request, addr).await;
yield_to_tasks().await;
assert!(
h2_task.is_finished(),
"request limit should retire the connection before the max age"
);
h2_task.await.expect("h2 connection task panicked").ok();
drop(send_request);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
#[tokio::test(start_paused = true)]
async fn max_requests_per_connection_retires_http1_after_limit() {
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(
|_ctx: crate::RequestContext, _req: buffa_types::Empty| async move {
crate::Response::ok(buffa_types::Empty::default())
},
),
);
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_max_requests_per_connection(NonZeroU64::new(1).unwrap());
let addr = bound.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
shutdown_rx.await.ok();
})
.await
});
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
stream.write_all(KEEPALIVE_ECHO_REQ).await.unwrap();
let resp = read_http1_response(&mut stream).await;
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
yield_to_tasks().await;
let mut buf = [0; 1];
let read = stream.read(&mut buf).await.unwrap();
assert_eq!(
read, 0,
"HTTP/1.1 keep-alive connection stayed open past the request limit"
);
shutdown_tx.send(()).unwrap();
let result = tokio::time::timeout(Duration::from_secs(1), serve)
.await
.expect("server did not shut down")
.expect("join error");
assert!(result.is_ok(), "serve returned error: {result:?}");
}
async fn send_unary(send_request: &mut h2::client::SendRequest<Bytes>, addr: SocketAddr) {
let req = http::Request::builder()
.method(http::Method::POST)
.uri(format!("http://{addr}/svc/Unknown"))
.body(())
.unwrap();
let (resp, _) = send_request.send_request(req, true).unwrap();
resp.await.unwrap();
}
fn slow_router() -> (
Router,
tokio::sync::oneshot::Receiver<()>,
tokio::sync::oneshot::Sender<()>,
) {
let (entered_tx, entered_rx) = tokio::sync::oneshot::channel();
let (release_tx, release_rx) = tokio::sync::oneshot::channel();
let chans = Arc::new(Mutex::new(Some((entered_tx, release_rx))));
let router = Router::new().route(
"svc",
"Slow",
crate::handler_fn(
move |_ctx: crate::RequestContext, _req: buffa_types::Empty| {
let chans = Arc::clone(&chans);
async move {
let taken = chans.lock().unwrap().take();
if let Some((entered_tx, release_rx)) = taken {
entered_tx.send(()).ok();
release_rx.await.ok();
}
crate::Response::ok(buffa_types::Empty::default())
}
},
),
);
(router, entered_rx, release_tx)
}
async fn yield_to_tasks() {
for _ in 0..5 {
tokio::task::yield_now().await;
}
}
async fn drain_h2_body(mut resp: http::Response<h2::RecvStream>) {
while let Some(chunk) = resp.body_mut().data().await {
chunk.expect("h2 response body failed");
}
}
async fn read_http1_response(stream: &mut tokio::net::TcpStream) -> Vec<u8> {
let mut resp = Vec::new();
let mut buf = [0; 1024];
loop {
let read = stream.read(&mut buf).await.unwrap();
assert!(read > 0, "connection closed before full response arrived");
resp.extend_from_slice(&buf[..read]);
let Some(header_end) = find_header_end(&resp) else {
continue;
};
let body_start = header_end + 4;
let content_length = content_length(&resp[..header_end]).unwrap_or(0);
if resp.len() >= body_start + content_length {
return resp;
}
}
}
fn find_header_end(bytes: &[u8]) -> Option<usize> {
bytes.windows(4).position(|window| window == b"\r\n\r\n")
}
fn content_length(headers: &[u8]) -> Option<usize> {
std::str::from_utf8(headers).ok()?.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse().ok())
.flatten()
})
}
#[tokio::test]
async fn peer_addr_reaches_handler() {
let captured: Arc<Mutex<Option<std::net::SocketAddr>>> = Arc::new(Mutex::new(None));
let handler_captured = Arc::clone(&captured);
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(
move |ctx: crate::RequestContext, _req: buffa_types::Empty| {
let cap = Arc::clone(&handler_captured);
async move {
*cap.lock().unwrap() = ctx.peer_addr();
crate::Response::ok(buffa_types::Empty::default())
}
},
),
);
let bound = Server::bind("127.0.0.1:0").await.unwrap();
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
rx.await.ok();
})
.await
});
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let client_local = stream.local_addr().unwrap();
stream.write_all(ECHO_REQ).await.unwrap();
let mut resp = Vec::new();
stream.read_to_end(&mut resp).await.unwrap();
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
let peer = captured
.lock()
.unwrap()
.take()
.expect("handler should have captured PeerAddr");
assert_eq!(peer, client_local);
}
#[cfg(feature = "server-tls")]
#[tokio::test]
async fn peer_certs_reach_handler() {
fn pki() -> (
Arc<rustls::ServerConfig>,
Arc<rustls::ClientConfig>,
rustls::pki_types::CertificateDer<'static>,
) {
use rcgen::CertificateParams;
use rcgen::KeyPair;
use rcgen::SanType;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivatePkcs8KeyDer;
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let ca_key = KeyPair::generate().unwrap();
let mut ca_params = CertificateParams::default();
ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
let ca = rcgen::CertifiedIssuer::self_signed(ca_params, ca_key).unwrap();
let issue = |sans: &[SanType]| {
let k = KeyPair::generate().unwrap();
let mut p = CertificateParams::default();
p.subject_alt_names = sans.to_vec();
let c = p.signed_by(&k, &ca).unwrap();
(
CertificateDer::from(c.der().to_vec()),
PrivatePkcs8KeyDer::from(k.serialized_der().to_vec()).into(),
)
};
let (srv_cert, srv_key) = issue(&[SanType::DnsName("localhost".try_into().unwrap())]);
let (cli_cert, cli_key) = issue(&[]);
let mut roots = rustls::RootCertStore::empty();
roots.add(CertificateDer::from(ca.der().to_vec())).unwrap();
let roots = Arc::new(roots);
let cv = rustls::server::WebPkiClientVerifier::builder(Arc::clone(&roots))
.build()
.unwrap();
let server = rustls::ServerConfig::builder()
.with_client_cert_verifier(cv)
.with_single_cert(vec![srv_cert], srv_key)
.unwrap();
let client = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_client_auth_cert(vec![cli_cert.clone()], cli_key)
.unwrap();
(Arc::new(server), Arc::new(client), cli_cert)
}
let (server_cfg, client_cfg, expected_client_der) = pki();
type CapturedCerts = Vec<rustls::pki_types::CertificateDer<'static>>;
let captured: Arc<Mutex<Option<CapturedCerts>>> = Arc::new(Mutex::new(None));
let handler_captured = Arc::clone(&captured);
let router = Router::new().route(
"svc",
"Echo",
crate::handler_fn(
move |ctx: crate::RequestContext, _req: buffa_types::Empty| {
let cap = Arc::clone(&handler_captured);
async move {
*cap.lock().unwrap() = ctx.peer_certs().map(<[_]>::to_vec);
crate::Response::ok(buffa_types::Empty::default())
}
},
),
);
let bound = Server::bind("127.0.0.1:0")
.await
.unwrap()
.with_tls(server_cfg);
let addr = bound.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let serve = tokio::spawn(async move {
bound
.serve_with_graceful_shutdown(router, async {
rx.await.ok();
})
.await
});
let tcp = tokio::net::TcpStream::connect(addr).await.unwrap();
let connector = tokio_rustls::TlsConnector::from(client_cfg);
let sni = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let mut tls = connector.connect(sni, tcp).await.unwrap();
tls.write_all(ECHO_REQ).await.unwrap();
let mut resp = Vec::new();
tls.read_to_end(&mut resp).await.unwrap();
assert!(
resp.starts_with(b"HTTP/1.1 2"),
"expected 2xx, got: {}",
String::from_utf8_lossy(&resp[..resp.len().min(80)])
);
tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), serve)
.await
.unwrap()
.unwrap()
.unwrap();
let certs = captured
.lock()
.unwrap()
.take()
.expect("handler should have captured PeerCerts");
assert_eq!(certs.len(), 1);
assert_eq!(certs[0].as_ref(), expected_client_der.as_ref());
}
}