#![doc(test(attr(deny(warnings))))]
#![allow(clippy::needless_doctest_main)]
#![forbid(unsafe_code)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use std::convert::Infallible;
use std::fmt::Debug;
use std::future::Future;
use std::io::Error as IoError;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use err_context::prelude::*;
use err_context::AnyError;
use hyper::body::Body;
use hyper::server::accept::Accept as HyperAccept;
use hyper::server::{Builder, Server};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Error as HyperError, Request, Response};
use log::{debug, trace};
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use spirit::fragment::driver::{CacheSimilar, Comparable, Comparison};
use spirit::fragment::{Fragment, Stackable, Transformation};
use spirit::utils::{deserialize_opt_duration, is_default, is_true, serialize_opt_duration};
use spirit::{log_error, Empty};
use spirit_tokio::net::limits::WithLimits;
use spirit_tokio::net::{Accept as SpiritAccept, TcpListen};
use spirit_tokio::runtime::{self, ShutGuard};
use spirit_tokio::FutureInstaller;
#[cfg(feature = "cfg-help")]
use structdoc::StructDoc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot::{self, Receiver, Sender};
use tokio::task::JoinHandle;
const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
fn is_default_timeout(t: &Duration) -> bool {
*t == KEEPALIVE_TIMEOUT
}
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize)]
#[cfg_attr(feature = "cfg-help", derive(StructDoc))]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum HttpMode {
Both,
#[serde(rename = "http1-only")]
Http1Only,
#[serde(rename = "http2-only")]
Http2Only,
}
impl Default for HttpMode {
fn default() -> Self {
HttpMode::Both
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize)]
#[cfg_attr(feature = "cfg-help", derive(StructDoc))]
#[serde(rename_all = "kebab-case", default)]
#[non_exhaustive]
pub struct HyperCfg {
#[serde(skip_serializing_if = "is_true")]
pub http1_keepalive: bool,
#[serde(skip_serializing_if = "is_true")]
pub http1_half_close: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub http1_max_buf_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub http2_initial_stream_window_size: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub http2_initial_connection_window_size: Option<u32>,
#[serde(default, skip_serializing_if = "is_default")]
pub http2_adaptive_window: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub http2_max_concurrent_streams: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub http2_max_frame_size: Option<u32>,
#[serde(
deserialize_with = "deserialize_opt_duration",
serialize_with = "serialize_opt_duration",
skip_serializing_if = "Option::is_none"
)]
pub http2_keep_alive_interval: Option<Duration>,
#[serde(skip_serializing_if = "is_default_timeout")]
pub http2_keep_alive_timeout: Duration,
#[serde(default, skip_serializing_if = "is_default")]
pub http_mode: HttpMode,
}
impl HyperCfg {
pub fn builder<I>(&self, incoming: I) -> Builder<I> {
let (h1_only, h2_only) = match self.http_mode {
HttpMode::Both => (false, false),
HttpMode::Http1Only => (true, false),
HttpMode::Http2Only => (false, true),
};
let mut builder = Server::builder(incoming)
.http1_keepalive(self.http1_keepalive)
.http1_half_close(self.http1_half_close)
.http2_initial_connection_window_size(self.http2_initial_connection_window_size)
.http2_initial_stream_window_size(self.http2_initial_stream_window_size)
.http2_adaptive_window(self.http2_adaptive_window)
.http2_max_concurrent_streams(self.http2_max_concurrent_streams)
.http2_max_frame_size(self.http2_max_frame_size)
.http2_keep_alive_interval(self.http2_keep_alive_interval)
.http2_keep_alive_timeout(self.http2_keep_alive_timeout)
.http1_only(h1_only)
.http2_only(h2_only);
if let Some(size) = self.http1_max_buf_size {
builder = builder.http1_max_buf_size(size);
}
builder
}
}
impl Default for HyperCfg {
fn default() -> Self {
HyperCfg {
http1_keepalive: true,
http1_half_close: true,
http1_max_buf_size: None,
http2_initial_connection_window_size: None,
http2_initial_stream_window_size: None,
http2_adaptive_window: false,
http2_max_concurrent_streams: None,
http2_max_frame_size: None,
http2_keep_alive_interval: None,
http2_keep_alive_timeout: KEEPALIVE_TIMEOUT,
http_mode: HttpMode::default(),
}
}
}
#[pin_project]
#[derive(Copy, Clone, Debug)]
pub struct Acceptor<A>(#[pin] A);
impl<A: SpiritAccept> HyperAccept for Acceptor<A> {
type Conn = A::Connection;
type Error = IoError;
fn poll_accept(
self: Pin<&mut Self>,
ctx: &mut Context,
) -> Poll<Option<Result<Self::Conn, IoError>>> {
self.project()
.0
.poll_accept(ctx)
.map(|p| p.map(Some).transpose())
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize)]
#[cfg_attr(feature = "cfg-help", derive(StructDoc))]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub struct HyperServer<Transport> {
#[serde(flatten)]
pub transport: Transport,
#[serde(flatten)]
pub hyper_cfg: HyperCfg,
}
impl<Transport: Comparable> Comparable for HyperServer<Transport> {
fn compare(&self, other: &Self) -> Comparison {
let transport_cmp = self.transport.compare(&other.transport);
if transport_cmp == Comparison::Same && self.hyper_cfg != other.hyper_cfg {
Comparison::Similar
} else {
transport_cmp
}
}
}
impl<Transport> Fragment for HyperServer<Transport>
where
Transport: Fragment + Debug + Clone + Comparable,
{
type Driver = CacheSimilar<Self>;
type Installer = ();
type Seed = Transport::Seed;
type Resource = Builder<Acceptor<Transport::Resource>>;
fn make_seed(&self, name: &'static str) -> Result<Self::Seed, AnyError> {
self.transport.make_seed(name)
}
fn make_resource(
&self,
seed: &mut Self::Seed,
name: &'static str,
) -> Result<Self::Resource, AnyError> {
debug!("Creating HTTP server {}", name);
let transport = self.transport.make_resource(seed, name)?;
let builder = self.hyper_cfg.builder(Acceptor(transport));
Ok(builder)
}
}
impl<Transport> Stackable for HyperServer<Transport> where Transport: Stackable {}
pub type HttpServer<ExtraCfg = Empty> = HyperServer<WithLimits<TcpListen<ExtraCfg>>>;
pub struct Activate<Fut> {
build_server: Option<Box<dyn FnOnce(Receiver<()>) -> Fut + Send>>,
shut_guard: Option<ShutGuard>,
sender: Option<Sender<()>>,
join: Option<JoinHandle<()>>,
name: &'static str,
}
impl<Fut> Drop for Activate<Fut> {
fn drop(&mut self) {
if let Some(sender) = self.sender.take() {
let _ = sender.send(());
}
}
}
impl<Fut, E> Future for Activate<Fut>
where
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Into<AnyError>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<()> {
trace!("Poll on Activate({})", self.name);
if let Some(build_server) = self.build_server.take() {
trace!("Activating {}", self.name);
let (sender, receiver) = oneshot::channel();
let server = build_server(receiver);
let name = self.name;
let shut_guard = self.shut_guard.take();
let server = async move {
let _shut_guard = shut_guard;
if let Err(e) = server.await {
log_error!(
Error,
e.into()
.context(format!("HTTP server error {}", name))
.into()
);
}
trace!("Server {} terminated", name);
};
let join = tokio::spawn(server);
self.join = Some(join);
self.sender = Some(sender);
}
match Pin::new(self.join.as_mut().expect("Missing join handle")).poll(ctx) {
Poll::Ready(Ok(())) => {
debug!("Future of server {} terminated", self.name);
Poll::Ready(())
}
Poll::Ready(Err(e)) => {
debug!("Future of server {} errored out", self.name);
log_error!(
Error,
e.context(format!("HTTP server {} failed", self.name))
.into()
);
Poll::Ready(())
}
Poll::Pending => {
trace!("Future of server {} is still pending", self.name);
Poll::Pending
}
}
}
}
pub struct BuildServer<BS>(pub BS);
impl<Tr, Inst, BS> Transformation<Builder<Acceptor<Tr::Resource>>, Inst, HyperServer<Tr>>
for BuildServer<BS>
where
Tr: Fragment + Clone + Send + 'static,
Tr::Resource: Send,
BS: ServerBuilder<Tr> + Clone + Send + 'static,
BS::OutputFut: Future<Output = Result<(), HyperError>>,
{
type OutputResource = Activate<BS::OutputFut>;
type OutputInstaller = FutureInstaller;
fn installer(&mut self, _ii: Inst, _name: &'static str) -> Self::OutputInstaller {
FutureInstaller::default()
}
fn transform(
&mut self,
builder: Builder<Acceptor<Tr::Resource>>,
cfg: &HyperServer<Tr>,
name: &'static str,
) -> Result<Self::OutputResource, AnyError> {
let build_server = self.0.clone();
let cfg = cfg.clone();
let build_server = move |receiver| build_server.build(builder, &cfg, name, receiver);
Ok(Activate {
build_server: Some(Box::new(build_server)),
shut_guard: runtime::shut_guard(),
join: None,
name,
sender: None,
})
}
}
pub trait ServerBuilder<Tr>
where
Tr: Fragment,
{
type OutputFut: Future<Output = Result<(), HyperError>> + Send;
fn build(
&self,
builder: Builder<Acceptor<Tr::Resource>>,
cfg: &HyperServer<Tr>,
name: &'static str,
shutdown: Receiver<()>,
) -> Self::OutputFut;
}
impl<F, Tr, Fut> ServerBuilder<Tr> for F
where
Tr: Fragment,
F: Fn(Builder<Acceptor<Tr::Resource>>, &HyperServer<Tr>, &'static str, Receiver<()>) -> Fut,
Fut: Future<Output = Result<(), HyperError>> + Send,
{
type OutputFut = Fut;
fn build(
&self,
builder: Builder<Acceptor<Tr::Resource>>,
cfg: &HyperServer<Tr>,
name: &'static str,
shutdown: Receiver<()>,
) -> Fut {
self(builder, cfg, name, shutdown)
}
}
pub fn server_from_handler<H, Tr, S>(handler: H) -> impl ServerBuilder<Tr> + Clone + Send
where
Tr: Fragment,
Tr::Resource: SpiritAccept + Unpin,
<Tr::Resource as SpiritAccept>::Connection: AsyncRead + AsyncWrite + Unpin,
H: Clone + Send + Sync + Fn(Request<Body>) -> S + 'static,
S: Future<Output = Response<Body>> + Send + 'static,
{
move |builder: Builder<Acceptor<Tr::Resource>>, _: &_, name, shutdown| {
debug!("Creating server instance {}", name);
let handler = handler.clone();
builder
.serve(make_service_fn(move |_conn| {
trace!("Creating a service for {}", name);
let handler = handler.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let handler = handler.clone();
async move { Ok::<_, Infallible>(handler(req).await) }
}))
}
}))
.with_graceful_shutdown(async move {
let _ = shutdown.await;
})
}
}