mod conn;
mod display_error_stack;
mod incoming;
mod io_stream;
mod service;
#[cfg(feature = "_tls-any")]
mod tls;
#[cfg(unix)]
mod unix;
use tokio_stream::StreamExt as _;
use tracing::{debug, trace};
#[cfg(feature = "router")]
use crate::{server::NamedService, service::Routes};
#[cfg(feature = "router")]
use std::convert::Infallible;
pub use conn::{Connected, TcpConnectInfo};
use hyper_util::{
rt::{TokioExecutor, TokioIo, TokioTimer},
server::conn::auto::{Builder as ConnectionBuilder, HttpServerConnExec},
service::TowerToHyperService,
};
#[cfg(feature = "_tls-any")]
pub use tls::ServerTlsConfig;
#[cfg(feature = "_tls-any")]
pub use conn::TlsConnectInfo;
#[cfg(feature = "_tls-any")]
use self::service::TlsAcceptor;
#[cfg(unix)]
pub use unix::UdsConnectInfo;
pub use incoming::TcpIncoming;
#[cfg(feature = "_tls-any")]
use crate::transport::Error;
use self::service::{ConnectInfoLayer, ServerIo};
use super::service::GrpcTimeout;
use crate::body::Body;
use crate::service::RecoverErrorLayer;
use crate::transport::server::display_error_stack::DisplayErrorStack;
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::BodyExt;
use hyper::{body::Incoming, service::Service as HyperService};
use pin_project::pin_project;
use std::{
fmt,
future::{self, Future},
marker::PhantomData,
net::SocketAddr,
pin::{Pin, pin},
sync::Arc,
task::{Context, Poll, ready},
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::Stream;
use tower::{
Service, ServiceBuilder, ServiceExt,
layer::Layer,
layer::util::{Identity, Stack},
limit::concurrency::ConcurrencyLimitLayer,
load_shed::LoadShedLayer,
util::BoxCloneService,
};
type BoxService = tower::util::BoxCloneService<Request<Body>, Response<Body>, crate::BoxError>;
type TraceInterceptor = Arc<dyn Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static>;
const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
#[derive(Clone)]
pub struct Server<L = Identity> {
trace_interceptor: Option<TraceInterceptor>,
concurrency_limit: Option<usize>,
load_shed: bool,
timeout: Option<Duration>,
#[cfg(feature = "_tls-any")]
tls: Option<TlsAcceptor>,
init_stream_window_size: Option<u32>,
init_connection_window_size: Option<u32>,
max_concurrent_streams: Option<u32>,
tcp_keepalive: Option<Duration>,
tcp_keepalive_interval: Option<Duration>,
tcp_keepalive_retries: Option<u32>,
tcp_nodelay: bool,
http2_keepalive_interval: Option<Duration>,
http2_keepalive_timeout: Duration,
http2_adaptive_window: Option<bool>,
http2_max_pending_accept_reset_streams: Option<usize>,
http2_max_local_error_reset_streams: Option<usize>,
http2_max_header_list_size: Option<u32>,
max_frame_size: Option<u32>,
accept_http1: bool,
service_builder: ServiceBuilder<L>,
max_connection_age: Option<Duration>,
max_connection_age_grace: Option<Duration>,
}
impl Default for Server<Identity> {
fn default() -> Self {
Self {
trace_interceptor: None,
concurrency_limit: None,
load_shed: false,
timeout: None,
#[cfg(feature = "_tls-any")]
tls: None,
init_stream_window_size: None,
init_connection_window_size: None,
max_concurrent_streams: None,
tcp_keepalive: None,
tcp_keepalive_interval: None,
tcp_keepalive_retries: None,
tcp_nodelay: true,
http2_keepalive_interval: None,
http2_keepalive_timeout: DEFAULT_HTTP2_KEEPALIVE_TIMEOUT,
http2_adaptive_window: None,
http2_max_pending_accept_reset_streams: None,
http2_max_local_error_reset_streams: None,
http2_max_header_list_size: None,
max_frame_size: None,
accept_http1: false,
service_builder: Default::default(),
max_connection_age: None,
max_connection_age_grace: None,
}
}
}
#[cfg(feature = "router")]
#[derive(Clone, Debug)]
pub struct Router<L = Identity> {
server: Server<L>,
routes: Routes,
}
impl Server {
pub fn builder() -> Self {
Self::default()
}
}
impl<L> Server<L> {
#[cfg(feature = "_tls-any")]
pub fn tls_config(self, tls_config: ServerTlsConfig) -> Result<Self, Error> {
Ok(Server {
tls: Some(tls_config.tls_acceptor().map_err(Error::from_source)?),
..self
})
}
#[must_use]
pub fn concurrency_limit_per_connection(self, limit: usize) -> Self {
Server {
concurrency_limit: Some(limit),
..self
}
}
#[must_use]
pub fn load_shed(self, load_shed: bool) -> Self {
Server { load_shed, ..self }
}
#[must_use]
pub fn timeout(self, timeout: Duration) -> Self {
Server {
timeout: Some(timeout),
..self
}
}
#[must_use]
pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
Server {
init_stream_window_size: sz.into(),
..self
}
}
#[must_use]
pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
Server {
init_connection_window_size: sz.into(),
..self
}
}
#[must_use]
pub fn max_concurrent_streams(self, max: impl Into<Option<u32>>) -> Self {
Server {
max_concurrent_streams: max.into(),
..self
}
}
#[must_use]
pub fn max_connection_age(self, max_connection_age: Duration) -> Self {
Server {
max_connection_age: Some(max_connection_age),
..self
}
}
#[must_use]
pub fn max_connection_age_grace(self, max_connection_age_grace: Duration) -> Self {
Server {
max_connection_age_grace: Some(max_connection_age_grace),
..self
}
}
#[must_use]
pub fn http2_keepalive_interval(self, http2_keepalive_interval: Option<Duration>) -> Self {
Server {
http2_keepalive_interval,
..self
}
}
#[must_use]
pub fn http2_keepalive_timeout(mut self, http2_keepalive_timeout: Option<Duration>) -> Self {
if let Some(timeout) = http2_keepalive_timeout {
self.http2_keepalive_timeout = timeout;
}
self
}
#[must_use]
pub fn http2_adaptive_window(self, enabled: Option<bool>) -> Self {
Server {
http2_adaptive_window: enabled,
..self
}
}
#[must_use]
pub fn http2_max_pending_accept_reset_streams(self, max: Option<usize>) -> Self {
Server {
http2_max_pending_accept_reset_streams: max,
..self
}
}
#[must_use]
pub fn http2_max_local_error_reset_streams(self, max: Option<usize>) -> Self {
Server {
http2_max_local_error_reset_streams: max,
..self
}
}
#[must_use]
pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
Server {
tcp_keepalive,
..self
}
}
#[must_use]
pub fn tcp_keepalive_interval(self, tcp_keepalive_interval: Option<Duration>) -> Self {
Server {
tcp_keepalive_interval,
..self
}
}
#[must_use]
pub fn tcp_keepalive_retries(self, tcp_keepalive_retries: Option<u32>) -> Self {
Server {
tcp_keepalive_retries,
..self
}
}
#[must_use]
pub fn tcp_nodelay(self, enabled: bool) -> Self {
Server {
tcp_nodelay: enabled,
..self
}
}
#[must_use]
pub fn http2_max_header_list_size(self, max: impl Into<Option<u32>>) -> Self {
Server {
http2_max_header_list_size: max.into(),
..self
}
}
#[must_use]
pub fn max_frame_size(self, frame_size: impl Into<Option<u32>>) -> Self {
Server {
max_frame_size: frame_size.into(),
..self
}
}
#[must_use]
pub fn accept_http1(self, accept_http1: bool) -> Self {
Server {
accept_http1,
..self
}
}
#[must_use]
pub fn trace_fn<F>(self, f: F) -> Self
where
F: Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static,
{
Server {
trace_interceptor: Some(Arc::new(f)),
..self
}
}
#[cfg(feature = "router")]
pub fn add_service<S>(&mut self, svc: S) -> Router<L>
where
S: Service<Request<Body>, Error = Infallible>
+ NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Response: axum::response::IntoResponse,
S::Future: Send + 'static,
L: Clone,
{
Router::new(self.clone(), Routes::new(svc))
}
#[cfg(feature = "router")]
pub fn add_optional_service<S>(&mut self, svc: Option<S>) -> Router<L>
where
S: Service<Request<Body>, Error = Infallible>
+ NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Response: axum::response::IntoResponse,
S::Future: Send + 'static,
L: Clone,
{
let routes = svc.map(Routes::new).unwrap_or_default();
Router::new(self.clone(), routes)
}
#[cfg(feature = "router")]
pub fn add_routes(&mut self, routes: Routes) -> Router<L>
where
L: Clone,
{
Router::new(self.clone(), routes)
}
pub fn layer<NewLayer>(self, new_layer: NewLayer) -> Server<Stack<NewLayer, L>> {
Server {
service_builder: self.service_builder.layer(new_layer),
trace_interceptor: self.trace_interceptor,
concurrency_limit: self.concurrency_limit,
load_shed: self.load_shed,
timeout: self.timeout,
#[cfg(feature = "_tls-any")]
tls: self.tls,
init_stream_window_size: self.init_stream_window_size,
init_connection_window_size: self.init_connection_window_size,
max_concurrent_streams: self.max_concurrent_streams,
tcp_keepalive: self.tcp_keepalive,
tcp_keepalive_interval: self.tcp_keepalive_interval,
tcp_keepalive_retries: self.tcp_keepalive_retries,
tcp_nodelay: self.tcp_nodelay,
http2_keepalive_interval: self.http2_keepalive_interval,
http2_keepalive_timeout: self.http2_keepalive_timeout,
http2_adaptive_window: self.http2_adaptive_window,
http2_max_pending_accept_reset_streams: self.http2_max_pending_accept_reset_streams,
http2_max_header_list_size: self.http2_max_header_list_size,
http2_max_local_error_reset_streams: self.http2_max_local_error_reset_streams,
max_frame_size: self.max_frame_size,
accept_http1: self.accept_http1,
max_connection_age: self.max_connection_age,
max_connection_age_grace: self.max_connection_age_grace,
}
}
fn bind_incoming(&self, addr: SocketAddr) -> Result<TcpIncoming, super::Error> {
Ok(TcpIncoming::bind(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.tcp_nodelay))
.with_keepalive(self.tcp_keepalive)
.with_keepalive_interval(self.tcp_keepalive_interval)
.with_keepalive_retries(self.tcp_keepalive_retries))
}
pub async fn serve<S, ResBody>(self, addr: SocketAddr, svc: S) -> Result<(), super::Error>
where
L: Layer<S>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send + 'static,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = self.bind_incoming(addr)?;
self.serve_with_incoming(svc, incoming).await
}
pub async fn serve_with_shutdown<S, F, ResBody>(
self,
addr: SocketAddr,
svc: S,
signal: F,
) -> Result<(), super::Error>
where
L: Layer<S>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send + 'static,
F: Future<Output = ()>,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = self.bind_incoming(addr)?;
self.serve_with_incoming_shutdown(svc, incoming, signal)
.await
}
pub async fn serve_with_incoming<S, I, IO, IE, ResBody>(
self,
svc: S,
incoming: I,
) -> Result<(), super::Error>
where
L: Layer<S>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send + 'static,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
self.serve_internal(svc, incoming, Option::<future::Ready<()>>::None)
.await
}
pub async fn serve_with_incoming_shutdown<S, I, F, IO, IE, ResBody>(
self,
svc: S,
incoming: I,
signal: F,
) -> Result<(), super::Error>
where
L: Layer<S>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send + 'static,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
F: Future<Output = ()>,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
self.serve_internal(svc, incoming, Some(signal)).await
}
async fn serve_internal<S, I, F, IO, IE, ResBody>(
self,
svc: S,
incoming: I,
signal: Option<F>,
) -> Result<(), super::Error>
where
L: Layer<S>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send + 'static,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
F: Future<Output = ()>,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let trace_interceptor = self.trace_interceptor.clone();
let concurrency_limit = self.concurrency_limit;
let load_shed = self.load_shed;
let init_connection_window_size = self.init_connection_window_size;
let init_stream_window_size = self.init_stream_window_size;
let max_concurrent_streams = self.max_concurrent_streams;
let timeout = self.timeout;
let max_header_list_size = self.http2_max_header_list_size;
let max_frame_size = self.max_frame_size;
let http2_only = !self.accept_http1;
let http2_keepalive_interval = self.http2_keepalive_interval;
let http2_keepalive_timeout = self.http2_keepalive_timeout;
let http2_adaptive_window = self.http2_adaptive_window;
let http2_max_pending_accept_reset_streams = self.http2_max_pending_accept_reset_streams;
let http2_max_local_error_reset_streams = self.http2_max_local_error_reset_streams;
let max_connection_age = self.max_connection_age;
let max_connection_age_grace = self.max_connection_age_grace;
let svc = self.service_builder.service(svc);
let incoming = io_stream::ServerIoStream::new(
incoming,
#[cfg(feature = "_tls-any")]
self.tls,
);
let mut svc = MakeSvc {
inner: svc,
concurrency_limit,
load_shed,
timeout,
trace_interceptor,
_io: PhantomData,
};
let server = {
let mut builder = ConnectionBuilder::new(TokioExecutor::new());
if http2_only {
builder = builder.http2_only();
}
builder
.http2()
.timer(TokioTimer::new())
.initial_connection_window_size(init_connection_window_size)
.initial_stream_window_size(init_stream_window_size)
.max_concurrent_streams(max_concurrent_streams)
.keep_alive_interval(http2_keepalive_interval)
.keep_alive_timeout(http2_keepalive_timeout)
.adaptive_window(http2_adaptive_window.unwrap_or_default())
.max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams)
.max_local_error_reset_streams(http2_max_local_error_reset_streams)
.max_frame_size(max_frame_size);
if let Some(max_header_list_size) = max_header_list_size {
builder.http2().max_header_list_size(max_header_list_size);
}
builder
};
let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
let signal_tx = Arc::new(signal_tx);
let graceful = signal.is_some();
let mut sig = pin!(Fuse { inner: signal });
let mut incoming = pin!(incoming);
loop {
tokio::select! {
_ = &mut sig => {
trace!("signal received, shutting down");
break;
},
io = incoming.next() => {
let io = match io {
Some(Ok(io)) => io,
Some(Err(e)) => {
trace!("error accepting connection: {}", DisplayErrorStack(&*e));
continue;
},
None => {
break
},
};
trace!("connection accepted");
let req_svc = svc
.call(&io)
.await
.map_err(super::Error::from_source)?;
let hyper_io = TokioIo::new(io);
let hyper_svc = TowerToHyperService::new(req_svc.map_request(|req: Request<Incoming>| req.map(Body::new)));
serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), max_connection_age, max_connection_age_grace);
}
}
}
if graceful {
let _ = signal_tx.send(());
drop(signal_rx);
trace!(
"waiting for {} connections to close",
signal_tx.receiver_count()
);
signal_tx.closed().await;
}
Ok(())
}
}
enum TimeoutAction {
GracefulShutdown,
ForcefulShutdown,
}
async fn connection_timeout_future(
max_connection_age: Option<Duration>,
max_connection_age_grace: Option<Duration>,
) -> TimeoutAction {
if let Some(age) = max_connection_age {
tokio::time::sleep(age).await;
if let Some(grace) = max_connection_age_grace {
tokio::time::sleep(grace).await;
TimeoutAction::ForcefulShutdown
} else {
TimeoutAction::GracefulShutdown
}
} else {
future::pending().await
}
}
fn serve_connection<B, IO, S, E>(
hyper_io: IO,
hyper_svc: S,
builder: ConnectionBuilder<E>,
mut watcher: Option<tokio::sync::watch::Receiver<()>>,
max_connection_age: Option<Duration>,
max_connection_age_grace: Option<Duration>,
) where
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
S: HyperService<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
E: HttpServerConnExec<S::Future, B> + Send + Sync + 'static,
{
tokio::spawn(async move {
{
let mut sig = pin!(Fuse {
inner: watcher.as_mut().map(|w| w.changed()),
});
let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc));
let mut connection_timeout = pin!(connection_timeout_future(
max_connection_age,
max_connection_age_grace,
));
loop {
tokio::select! {
rv = &mut conn => {
if let Err(err) = rv {
debug!("failed serving connection: {}", DisplayErrorStack(&*err));
}
break;
},
timeout_action = &mut connection_timeout => {
match timeout_action {
TimeoutAction::GracefulShutdown => {
conn.as_mut().graceful_shutdown();
},
TimeoutAction::ForcefulShutdown => {
debug!("forcefully closed connection");
break;
}
}
},
_ = &mut sig => {
conn.as_mut().graceful_shutdown();
},
}
}
}
drop(watcher);
trace!("connection closed");
});
}
#[cfg(feature = "router")]
impl<L> Router<L> {
pub(crate) fn new(server: Server<L>, routes: Routes) -> Self {
Self { server, routes }
}
}
#[cfg(feature = "router")]
impl<L> Router<L> {
pub fn add_service<S>(mut self, svc: S) -> Self
where
S: Service<Request<Body>, Error = Infallible>
+ NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Response: axum::response::IntoResponse,
S::Future: Send + 'static,
{
self.routes = self.routes.add_service(svc);
self
}
pub fn add_optional_service<S>(mut self, svc: Option<S>) -> Self
where
S: Service<Request<Body>, Error = Infallible>
+ NamedService
+ Clone
+ Send
+ Sync
+ 'static,
S::Response: axum::response::IntoResponse,
S::Future: Send + 'static,
{
if let Some(svc) = svc {
self.routes = self.routes.add_service(svc);
}
self
}
pub async fn serve<ResBody>(self, addr: SocketAddr) -> Result<(), super::Error>
where
L: Layer<Routes> + Clone,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
self.server.serve(addr, self.routes.prepare()).await
}
pub async fn serve_with_shutdown<F: Future<Output = ()>, ResBody>(
self,
addr: SocketAddr,
signal: F,
) -> Result<(), super::Error>
where
L: Layer<Routes>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
self.server
.serve_with_shutdown(addr, self.routes.prepare(), signal)
.await
}
pub async fn serve_with_incoming<I, IO, IE, ResBody>(
self,
incoming: I,
) -> Result<(), super::Error>
where
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
L: Layer<Routes>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
self.server
.serve_with_incoming(self.routes.prepare(), incoming)
.await
}
pub async fn serve_with_incoming_shutdown<I, IO, IE, F, ResBody>(
self,
incoming: I,
signal: F,
) -> Result<(), super::Error>
where
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
F: Future<Output = ()>,
L: Layer<Routes>,
L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
<<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
Into<crate::BoxError> + Send,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
self.server
.serve_with_incoming_shutdown(self.routes.prepare(), incoming, signal)
.await
}
}
impl<L> fmt::Debug for Server<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Builder").finish()
}
}
#[derive(Clone)]
struct Svc<S> {
inner: S,
trace_interceptor: Option<TraceInterceptor>,
}
impl<S, ResBody> Service<Request<Body>> for Svc<S>
where
S: Service<Request<Body>, Response = Response<ResBody>>,
S::Error: Into<crate::BoxError>,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
type Response = Response<Body>;
type Error = crate::BoxError;
type Future = SvcFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let span = if let Some(trace_interceptor) = &self.trace_interceptor {
let (parts, body) = req.into_parts();
let bodyless_request = Request::from_parts(parts, ());
let span = trace_interceptor(&bodyless_request);
let (parts, _) = bodyless_request.into_parts();
req = Request::from_parts(parts, body);
span
} else {
tracing::Span::none()
};
SvcFuture {
inner: self.inner.call(req),
span,
}
}
}
#[pin_project]
struct SvcFuture<F> {
#[pin]
inner: F,
span: tracing::Span,
}
impl<F, E, ResBody> Future for SvcFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
E: Into<crate::BoxError>,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
type Output = Result<Response<Body>, crate::BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _guard = this.span.enter();
let response: Response<ResBody> = ready!(this.inner.poll(cx)).map_err(Into::into)?;
let response = response.map(|body| Body::new(body.map_err(Into::into)));
Poll::Ready(Ok(response))
}
}
impl<S> fmt::Debug for Svc<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Svc").finish()
}
}
#[derive(Clone)]
struct MakeSvc<S, IO> {
concurrency_limit: Option<usize>,
load_shed: bool,
timeout: Option<Duration>,
inner: S,
trace_interceptor: Option<TraceInterceptor>,
_io: PhantomData<fn() -> IO>,
}
impl<S, ResBody, IO> Service<&ServerIo<IO>> for MakeSvc<S, IO>
where
IO: Connected + 'static,
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<crate::BoxError> + Send,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
type Response = BoxService;
type Error = crate::BoxError;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn call(&mut self, io: &ServerIo<IO>) -> Self::Future {
let conn_info = io.connect_info();
let svc = self.inner.clone();
let concurrency_limit = self.concurrency_limit;
let timeout = self.timeout;
let trace_interceptor = self.trace_interceptor.clone();
let svc = ServiceBuilder::new()
.layer(RecoverErrorLayer::new())
.option_layer(self.load_shed.then_some(LoadShedLayer::new()))
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);
let svc = ServiceBuilder::new()
.layer(BoxCloneService::layer())
.layer(ConnectInfoLayer::new(conn_info.clone()))
.service(Svc {
inner: svc,
trace_interceptor,
});
future::ready(Ok(svc))
}
}
#[pin_project]
struct Fuse<F> {
#[pin]
inner: Option<F>,
}
impl<F> Future for Fuse<F>
where
F: Future,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().inner.as_pin_mut() {
Some(fut) => fut.poll(cx).map(|output| {
self.project().inner.set(None);
output
}),
None => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::Server;
use std::time::Duration;
#[tokio::test(start_paused = true)]
async fn test_connection_timeout_no_max_age() {
let future = connection_timeout_future(None, None);
tokio::select! {
_ = future => {
panic!("timeout future should never complete when max_connection_age is None");
}
_ = tokio::time::sleep(Duration::from_secs(1000)) => {
}
}
}
#[tokio::test(start_paused = true)]
async fn test_connection_timeout_with_max_connection_age() {
let future = connection_timeout_future(Some(Duration::from_secs(10)), None);
let action = future.await;
assert!(matches!(action, TimeoutAction::GracefulShutdown));
}
#[tokio::test(start_paused = true)]
async fn test_connection_timeout_with_max_connection_age_grace() {
let mut future = pin!(connection_timeout_future(
Some(Duration::from_secs(10)),
Some(Duration::from_secs(5)),
));
tokio::select! {
_ = &mut future => {
panic!("should not complete before max_connection_age");
}
_ = tokio::time::sleep(Duration::from_secs(9)) => {}
}
tokio::select! {
_ = &mut future => {
panic!("should not complete before max_connection_age_grace");
}
_ = tokio::time::sleep(Duration::from_secs(4)) => {}
}
let action = future.await;
assert!(matches!(action, TimeoutAction::ForcefulShutdown));
}
#[test]
fn server_tcp_defaults() {
const EXAMPLE_TCP_KEEPALIVE: Duration = Duration::from_secs(10);
const EXAMPLE_TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(5);
const EXAMPLE_TCP_KEEPALIVE_RETRIES: u32 = 3;
let server_via_builder = Server::builder();
assert!(server_via_builder.tcp_nodelay);
assert_eq!(server_via_builder.tcp_keepalive, None);
assert_eq!(server_via_builder.tcp_keepalive_interval, None);
assert_eq!(server_via_builder.tcp_keepalive_retries, None);
let server_via_default = Server::default();
assert!(server_via_default.tcp_nodelay);
assert_eq!(server_via_default.tcp_keepalive, None);
assert_eq!(server_via_default.tcp_keepalive_interval, None);
assert_eq!(server_via_default.tcp_keepalive_retries, None);
let server_via_builder = Server::builder()
.tcp_nodelay(false)
.tcp_keepalive(Some(EXAMPLE_TCP_KEEPALIVE))
.tcp_keepalive_interval(Some(EXAMPLE_TCP_KEEPALIVE_INTERVAL))
.tcp_keepalive_retries(Some(EXAMPLE_TCP_KEEPALIVE_RETRIES));
assert!(!server_via_builder.tcp_nodelay);
assert_eq!(
server_via_builder.tcp_keepalive,
Some(EXAMPLE_TCP_KEEPALIVE)
);
assert_eq!(
server_via_builder.tcp_keepalive_interval,
Some(EXAMPLE_TCP_KEEPALIVE_INTERVAL)
);
assert_eq!(
server_via_builder.tcp_keepalive_retries,
Some(EXAMPLE_TCP_KEEPALIVE_RETRIES)
);
}
}