use std::future::Future;
use std::io::Result as IoResult;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[cfg(feature = "http1")]
use hyper::server::conn::http1;
#[cfg(feature = "http2")]
use hyper::server::conn::http2;
use tokio::sync::Notify;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
#[cfg(feature = "quinn")]
use crate::conn::quinn;
use crate::conn::{Accepted, Acceptor, Holding, HttpBuilder};
use crate::http::{HeaderValue, HttpConnection, Version};
use crate::Service;
pub struct Server<A> {
acceptor: A,
builder: HttpBuilder,
idle_timeout: Option<Duration>,
}
impl<A: Acceptor + Send> Server<A> {
#[inline]
pub fn new(acceptor: A) -> Self {
Server {
acceptor,
builder: HttpBuilder {
#[cfg(feature = "http1")]
http1: http1::Builder::new(),
#[cfg(feature = "http2")]
http2: http2::Builder::new(crate::rt::tokio::TokioExecutor::new()),
#[cfg(feature = "quinn")]
quinn: crate::conn::quinn::Builder::new(),
},
idle_timeout: None,
}
}
#[inline]
pub fn holdings(&self) -> &[Holding] {
self.acceptor.holdings()
}
cfg_feature! {
#![feature = "http1"]
pub fn http1_mut(&mut self) -> &mut http1::Builder {
&mut self.builder.http1
}
}
cfg_feature! {
#![feature = "http2"]
pub fn http2_mut(&mut self) -> &mut http2::Builder<crate::rt::tokio::TokioExecutor> {
&mut self.builder.http2
}
}
cfg_feature! {
#![feature = "quinn"]
pub fn quinn_mut(&mut self) -> &mut quinn::Builder {
&mut self.builder.quinn
}
}
#[must_use]
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = Some(timeout);
self
}
#[inline]
pub async fn serve<S>(self, service: S)
where
S: Into<Service> + Send,
{
self.try_serve(service).await.unwrap();
}
#[inline]
pub async fn try_serve<S>(self, service: S) -> IoResult<()>
where
S: Into<Service> + Send,
{
self.try_serve_with_graceful_shutdown(service, futures_util::future::pending(), None)
.await
}
#[inline]
pub async fn serve_with_graceful_shutdown<S, G>(self, service: S, signal: G, timeout: Option<Duration>)
where
S: Into<Service> + Send,
G: Future<Output = ()> + Send + 'static,
{
self.try_serve_with_graceful_shutdown(service, signal, timeout)
.await
.unwrap();
}
#[inline]
pub async fn try_serve_with_graceful_shutdown<S, G>(
self,
service: S,
signal: G,
timeout: Option<Duration>,
) -> IoResult<()>
where
S: Into<Service> + Send,
G: Future<Output = ()> + Send + 'static,
{
let Self {
mut acceptor,
builder,
idle_timeout,
} = self;
let alive_connections = Arc::new(AtomicUsize::new(0));
let notify = Arc::new(Notify::new());
let timeout_token = CancellationToken::new();
let server_shutdown_token = CancellationToken::new();
tokio::pin!(signal);
let mut alt_svc_h3 = None;
for holding in acceptor.holdings() {
tracing::info!("listening {}", holding);
if holding.http_versions.contains(&Version::HTTP_3) {
if let Some(addr) = holding.local_addr.clone().into_std() {
let port = addr.port();
alt_svc_h3 = Some(
format!(r#"h3=":{port}"; ma=2592000,h3-29=":{port}"; ma=2592000"#)
.parse::<HeaderValue>()
.unwrap(),
);
}
}
}
let service = Arc::new(service.into());
let builder = Arc::new(builder);
loop {
tokio::select! {
_ = &mut signal => {
server_shutdown_token.cancel();
if let Some(timeout) = timeout {
tracing::info!(
timeout_in_seconds = timeout.as_secs_f32(),
"initiate graceful shutdown",
);
let timeout_token = timeout_token.clone();
tokio::spawn(async move {
tokio::time::sleep(timeout).await;
timeout_token.cancel();
});
} else {
tracing::info!("initiate graceful shutdown");
}
break;
},
accepted = acceptor.accept() => {
match accepted {
Ok(Accepted { conn, local_addr, remote_addr, http_scheme, ..}) => {
alive_connections.fetch_add(1, Ordering::Release);
let service = service.clone();
let alive_connections = alive_connections.clone();
let notify = notify.clone();
let handler = service.hyper_handler(local_addr, remote_addr, http_scheme, alt_svc_h3.clone());
let builder = builder.clone();
let timeout_token = timeout_token.clone();
let server_shutdown_token = server_shutdown_token.clone();
tokio::spawn(async move {
let conn = conn.serve(handler, builder, server_shutdown_token, idle_timeout);
if timeout.is_some() {
tokio::select! {
_ = conn => {
},
_ = timeout_token.cancelled() => {
}
}
} else {
conn.await.ok();
}
if alive_connections.fetch_sub(1, Ordering::Acquire) == 1 {
notify.notify_waiters();
}
});
},
Err(e) => {
tracing::error!(error = ?e, "accept connection failed");
}
}
}
}
}
if alive_connections.load(Ordering::Acquire) > 0 {
tracing::info!("wait for all connections to close.");
notify.notified().await;
}
tracing::info!("server stopped");
Ok(())
}
}
#[cfg(test)]
mod tests {
use serde::Serialize;
use crate::prelude::*;
use crate::test::{ResponseExt, TestClient};
#[tokio::test]
async fn test_server() {
#[handler]
async fn hello() -> Result<&'static str, ()> {
Ok("Hello World")
}
#[handler]
async fn json(res: &mut Response) {
#[derive(Serialize, Debug)]
struct User {
name: String,
}
res.render(Json(User { name: "jobs".into() }));
}
let router = Router::new().get(hello).push(Router::with_path("json").get(json));
let serivce = Service::new(router);
let base_url = "http://127.0.0.1:5800";
let result = TestClient::get(&base_url)
.send(&serivce)
.await
.take_string()
.await
.unwrap();
assert_eq!(result, "Hello World");
let result = TestClient::get(format!("{}/json", base_url))
.send(&serivce)
.await
.take_string()
.await
.unwrap();
assert_eq!(result, r#"{"name":"jobs"}"#);
let result = TestClient::get(format!("{}/not_exist", base_url))
.send(&serivce)
.await
.take_string()
.await
.unwrap();
assert!(result.contains("Not Found"));
let result = TestClient::get(format!("{}/not_exist", base_url))
.add_header("accept", "application/json", true)
.send(&serivce)
.await
.take_string()
.await
.unwrap();
assert!(result.contains(r#""code":404"#));
let result = TestClient::get(format!("{}/not_exist", base_url))
.add_header("accept", "text/plain", true)
.send(&serivce)
.await
.take_string()
.await
.unwrap();
assert!(result.contains("code: 404"));
let result = TestClient::get(format!("{}/not_exist", base_url))
.add_header("accept", "application/xml", true)
.send(&serivce)
.await
.take_string()
.await
.unwrap();
assert!(result.contains("<code>404</code>"));
}
}