use std::fmt::Debug;
use axum::body::Body;
use axum::response::IntoResponse;
use futures::Future;
use http::request::Parts;
use http::HeaderValue;
use http_body_util::BodyExt;
use hyper::{header, Request, Response};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::time::{Duration, Instant};
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse};
#[derive(Debug, Clone, Copy, Default)]
enum BodyFormat {
#[default]
JsonPb,
Pb,
}
impl BodyFormat {
fn from_content_type<T>(req: &Request<T>) -> BodyFormat {
match req
.headers()
.get(header::CONTENT_TYPE)
.map(|x| x.as_bytes())
{
Some(CONTENT_TYPE_PROTOBUF) => BodyFormat::Pb,
_ => BodyFormat::JsonPb,
}
}
}
pub(crate) async fn handle_request<S, F, Fut, In, Out>(
service: S,
req: Request<Body>,
f: F,
) -> Response<Body>
where
F: FnOnce(S, http::Request<In>) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<http::Response<Out>, TwirpErrorResponse>> + Send,
In: prost::Message + Default + serde::de::DeserializeOwned,
Out: prost::Message + Default + serde::Serialize,
{
let mut timings = req
.extensions()
.get::<Timings>()
.copied()
.unwrap_or_else(|| Timings::new(Instant::now()));
let (parts, req, resp_fmt) = match parse_request::<In>(req, &mut timings).await {
Ok(tuple) => tuple,
Err(err) => {
return error::malformed("bad request")
.with_meta("error", &err)
.with_generic_error(err)
.into_response();
}
};
let r = Request::from_parts(parts, req);
let res = f(service, r).await;
timings.set_response_handled();
let mut resp = match write_response(res, resp_fmt) {
Ok(resp) => resp,
Err(err) => {
return error::internal("error serializing response")
.with_meta("error", &err)
.with_generic_error(err)
.into_response();
}
};
timings.set_response_written();
resp.extensions_mut().insert(timings);
resp
}
async fn parse_request<T>(
req: Request<Body>,
timings: &mut Timings,
) -> Result<(Parts, T, BodyFormat), GenericError>
where
T: prost::Message + Default + DeserializeOwned,
{
let format = BodyFormat::from_content_type(&req);
let (parts, body) = req.into_parts();
let bytes = body.collect().await?.to_bytes();
timings.set_received();
let request = match format {
BodyFormat::Pb => T::decode(&bytes[..])?,
BodyFormat::JsonPb => serde_json::from_slice(&bytes)?,
};
timings.set_parsed();
Ok((parts, request, format))
}
fn write_response<T>(
out: Result<http::Response<T>, TwirpErrorResponse>,
out_format: BodyFormat,
) -> Result<Response<Body>, GenericError>
where
T: prost::Message + Default + Serialize,
{
let res = match out {
Ok(out) => {
let (parts, body) = out.into_parts();
let (body, content_type) = match out_format {
BodyFormat::Pb => (
Body::from(serialize_proto_message(body)),
CONTENT_TYPE_PROTOBUF,
),
BodyFormat::JsonPb => {
(Body::from(serde_json::to_string(&body)?), CONTENT_TYPE_JSON)
}
};
let mut resp = Response::new(body);
resp.extensions_mut().extend(parts.extensions);
resp.headers_mut().extend(parts.headers);
resp.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_bytes(content_type)?);
resp
}
Err(err) => err.into_response(),
};
Ok(res)
}
pub async fn not_found_handler() -> Response<Body> {
error::bad_route("not found").into_response()
}
#[derive(Debug, Clone, Copy)]
pub struct Timings {
start: Instant,
request_received: Option<Instant>,
request_parsed: Option<Instant>,
response_handled: Option<Instant>,
response_written: Option<Instant>,
}
impl Timings {
#[allow(clippy::new_without_default)]
pub fn new(start: Instant) -> Self {
Self {
start,
request_received: None,
request_parsed: None,
response_handled: None,
response_written: None,
}
}
fn set_received(&mut self) {
self.request_received = Some(Instant::now());
}
fn set_parsed(&mut self) {
self.request_parsed = Some(Instant::now());
}
fn set_response_handled(&mut self) {
self.response_handled = Some(Instant::now());
}
fn set_response_written(&mut self) {
self.response_written = Some(Instant::now());
}
pub fn received(&self) -> Option<Duration> {
self.request_received.map(|x| x - self.start)
}
pub fn parsed(&self) -> Option<Duration> {
match (self.request_parsed, self.request_received) {
(Some(parsed), Some(received)) => Some(parsed - received),
_ => None,
}
}
pub fn response_handled(&self) -> Option<Duration> {
match (self.response_handled, self.request_parsed) {
(Some(handled), Some(parsed)) => Some(handled - parsed),
_ => None,
}
}
pub fn response_written(&self) -> Option<Duration> {
match (self.response_written, self.response_handled) {
(Some(written), Some(handled)) => Some(written - handled),
(Some(written), None) => {
if let Some(parsed) = self.request_parsed {
Some(written - parsed)
} else {
self.request_received.map(|received| written - received)
}
}
_ => None,
}
}
pub fn total_duration(&self) -> Duration {
self.start.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use axum::middleware::{self, Next};
use tower::Service;
fn timings() -> Timings {
Timings::new(Instant::now())
}
#[tokio::test]
async fn test_bad_route() {
let mut router = test_api_router();
let req = Request::get("/nothing")
.extension(timings())
.body(Body::empty())
.unwrap();
let resp = router.call(req).await.unwrap();
let data = read_err_body(resp.into_body()).await;
assert_eq!(data, error::bad_route("not found"));
}
#[tokio::test]
async fn test_ping_success() {
let mut router = test_api_router();
let resp = router.call(gen_ping_request("hi")).await.unwrap();
assert!(resp.status().is_success(), "{:?}", resp);
let data: PingResponse = read_json_body(resp.into_body()).await;
assert_eq!(&data.name, "hi");
}
#[tokio::test]
async fn test_ping_invalid_request() {
let mut router = test_api_router();
let req = Request::post("/twirp/test.TestAPI/Ping")
.extension(timings())
.body(Body::empty()) .unwrap();
let resp = router.call(req).await.unwrap();
assert!(resp.status().is_client_error(), "{:?}", resp);
let data = read_err_body(resp.into_body()).await;
let expected = error::malformed("bad request")
.with_meta("error", "EOF while parsing a value at line 1 column 0");
assert_eq!(data, expected);
}
#[tokio::test]
async fn test_boom() {
let mut router = test_api_router();
let req = serde_json::to_string(&PingRequest {
name: "hi".to_string(),
})
.unwrap();
let req = Request::post("/twirp/test.TestAPI/Boom")
.extension(timings())
.body(Body::from(req))
.unwrap();
let resp = router.call(req).await.unwrap();
assert!(resp.status().is_server_error(), "{:?}", resp);
let data = read_err_body(resp.into_body()).await;
assert_eq!(data, error::internal("boom!"));
}
#[tokio::test]
async fn test_middleware() {
let mut router = test_api_router().layer(middleware::from_fn(request_id_middleware));
let resp = router.call(gen_ping_request("hi")).await.unwrap();
assert!(resp.status().is_success(), "{:?}", resp);
let data: PingResponse = read_json_body(resp.into_body()).await;
assert_eq!(&data.name, "hi");
let req = Request::post("/twirp/test.TestAPI/Ping")
.header("x-request-id", "abcd")
.body(Body::from(
serde_json::to_string(&PingRequest {
name: "hello".to_string(),
})
.expect("will always be valid json"),
))
.expect("always a valid twirp request");
let resp = router.call(req).await.unwrap();
assert!(resp.status().is_success(), "{:?}", resp);
let data: PingResponse = read_json_body(resp.into_body()).await;
assert_eq!(&data.name, "hello-abcd");
}
async fn request_id_middleware(
mut request: http::Request<Body>,
next: Next,
) -> http::Response<Body> {
let rid = request
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|x| RequestId(x.to_string()));
if let Some(rid) = rid {
request.extensions_mut().insert(rid);
}
next.run(request).await
}
}