#![warn(missing_docs)]
use crate::app::{api_error::ApiError, settings::Http2Config};
use std::cell::RefCell;
use crate::{
helpers::{exec_post_middleware, exec_pre_middleware},
middlewares::{Middleware, MiddlewareType},
req::HttpRequest,
res::HttpResponse,
router::Router,
types::{HttpMethods, RouterFns, Routes},
};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::{header, http::StatusCode, Method, Request, Response};
use hyper_staticfile::Static;
use routerify_ng::{ext::RequestExt, RouterService};
use settings::AppSettings;
use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc};
use tokio::net::TcpListener;
pub(crate) mod api_error;
mod h2;
pub mod handler;
pub mod middlewares;
pub mod settings;
pub struct App {
routes: Routes,
pub(crate) middlewares: Vec<Arc<Middleware>>,
pub(crate) settings: AppSettings,
}
impl RouterFns for App {
fn routes(&mut self) -> &mut Routes {
&mut self.routes
}
}
impl App {
pub fn new() -> Self {
App {
routes: HashMap::new(),
middlewares: Vec::new(),
settings: AppSettings::default(),
}
}
pub fn host(&mut self, host: &str) -> &mut Self {
self.settings.host = host.to_string();
self
}
pub fn http2_config(&mut self, config: Http2Config) -> &mut Self {
self.settings.http2_config = config;
self
}
pub fn with_graceful_shutdown(&mut self) {
self.settings.graceful_shutdown = true
}
pub fn router(&mut self, mut router: Router) {
let base_path = router.base_path;
for (path, methods) in router.routes() {
for (method, handler) in methods.to_owned() {
if path == "/" {
self.add_route(method, &base_path, move |req: HttpRequest, res| {
(handler)(req, res)
});
} else {
let full_path = format!("{}{}", base_path, path);
self.add_route(method, &full_path, move |req: HttpRequest, res| {
(handler)(req, res)
});
}
}
}
}
pub fn static_files(
&mut self,
path: &'static str,
file: &'static str,
) -> Result<(), &'static str> {
if file == "/" {
return Err("Serving from filesystem root '/' is not allowed for security reasons");
}
if path.is_empty() {
return Err("Mount path cannot be empty");
}
if file.is_empty() {
return Err("File path cannot be empty");
}
if !path.starts_with('/') {
return Err("Mount path must start with '/'");
}
self.settings.static_files.insert(path, file);
Ok(())
}
#[inline]
pub fn disable_http2(&mut self) -> &mut Self {
self.settings.http2_config.is_enabled = false;
self
}
pub async fn listen<F: FnOnce()>(&self, port: u16, cb: F) {
let mut router = routerify_ng::Router::<ApiError>::builder();
#[cfg(feature = "with-wynd")]
if let Some(middleware) = self.settings.wynd_config.clone() {
router = router.middleware(routerify_ng::Middleware::pre({
use crate::helpers::exec_wynd_middleware;
let middleware = Arc::new(middleware);
move |req| exec_wynd_middleware(req, Arc::clone(&middleware))
}));
}
for middleware in &self.middlewares {
match middleware.middleware_type {
MiddlewareType::Post => {
let middleware = Arc::clone(middleware);
router = router.middleware(routerify_ng::Middleware::post_with_info(
move |res, info| exec_post_middleware(res, Arc::clone(&middleware), info),
));
}
_ => {
let middleware = Arc::clone(middleware);
router = router.middleware(routerify_ng::Middleware::pre(move |req| {
exec_pre_middleware(req, Arc::clone(&middleware))
}));
}
}
}
for (path, methods) in &self.routes {
for (method, handler) in methods {
let handler = Arc::clone(handler);
let method = match method {
HttpMethods::GET => Method::GET,
HttpMethods::POST => Method::POST,
HttpMethods::PUT => Method::PUT,
HttpMethods::DELETE => Method::DELETE,
HttpMethods::PATCH => Method::PATCH,
HttpMethods::HEAD => Method::HEAD,
HttpMethods::OPTIONS => Method::OPTIONS,
};
router = router.add(path, vec![method], move |mut req| {
let handler = Arc::clone(&handler);
async move {
let mut our_req = match HttpRequest::from_hyper_request(&mut req).await {
Ok(r) => r,
Err(e) => {
return Err(ApiError::Generic(
HttpResponse::new().bad_request().text(e.to_string()),
));
}
};
req.params().iter().for_each(|(key, value)| {
our_req.set_param(key, value);
});
let mut response = handler(our_req, HttpResponse::new()).await;
let _ = crate::next::PENDING_HEADERS.try_with(|pending| {
for (k, v) in pending.borrow_mut().drain(..) {
response = std::mem::take(&mut response).set_header(k, v);
}
});
let _ = crate::next::PENDING_COOKIES.try_with(|pending| {
for cookie in pending.borrow_mut().drain(..) {
response = std::mem::take(&mut response).set_cookie_raw(cookie);
}
});
let hyper_response = response.to_hyper_response().await;
Ok(hyper_response.unwrap())
}
});
}
}
for (mount_path, serve_from) in self.settings.static_files.iter() {
let serve_from = (*serve_from).to_string();
let mount_root = (*mount_path).to_string();
let route_pattern_owned = if mount_root == "/" {
"/*".to_string()
} else {
format!("{}/{}", mount_root, "*")
};
let serve_from_clone = serve_from.clone();
let mount_root_clone = mount_root.clone();
router = router.get(route_pattern_owned, move |req| {
let serve_from = serve_from_clone.clone();
let mount_root = mount_root_clone.clone();
async move {
match Self::serve_static_with_headers(req, mount_root, serve_from).await {
Ok(res) => Ok(res),
Err(e) => Err(ApiError::Generic(
HttpResponse::new()
.internal_server_error()
.text(e.to_string()),
)),
}
}
});
}
router = router.err_handler(Self::error_handler);
let router = router.build().unwrap();
cb();
let addr = format!("{}:{}", self.settings.host, port)
.parse::<SocketAddr>()
.unwrap();
let listener = TcpListener::bind(addr).await;
if let Err(e) = listener {
eprintln!("Error binding to address {}: {}", addr, e);
return;
}
let listener = listener.unwrap();
let router_service = Arc::new(RouterService::new(router).unwrap());
let http2_enabled = self.settings.http2_config.is_enabled;
let http2_config = self.settings.http2_config.clone();
let mut shutdown = if self.settings.graceful_shutdown {
Some(Box::pin(tokio::signal::ctrl_c()))
} else {
None
};
loop {
let accept_result = if let Some(ref mut sig) = shutdown {
tokio::select! {
result = listener.accept() => Some(result),
_ = sig.as_mut() => None,
}
} else {
Some(listener.accept().await)
};
match accept_result {
Some(Ok((stream, _))) => {
let service = Arc::clone(&router_service);
let http2_config = http2_config.clone();
tokio::task::spawn(async move {
crate::next::PENDING_HEADERS.scope(
RefCell::new(Vec::new()),
crate::next::PENDING_COOKIES.scope(
RefCell::new(Vec::new()),
Self::handle_connection(stream, service, http2_enabled, http2_config),
),
)
.await;
});
}
Some(Err(e)) => {
eprintln!("Error accepting connection: {}", e);
}
None => {
break;
}
}
}
}
pub(crate) async fn error_handler(
err: routerify_ng::RouteError,
) -> Response<Full<hyper::body::Bytes>> {
let api_err = err.downcast::<ApiError>().unwrap_or_else(|_| {
return Box::new(ApiError::Generic(
HttpResponse::new()
.internal_server_error()
.text("Unhandled error"),
));
});
match *api_err {
ApiError::WebSocketUpgrade(response) => response,
ApiError::Generic(res) => {
let hyper_res = <HttpResponse as Clone>::clone(&res)
.to_hyper_response()
.await
.map_err(ApiError::from)
.unwrap();
hyper_res
}
}
}
pub(crate) async fn serve_static_with_headers<B>(
req: Request<B>,
mount_root: String,
fs_root: String,
) -> Result<Response<Full<hyper::body::Bytes>>, std::io::Error>
where
B: hyper::body::Body<Data = hyper::body::Bytes> + Send + 'static,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let (mut parts, body) = req.into_parts();
let original_uri = parts.uri.clone();
let original_path = original_uri.path();
let if_none_match = parts
.headers
.get(header::IF_NONE_MATCH)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let trimmed_path = if mount_root == "/" {
original_path
} else if original_path.starts_with(&mount_root) {
let remaining = &original_path[mount_root.len()..];
if remaining.is_empty() {
"/"
} else {
remaining
}
} else {
original_path
};
let normalized_path = if trimmed_path.is_empty() {
"/"
} else {
trimmed_path
};
let new_path_and_query = if let Some(query) = original_uri.query() {
format!("{}?{}", normalized_path, query)
} else {
normalized_path.to_string()
};
parts.uri = match new_path_and_query.parse() {
Ok(uri) => uri,
Err(e) => {
eprintln!(
"Error parsing URI: {} (original: {}, mount_root: {}, trimmed: {}, normalized: {})",
e, original_path, mount_root, trimmed_path, normalized_path
);
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid URI after rewriting: {}", e),
));
}
};
let rewritten_req = Request::from_parts(parts, body);
let static_service = Static::new(Path::new(fs_root.as_str()));
match static_service.serve(rewritten_req).await {
Ok(mut response) => {
response
.headers_mut()
.insert("Cache-Control", "public, max-age=86400".parse().unwrap());
response
.headers_mut()
.insert("X-Served-By", "hyper-staticfile".parse().unwrap());
if let Some(if_none_match_value) = if_none_match {
if let Some(etag) = response.headers().get(header::ETAG) {
if let Ok(etag_value) = etag.to_str() {
if if_none_match_value == etag_value {
let mut builder =
Response::builder().status(StatusCode::NOT_MODIFIED);
if let Some(h) = builder.headers_mut() {
for (k, v) in response.headers().iter() {
h.insert(k.clone(), v.clone());
}
h.remove(header::CONTENT_LENGTH);
}
return Ok(builder.body(Full::from(Bytes::new())).unwrap());
}
}
}
}
let (parts, body) = response.into_parts();
let collected = body.collect().await.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to collect body: {}", e),
)
})?;
let body_bytes = collected.to_bytes();
let full_body = Full::from(body_bytes);
Ok(Response::from_parts(parts, full_body))
}
Err(e) => Err(e),
}
}
pub(crate) fn _build_router(&self) -> routerify_ng::Router<ApiError> {
routerify_ng::Router::builder()
.err_handler(Self::error_handler)
.build()
.unwrap()
}
}