#[cfg(feature = "tls")]
use std::path::Path;
use std::sync::Arc;
use hyper::http;
use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo};
use lazy_static::lazy_static;
use tokio::net::{TcpListener, ToSocketAddrs};
use crate::endpoint::Handler;
use crate::endpoint::{Endpoint, RouterEndpoint};
use crate::error::Error;
use crate::middleware::{Middleware, WithState};
use crate::register_method;
use crate::request::{Request, RequestCtx};
use crate::response::Response;
use crate::router::Router;
lazy_static! {
pub static ref SERVER_ID: String = format!("Lieweb {}", env!("CARGO_PKG_VERSION"));
}
pub struct App {
router: Router,
}
impl App {
pub fn new() -> App {
App {
router: Router::new(),
}
}
pub fn with_state<T>(state: T) -> App
where
T: Send + Sync + 'static + Clone,
{
let mut app = App::new();
app.middleware(WithState::new(state));
app
}
pub fn merge(
&mut self,
prefix: impl AsRef<str>,
router: Router,
) -> Result<(), crate::error::Error> {
self.router.merge(prefix, router)
}
pub fn register<H, T>(&mut self, method: http::Method, path: impl AsRef<str>, handler: H)
where
H: Handler<T> + Send + Sync + 'static,
T: 'static,
{
self.router.register(method, path, handler)
}
register_method!(options, http::Method::OPTIONS);
register_method!(get, http::Method::GET);
register_method!(head, http::Method::HEAD);
register_method!(post, http::Method::POST);
register_method!(put, http::Method::PUT);
register_method!(delete, http::Method::DELETE);
register_method!(trace, http::Method::TRACE);
register_method!(connect, http::Method::CONNECT);
register_method!(patch, http::Method::PATCH);
pub fn middleware(&mut self, m: impl Middleware) -> &mut Self {
self.router.middleware(m);
self
}
pub fn handle_not_found<H, T>(&mut self, handler: H) -> &mut Self
where
H: Handler<T> + Send + Sync + 'static,
T: 'static,
{
self.router.set_not_found_handler(handler);
self
}
pub async fn respond(self, req: Request) -> Response {
let mut req = req;
RequestCtx::init(&mut req, None);
let App { router } = self;
let router = Arc::new(router);
let endpoint = RouterEndpoint::new(router);
endpoint.call(req).await
}
pub async fn run(self, addr: impl ToSocketAddrs) -> Result<(), Error> {
let App { router } = self;
let router = Arc::new(router);
let listener = TcpListener::bind(addr).await.unwrap();
while let Ok((socket, remote_addr)) = listener.accept().await {
let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
let router = router.clone();
tokio::task::spawn(async move {
let router = router.clone();
let ret = server.serve_connection_with_upgrades(
TokioIo::new(socket),
service_fn(|mut req| {
let router = router.clone();
RequestCtx::init(&mut req, Some(remote_addr));
async move {
let endpoint = RouterEndpoint::new(router);
let resp = endpoint.call(req).await;
Ok::<_, Error>(resp)
}
}),
);
if let Err(e) = ret.await {
tracing::error!("serve_connection error: {:?}", e);
}
});
}
Ok(())
}
#[cfg(feature = "tls")]
pub async fn run_with_tls(
self,
addr: impl ToSocketAddrs,
cert: impl AsRef<Path>,
key: impl AsRef<Path>,
) -> Result<(), Error> {
let App { router } = self;
let router = Arc::new(router);
let tls_acceptor = crate::tls::new_tls_acceptor(cert, key)?;
let listener = TcpListener::bind(addr).await.unwrap();
while let Ok((socket, remote_addr)) = listener.accept().await {
let tls_acceptor = tls_acceptor.clone();
let router = router.clone();
tokio::task::spawn(async move {
let tls_acceptor = tls_acceptor.clone();
let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
let router = router.clone();
match tls_acceptor.accept(socket).await {
Ok(stream) => {
let stream = TokioIo::new(stream);
let ret = server.serve_connection(
stream,
service_fn(|mut req| {
let router = router.clone();
RequestCtx::init(&mut req, Some(remote_addr));
async move {
let endpoint = RouterEndpoint::new(router);
let resp = endpoint.call(req).await;
Ok::<_, Error>(resp)
}
}),
);
if let Err(e) = ret.await {
tracing::error!("serve_connection error: {:?}", e);
}
}
Err(err) => {
tracing::error!("tls accept failed, {:?}", err);
}
}
});
}
Ok(())
}
}
impl Default for App {
fn default() -> Self {
Self::new()
}
}
pub fn server_id() -> &'static str {
&SERVER_ID
}