use crate::{
GMT_FORMAT, HeadersDebug, LatencyDisplay, SERVER, WrappedRedactedHashingAlg, X_REQUEST_ID,
X_VIA, get_host, get_http_version_str, response_for_panic,
};
use ahash::AHasher;
use axum::Router;
use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode, header};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::{Engine as B64Engine, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use http_body_util::BodyExt;
use hyper::HeaderMap;
use ordinary_config::{HttpCache, HttpEtagAlgorithm, XXH3Variation};
use std::hash::Hasher;
use std::sync::Arc;
use std::time::Duration;
use time::UtcDateTime;
use tower::ServiceBuilder;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::classify::ServerErrorsFailureClass;
use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::set_header::{SetRequestHeaderLayer, SetResponseHeaderLayer};
use tower_http::trace::TraceLayer;
use tracing::Span;
use uuid::Uuid;
#[derive(Clone)]
pub enum ServiceKind {
App,
Api,
Redirect,
Proxy,
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
pub fn apply_common_middleware<T>(
router: Router<T>,
state: &T,
server_span: Option<Span>,
domain: String,
log_headers: bool,
log_ips: bool,
redacted_hash: Arc<Option<WrappedRedactedHashingAlg>>,
kind: ServiceKind,
via_domain: Option<String>,
) -> Router
where
T: Clone + Send + Sync + 'static,
{
let redacted_hash_clone = redacted_hash.clone();
router
.with_state(state.clone())
.layer(
ServiceBuilder::new()
.layer(CatchPanicLayer::custom(response_for_panic))
.layer(RequestDecompressionLayer::new())
.layer(CompressionLayer::new()),
)
.layer(
ServiceBuilder::new()
.layer(SetRequestIdLayer::new(X_REQUEST_ID, MakeRequestUuid))
.layer(
TraceLayer::new_for_http()
.make_span_with(move |req: &axum::http::Request<_>| {
let request_id = req
.headers()
.get(X_REQUEST_ID)
.and_then(|rid| {
rid.to_str()
.ok()
.and_then(|rid| Uuid::parse_str(rid).ok())
.map(tracing::field::display)
})
.unwrap_or(tracing::field::display(Uuid::new_v4()));
let host =
get_host(req.headers(), req.uri()).map(tracing::field::display);
let ip = crate::get_display_ip(log_ips, req);
let query = req.uri().query().map(tracing::field::display);
match kind {
ServiceKind::App => {
if let Some(server_span) = &server_span {
server_span.in_scope(|| {
tracing::info_span!(
"app",
%domain,
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
})
} else {
tracing::info_span!(
"app",
%domain,
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
}
}
ServiceKind::Proxy => {
if let Some(server_span) = &server_span {
server_span.in_scope(|| {
tracing::info_span!(
"proxy",
%domain,
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
})
} else {
tracing::info_span!(
"proxy",
%domain,
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
}
}
ServiceKind::Api => {
if let Some(server_span) = &server_span {
server_span.in_scope(|| {
tracing::info_span!(
"api",
%domain,
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
})
} else {
tracing::info_span!(
"api",
%domain,
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
}
}
ServiceKind::Redirect => {
if let Some(server_span) = &server_span {
server_span.in_scope(|| {
tracing::info_span!(
"redirect",
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
})
} else {
tracing::info_span!(
"redirect",
host,
rid = %request_id,
ip,
path = %req.uri().path(),
query,
)
}
}
}
})
.on_request(move |req: &axum::http::Request<_>, _: &Span| {
let hd = log_headers
.then_some(HeadersDebug(req.headers(), redacted_hash.clone()));
#[cfg(tracing_unstable)]
let headers = log_headers.then_some(tracing::field::valuable(&hd));
#[cfg(not(tracing_unstable))]
let headers = log_headers.then_some(tracing::field::debug(&hd));
tracing::info!(
version = ?req.version(),
method = %req.method(),
headers,
"req"
);
})
.on_response(move |res: &Response<_>, latency: Duration, _: &Span| {
let hd = log_headers.then_some(HeadersDebug(
res.headers(),
redacted_hash_clone.clone(),
));
#[cfg(tracing_unstable)]
let headers = log_headers.then_some(tracing::field::valuable(&hd));
#[cfg(not(tracing_unstable))]
let headers = log_headers.then_some(tracing::field::debug(&hd));
let status = res.status().as_u16();
let latency = LatencyDisplay(latency.as_nanos() as f64);
if status >= 500 {
tracing::error!(status, headers, %latency, "res");
} else if status >= 400 {
tracing::warn!(status, headers, %latency, "res");
} else {
tracing::info!(status, headers, %latency, "res");
}
})
.on_failure(|error: ServerErrorsFailureClass, _: Duration, _: &Span| {
tracing::error!(
err = %error,
"fail"
);
}),
)
.layer(PropagateRequestIdLayer::new(X_REQUEST_ID))
.layer(SetResponseHeaderLayer::if_not_present(
header::SERVER,
HeaderValue::from_static(SERVER),
))
.option_layer(via_domain.map(|domain| {
SetRequestHeaderLayer::overriding(X_VIA, move |req: &axum::http::Request<_>| {
let req_version = get_http_version_str(req.version());
HeaderValue::from_str(&format!("{req_version} {domain} (ordinaryd)")).ok()
})
})),
)
}
#[allow(clippy::similar_names)]
pub async fn http_cache_middleware(
last_modified: UtcDateTime,
req_headers: HeaderMap,
request: Request,
next: Next,
) -> Response {
let response = next.run(request).await;
let (mut parts, body) = response.into_parts();
let body_bytes = if let Ok(collected) = body.collect().await {
collected.to_bytes()
} else {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
};
let mut res_headers = HeaderMap::new();
let etag_string = get_etag_hash(body_bytes.as_ref(), None);
let etag_str = etag_string.as_str();
if let Some(if_none_match) = req_headers.get(header::IF_NONE_MATCH)
&& let Ok(if_none_match_str) = if_none_match.to_str()
&& if_none_match_str == etag_str
{
res_headers.insert(header::ETAG, if_none_match.to_owned());
return (StatusCode::NOT_MODIFIED, res_headers).into_response();
} else if let Ok(etag_header) = HeaderValue::from_str(etag_str) {
if let Some(if_modified_since) = req_headers.get(header::IF_MODIFIED_SINCE)
&& let Ok(if_modified_since_str) = if_modified_since.to_str()
&& let Ok(if_modified_since) = UtcDateTime::parse(if_modified_since_str, &GMT_FORMAT)
&& if_modified_since >= last_modified
{
res_headers.insert(header::ETAG, etag_header);
return (StatusCode::NOT_MODIFIED, res_headers).into_response();
}
parts.headers.insert(header::ETAG, etag_header);
}
(parts, body_bytes).into_response()
}
#[must_use]
pub fn get_etag_hash(content: &[u8], http_cache: Option<&HttpCache>) -> String {
if let Some(http_cache) = http_cache
&& let Some(etag_config) = &http_cache.etag
&& let Some(etag_alg) = &etag_config.alg
{
return match etag_alg {
HttpEtagAlgorithm::AHash => {
let mut hasher = AHasher::default();
hasher.write(content);
b64.encode(hasher.finish().to_be_bytes())
}
HttpEtagAlgorithm::XXH3(variation) => match variation {
XXH3Variation::Bit64 => {
b64.encode(xxhash_rust::xxh3::xxh3_64(content).to_be_bytes())
}
XXH3Variation::Bit128 => {
b64.encode(xxhash_rust::xxh3::xxh3_128(content).to_be_bytes())
}
},
HttpEtagAlgorithm::Rustc => {
let mut hasher = rustc_hash::FxHasher::default();
hasher.write(content);
b64.encode(hasher.finish().to_be_bytes())
}
HttpEtagAlgorithm::Blake3 => b64.encode(&blake3::hash(content).as_bytes()[0..16]),
};
}
let mut hasher = AHasher::default();
hasher.write(content);
b64.encode(hasher.finish().to_be_bytes())
}
#[must_use]
pub async fn x_via(headers: HeaderMap, request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
if let Some(x_via) = headers.get(X_VIA) {
response.headers_mut().insert(header::VIA, x_via.to_owned());
}
response
}
pub fn modify_etag_for_encoding(res: &Response) -> Option<HeaderValue> {
let headers = res.headers();
if let Some(curr_etag) = headers.get(header::ETAG)
&& let Ok(curr_etag_str) = curr_etag.to_str()
{
let etag_len = curr_etag_str.len();
if (etag_len == 22 || etag_len == 11)
&& let Some(compression) = headers.get(header::CONTENT_ENCODING)
&& let Ok(compression_str) = compression.to_str()
{
let mut etag_string = curr_etag_str.to_owned();
match compression_str {
"gzip" => etag_string.push('1'),
"zstd" => etag_string.push('2'),
"br" => etag_string.push('3'),
"deflate" => etag_string.push('4'),
_ => (),
}
match HeaderValue::from_str(etag_string.as_str()) {
Ok(v) => return Some(v),
Err(err) => tracing::error!(%err),
}
} else {
return Some(curr_etag.clone());
}
}
None
}
pub fn check_if_none_match<'a>(headers: &'a HeaderMap, etag: &'a str) -> Option<&'a str> {
if let Some(if_none_match) = headers.get(header::IF_NONE_MATCH)
&& let Ok(if_none_match_str) = if_none_match.to_str()
{
if if_none_match_str.len() < 11 {
return None;
}
if (etag.len() == 23
|| etag.len() == 12
|| if_none_match_str.len() == 22
|| if_none_match_str.len() == 11)
&& if_none_match_str == etag
{
return Some(etag);
}
if &if_none_match_str[..if_none_match_str.len() - 1] == etag {
Some(if_none_match_str)
} else {
None
}
} else {
None
}
}