use std::any::Any;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::time::Instant;
use futures_util::FutureExt;
use http::Request;
use tracing::Instrument;
use tracing::Level;
use crate::app::{fire_request_hooks, fire_response_hooks, request_info, AppInner};
use crate::body::ReqBody;
use crate::error::Error;
use crate::extract::RequestContext;
use crate::hooks::RequestInfo;
use crate::logging::Logger;
use crate::response::{IntoResponse, Response};
use crate::router::matcher::Match;
impl AppInner {
pub async fn dispatch(&self, request: Request<ReqBody>) -> Response {
let (head, body) = request.into_parts();
let path = head.uri.path().to_owned();
let needs_info = self.needs_request_info();
match self.matcher().find(&head.method, &path) {
Match::Found { route, params } => {
let handler = route.handler().clone();
let info = (needs_info || route.has_hooks()).then(|| {
request_info(
&head.method,
&head.uri,
&head.headers,
Some(route.path().to_owned()),
)
});
if let Some(info) = &info {
fire_request_hooks(route.request_hooks(), info).await;
}
let started = info.as_ref().map(|_| Instant::now());
let tracing_enabled = tracing::enabled!(Level::INFO);
let request_meta = if tracing_enabled || self.request_logs() {
Some(match &info {
Some(info) => DispatchMeta {
method: info.method().clone(),
request_path: info.path().to_owned(),
route_path: info.route().unwrap_or(route.path()).to_owned(),
request_id: info.request_id().map(str::to_owned),
},
None => DispatchMeta {
method: head.method.clone(),
request_path: head.uri.path().to_owned(),
route_path: route.path().to_owned(),
request_id: head
.headers
.get("x-request-id")
.and_then(|value| value.to_str().ok())
.map(str::to_owned),
},
})
} else {
None
};
let request_started = Instant::now();
let ctx = RequestContext::new(head, params, self.state().clone(), body);
let future: Pin<
Box<dyn std::future::Future<Output = crate::error::Result<Response>> + Send>,
> = if tracing_enabled {
let meta = request_meta.as_ref().expect("request meta for tracing");
let span = tracing::info_span!(
"request",
method = %meta.method,
route = %meta.route_path,
request_id = meta.request_id.as_deref().unwrap_or("")
);
Box::pin(handler(ctx).instrument(span))
} else {
Box::pin(handler(ctx))
};
let result = if self.catch_panics() {
match AssertUnwindSafe(future).catch_unwind().await {
Ok(result) => result,
Err(payload) => {
if let Some(info) = &info {
self.fire_panic(info, &panic_message(payload.as_ref()))
.await;
}
return Error::internal("handler panicked").into_response();
}
}
} else {
future.await
};
let response = match result {
Ok(response) => response,
Err(error) => match &info {
Some(info) => self.render_error(error, info, Some(route)).await,
None => error.into_response(),
},
};
if let Some(info) = &info {
let elapsed = started.map(|start| start.elapsed()).unwrap_or_default();
fire_response_hooks(route.response_hooks(), info, response.status(), elapsed)
.await;
}
if self.request_logs() {
let meta = request_meta.as_ref().expect("request meta for request log");
let status = response.status().as_u16();
let mut logger = Logger::framework("HTTP");
if let Some(request_id) = &meta.request_id {
logger = logger.with_field("request_id", request_id.clone());
}
logger
.info(format!("{} {} {status}", meta.method, meta.request_path))
.field("method", meta.method.as_str())
.field("path", &meta.request_path)
.field("route", &meta.route_path)
.field("status", status)
.field("duration_ms", request_started.elapsed().as_millis() as u64)
.emit();
}
response
}
Match::MethodNotAllowed => {
let info =
needs_info.then(|| request_info(&head.method, &head.uri, &head.headers, None));
self.finish_error(Error::method_not_allowed("method not allowed"), info)
.await
}
Match::NotFound => {
let info =
needs_info.then(|| request_info(&head.method, &head.uri, &head.headers, None));
self.finish_error(Error::not_found("resource not found"), info)
.await
}
}
}
#[allow(dead_code)]
pub(crate) async fn dispatch_upgrade(
&self,
request: Request<ReqBody>,
duplex: tokio::io::DuplexStream,
) -> Response {
let (head, body) = request.into_parts();
let path = head.uri.path().to_owned();
match self.matcher().find(&head.method, &path) {
Match::Found { route, params } => {
let handler = route.handler().clone();
let ctx = RequestContext::with_duplex_upgrade(
head,
params,
self.state().clone(),
body,
duplex,
);
match handler(ctx).await {
Ok(response) => response,
Err(error) => error.into_response(),
}
}
Match::MethodNotAllowed => {
Error::method_not_allowed("method not allowed").into_response()
}
Match::NotFound => Error::not_found("resource not found").into_response(),
}
}
async fn finish_error(&self, error: Error, info: Option<RequestInfo>) -> Response {
match info {
Some(info) => self.render_error(error, &info, None).await,
None => error.into_response(),
}
}
}
struct DispatchMeta {
method: http::Method,
request_path: String,
route_path: String,
request_id: Option<String>,
}
fn panic_message(payload: &(dyn Any + Send)) -> String {
let max_len = 1024;
if let Some(message) = payload.downcast_ref::<&str>() {
let s = (*message).to_owned();
if s.len() > max_len {
s[..max_len].to_owned()
} else {
s
}
} else if let Some(message) = payload.downcast_ref::<String>() {
if message.len() > max_len {
message[..max_len].to_owned()
} else {
message.clone()
}
} else {
"panic".to_owned()
}
}
#[cfg(test)]
mod tests {
use crate::app::{App, AppInner};
use crate::body::{box_body, ReqBody};
use crate::error::Result;
use crate::extract::RequestContext;
use crate::response::Response;
use crate::router::{BoxFuture, HandlerFn, Route, Router};
use crate::{json_response, Method, StatusCode};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use std::sync::Arc;
fn echo_id_handler() -> HandlerFn {
Arc::new(
|ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
Box::pin(async move {
let id = ctx.path_param("user_id").unwrap_or_default().to_owned();
Ok(json_response(
StatusCode::OK,
&serde_json::json!({ "id": id }),
))
})
},
)
}
fn test_app() -> AppInner {
let router = Router::new().prefix("/users").route(Route::new(
Method::GET,
"/{user_id}",
echo_id_handler(),
));
App::new().include_router(router).build().unwrap()
}
fn request(method: Method, uri: &str) -> http::Request<ReqBody> {
http::Request::builder()
.method(method)
.uri(uri)
.body(box_body(Full::new(Bytes::new())))
.unwrap()
}
async fn body_to_string(response: Response) -> String {
let bytes = response.into_body().collect().await.unwrap().to_bytes();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[tokio::test]
async fn dispatches_to_matching_route() {
let response = test_app().dispatch(request(Method::GET, "/users/42")).await;
assert_eq!(response.status(), StatusCode::OK);
assert!(body_to_string(response).await.contains("\"42\""));
}
#[tokio::test]
async fn unknown_path_yields_not_found() {
let response = test_app().dispatch(request(Method::GET, "/nope")).await;
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn wrong_method_yields_method_not_allowed() {
let response = test_app()
.dispatch(request(Method::POST, "/users/42"))
.await;
assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
}
#[test]
fn panic_message_handles_str_payload() {
use super::panic_message;
let payload: Box<dyn std::any::Any + Send> = Box::new("oops");
assert_eq!(panic_message(&*payload), "oops");
}
#[test]
fn panic_message_handles_string_payload() {
use super::panic_message;
let payload: Box<dyn std::any::Any + Send> = Box::new(String::from("from string"));
assert_eq!(panic_message(&*payload), "from string");
}
#[test]
fn panic_message_handles_unknown_payload() {
use super::panic_message;
let payload: Box<dyn std::any::Any + Send> = Box::new(42_i32);
assert_eq!(panic_message(&*payload), "panic");
}
}