use crate::params::resolve_hreq_params;
use crate::params::HReqParams;
use crate::proto::Protocol;
use crate::AsyncRuntime;
use crate::Body;
use crate::Error;
use crate::Stream;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing_futures::Instrument;
mod chain;
mod conn;
mod handler;
mod middle;
mod path;
mod peek;
mod reply;
mod resb_ext;
mod route;
mod router;
mod serv_handle;
mod serv_req_ext;
use conn::Connection;
use peek::Peekable;
use serv_handle::EndFut;
pub use chain::Next;
pub use handler::{Handler, StateHandler};
pub use middle::{Middleware, StateMiddleware};
pub use reply::Reply;
pub use resb_ext::ResponseBuilderExt;
pub use route::{MethodHandlers, Route, StateRoute};
pub use router::Router;
pub use serv_handle::ServerHandle;
pub use serv_req_ext::ServerRequestExt;
#[derive(Clone)]
pub struct Server<State> {
state: Arc<State>,
router: Router<State>,
}
impl Server<()> {
pub fn new() -> Server<()> {
Server::with_state(())
}
}
impl<State> Server<State>
where
State: Clone + Unpin + Send + Sync + 'static,
{
pub fn with_state(state: State) -> Self {
Server {
state: Arc::new(state),
router: Router::new(),
}
}
pub fn state(&self) -> State {
(*self.state).clone()
}
pub fn at(&mut self, path: &str) -> Route<'_, State> {
self.router.at(path)
}
pub async fn listen(&self, port: u16) -> Result<(ServerHandle, SocketAddr), Error> {
#[cfg(feature = "tls")]
{
Ok(self.do_listen(port, None).await?)
}
#[cfg(not(feature = "tls"))]
{
Ok(self.do_listen(port).await?)
}
}
#[cfg(feature = "tls")]
pub async fn listen_tls(
&self,
port: u16,
tls: rustls::ServerConfig,
) -> Result<(ServerHandle, SocketAddr), Error> {
Ok(self.do_listen(port, Some(tls)).await?)
}
#[instrument(name = "server_listen", skip(self, port, tls))]
async fn do_listen(
&self,
port: u16,
#[cfg(feature = "tls")] tls: Option<rustls::ServerConfig>,
) -> Result<(ServerHandle, SocketAddr), Error> {
let bind_addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?;
let mut listener = AsyncRuntime::listen(bind_addr).await?;
let local_addr = listener.local_addr()?;
let (shut, end) = ServerHandle::new().await;
let driver = Arc::new(Driver::new(
self.router.clone(),
self.state.clone(),
end.clone(),
));
#[cfg(feature = "tls")]
let tls = {
if let Some(mut tls) = tls {
crate::tls::configure_tls_server(&mut tls);
Some(Arc::new(tls))
} else {
None
}
};
let task = async move {
loop {
trace!("Waiting for connection");
let next = end.race(listener.accept()).await?;
match next {
Ok(v) => {
let (stream, remote_addr) = v;
trace!("Connection from: {}", remote_addr);
let span =
trace_span!("conn_task", remote_addr = &remote_addr.to_string()[..]);
span.follows_from(tracing::span::Span::current());
let driver = driver.clone();
#[cfg(feature = "tls")]
let tls = tls.clone();
let conn_task = async move {
#[cfg(feature = "tls")]
{
if let Err(e) =
driver.connect(stream, local_addr, remote_addr, tls).await
{
debug!("Client connection failed: {}", e);
}
}
#[cfg(not(feature = "tls"))]
{
if let Err(e) =
driver.connect(stream, local_addr, remote_addr).await
{
debug!("Client connection failed: {}", e);
}
}
}
.instrument(span);
AsyncRuntime::spawn(conn_task);
}
Err(e) => {
warn!("Listen failed: {}, retrying…", e);
AsyncRuntime::timeout(Duration::from_secs(1)).await;
}
}
}
#[allow(unreachable_code)]
Some(())
}
.instrument(trace_span!("listen_task"));
AsyncRuntime::spawn(task);
Ok((shut, local_addr))
}
pub async fn handle<B: Into<Body>>(
&self,
req: http::Request<B>,
) -> Result<http::Response<Body>, Error> {
let (mut parts, body, client_req_params) = {
let (parts, body) = req.into_parts();
let mut parts = resolve_hreq_params(parts);
let mut body = body.into();
let params = parts.extensions.get::<HReqParams>().cloned().unwrap();
body.configure(¶ms, &parts.headers, false);
crate::client::configure_request(&mut parts, &body, false);
(parts, body, params)
};
let (req, server_req_params) = {
let len = body.content_encoded_length();
let mut body = Body::from_async_read(body, len);
let params = HReqParams::new();
body.configure(¶ms, &parts.headers, true);
parts.extensions.insert(params.clone());
(http::Request::from_parts(parts, body), params)
};
let state = self.state.clone();
let res = self.router.run(state, req).await.into_inner()?;
let (mut parts, body) = {
let (parts, mut body) = res.into_parts();
let mut server_res_params = parts
.extensions
.get::<HReqParams>()
.cloned()
.unwrap_or_else(HReqParams::new);
server_res_params.copy_from_request(&server_req_params);
body.configure(&server_res_params, &parts.headers, false);
(parts, body)
};
let (parts, body) = {
let len = body.content_encoded_length();
let mut body = Body::from_async_read(body, len);
body.configure(&client_req_params, &parts.headers, true);
conn::configure_response(&mut parts, &body, false);
parts.extensions.insert(client_req_params.clone());
(parts, body)
};
Ok(http::Response::from_parts(parts, body))
}
}
struct Driver<State> {
router: Router<State>,
state: Arc<State>,
end: EndFut,
}
impl<State> Driver<State>
where
State: Clone + Unpin + Send + Sync + 'static,
{
fn new(router: Router<State>, state: Arc<State>, end: EndFut) -> Self {
Driver { router, state, end }
}
pub(crate) async fn connect(
self: Arc<Self>,
tcp: impl Stream,
local_addr: SocketAddr,
remote_addr: SocketAddr,
#[cfg(feature = "tls")] config: Option<Arc<rustls::ServerConfig>>,
) -> Result<(), Error> {
let (stream, alpn_proto) = {
#[cfg(feature = "tls")]
{
use crate::either::Either;
use crate::tls::wrap_tls_server;
if let Some(config) = config {
let (tls, proto) = wrap_tls_server(tcp, config).await?;
(Either::A(tls), proto)
} else {
(Either::B(tcp), Protocol::Unknown)
}
}
#[cfg(not(feature = "tls"))]
{
(tcp, Protocol::Unknown)
}
};
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
let mut peek = Peekable::new(stream, H2_PREFACE.len());
let proto = if alpn_proto == Protocol::Unknown {
let peeked = peek.peek(H2_PREFACE.len()).await?;
let p = if peeked == H2_PREFACE {
Protocol::Http2
} else {
Protocol::Http11
};
trace!("Protocol by peek ({}): {:?}", remote_addr, p);
p
} else {
trace!("Protocol by ALPN ({}): {:?}", remote_addr, alpn_proto);
alpn_proto
};
Ok(self
.handle_incoming(peek, local_addr, remote_addr, proto)
.await?)
}
pub(crate) async fn handle_incoming(
self: Arc<Self>,
stream: impl Stream,
local_addr: SocketAddr,
remote_addr: SocketAddr,
proto: Protocol,
) -> Result<(), Error> {
let mut conn = if proto == Protocol::Http2 {
let h2conn = hreq_h2::server::handshake(stream).await?;
Connection::H2(h2conn)
} else {
let h1conn = hreq_h1::server::handshake(stream);
Connection::H1(h1conn)
};
debug!("Handshake done, waiting for requests: {}", remote_addr);
let mut req_no = 0;
loop {
let inc = self.end.race(conn.accept(local_addr, remote_addr)).await;
let next = if let Some(Some(r)) = inc {
r?
} else {
return Ok(());
};
req_no += 1;
let driver = self.clone();
let req_task = async move {
let (req, send) = next;
let params = req
.extensions()
.get::<HReqParams>()
.expect("Missing hreq_params in request")
.clone();
let state = driver.state.clone();
let result = driver.router.run(state, req).await.into_inner();
if let Err(err) = send.send_response(result, params).await {
debug!("Error sending response: {}", err);
}
}
.instrument(debug_span!("req_task", no = req_no));
AsyncRuntime::spawn(req_task);
}
}
}
#[cfg(test)]
mod test {
use super::*;
use http::{Request, Response};
use std::io;
#[derive(Clone)]
pub struct App;
#[test]
pub fn ensure_type_signatures() {
let mut server = Server::with_state(App);
server
.at("/p1")
.get(|_req| async { "yo" });
server
.at("/p2")
.get(return_scalar);
server
.at("/p3")
.get(return_io_result)
.post(return_io_result);
server
.at("/p4")
.middleware(mid_nostate)
.get(return_response);
server
.at("/p5")
.get(return_result_response);
server
.at("/op")
.get(return_option);
server
.at("/p6")
.with_state()
.middleware(mid_state)
.get(return_result_response_state);
}
async fn return_scalar(_req: Request<Body>) -> String {
format!("Yo {}", "world")
}
async fn mid_nostate(req: Request<Body>, next: Next) -> Result<Response<Body>, Error> {
let res = next.run(req).await;
res
}
async fn mid_state(_st: App, req: Request<Body>, next: Next) -> Result<Response<Body>, Error> {
let res = next.run(req).await;
res
}
async fn return_io_result(_req: Request<Body>) -> Result<String, io::Error> {
Ok("yo".into())
}
async fn return_response(_req: Request<Body>) -> Response<String> {
Response::builder().body("yo".into()).unwrap()
}
async fn return_option(_req: Request<Body>) -> Option<String> {
None
}
async fn return_result_response(_req: Request<Body>) -> Result<Response<String>, http::Error> {
Response::builder().body("yo".into())
}
async fn return_result_response_state(
_state: App,
_req: Request<Body>,
) -> Result<Response<String>, http::Error> {
Response::builder().body("yo".into())
}
}