pub mod req_res_logging;
use crate::app::context::AppContext;
use crate::service::http::middleware::Middleware;
use crate::util::tracing::optional_trace_field;
use axum::Router;
use axum::extract::{FromRef, MatchedPath};
use axum::http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use itertools::Itertools;
use opentelemetry_semantic_conventions::trace::{
HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, HTTP_ROUTE, NETWORK_PROTOCOL_VERSION, URL_PATH,
URL_QUERY,
};
use serde_derive::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::{BTreeMap, HashSet};
use std::iter::{IntoIterator, Iterator};
use std::str::FromStr;
use std::time::Duration;
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{Span, error_span, field, info};
use url::Url;
use validator::Validate;
#[derive(Debug, Clone, Default, Validate, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case", default)]
#[non_exhaustive]
pub struct TracingConfig {
#[serde(default)]
pub request_headers_allow_all: bool,
#[serde(default)]
pub response_headers_allow_all: bool,
#[serde(default)]
pub query_params_allow_all: bool,
#[serde(default)]
pub request_header_names: Vec<String>,
#[serde(default)]
pub response_header_names: Vec<String>,
#[serde(default)]
pub query_param_names: Vec<String>,
}
pub struct TracingMiddleware;
impl<S> Middleware<S> for TracingMiddleware
where
S: 'static + Send + Sync + Clone,
AppContext: FromRef<S>,
{
type Error = crate::error::Error;
fn name(&self) -> String {
"tracing".to_string()
}
fn enabled(&self, state: &S) -> bool {
AppContext::from_ref(state)
.config()
.service
.http
.custom
.middleware
.tracing
.common
.enabled(state)
}
fn priority(&self, state: &S) -> i32 {
AppContext::from_ref(state)
.config()
.service
.http
.custom
.middleware
.tracing
.common
.priority
}
fn install(&self, state: &S, router: Router) -> Result<Router, Self::Error> {
let context = AppContext::from_ref(state);
let middleware_config = &context.config().service.http.custom.middleware;
let request_id_header_name = &middleware_config.set_request_id.custom.common.header_name;
let tracing_config = &middleware_config.tracing.custom;
let router = router.layer(
TraceLayer::new_for_http()
.make_span_with(
CustomMakeSpan::new(request_id_header_name).with_tracing_config(tracing_config),
)
.on_request(CustomOnRequest::new(tracing_config))
.on_response(CustomOnResponse::new(tracing_config)),
);
Ok(router)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CustomMakeSpan {
pub request_id_header_name: String,
pub query_params_allow_all: bool,
pub query_param_names: HashSet<String>,
}
impl CustomMakeSpan {
pub fn new(request_id_header_name: &str) -> Self {
Self {
request_id_header_name: request_id_header_name.to_owned(),
query_params_allow_all: false,
query_param_names: Default::default(),
}
}
pub fn with_tracing_config(mut self, tracing_config: &TracingConfig) -> Self {
self.query_params_allow_all = tracing_config.query_params_allow_all;
self.query_param_names = tracing_config
.query_param_names
.iter()
.map(|name| name.to_lowercase())
.collect();
self
}
}
impl<B> MakeSpan<B> for CustomMakeSpan {
fn make_span(&mut self, request: &Request<B>) -> Span {
let matched_path = get_matched_path(request);
let request_id = get_request_id(&self.request_id_header_name, request);
let redacted_uri_query = get_query_redacted(
self.query_params_allow_all,
&self.query_param_names,
request,
);
let span_name = [Some(request.method().as_str()), matched_path]
.into_iter()
.flatten()
.join(" ");
error_span!("HTTP",
otel.name = span_name,
otel.kind = "SERVER",
{ HTTP_REQUEST_METHOD } = %request.method(),
{ HTTP_ROUTE } = optional_trace_field(matched_path),
{ NETWORK_PROTOCOL_VERSION } = ?request.version(),
{ URL_PATH } = %request.uri().path(),
{ URL_QUERY } = optional_trace_field(redacted_uri_query),
request_id = optional_trace_field(request_id),
{ HTTP_RESPONSE_STATUS_CODE } = field::Empty,
)
}
}
fn get_matched_path<B>(request: &Request<B>) -> Option<&str> {
request
.extensions()
.get::<MatchedPath>()
.map(|path| path.as_str())
}
fn get_request_id<'a, B>(
request_id_header_name: &'a str,
request: &'a Request<B>,
) -> Option<&'a str> {
request
.headers()
.get(request_id_header_name)
.and_then(|v| v.to_str().ok())
}
fn get_query_redacted<B>(
allow_all: bool,
allowed_names: &HashSet<String>,
request: &Request<B>,
) -> Option<String> {
let uri = if let Ok(mut uri) = Url::from_str("https://example.com") {
uri.set_query(request.uri().query());
uri
} else {
return None;
};
let redacted = uri
.query_pairs()
.into_iter()
.map(|(key, value)| {
#[allow(clippy::if_same_then_else)]
let value = if allow_all {
value
} else if !allowed_names.is_empty() && allowed_names.contains(&key.to_lowercase()) {
value
} else {
Cow::from("REDACTED")
};
format!("{key}={value}")
})
.join("&");
Some(redacted)
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CustomOnRequest {
pub allow_all_headers: bool,
pub request_header_names: HashSet<HeaderName>,
}
impl CustomOnRequest {
pub fn new(tracing_config: &TracingConfig) -> Self {
let request_header_names = tracing_config
.request_header_names
.iter()
.filter_map(|name| HeaderName::from_str(name).ok())
.collect();
Self {
allow_all_headers: tracing_config.request_headers_allow_all,
request_header_names,
}
}
}
impl<B> OnRequest<B> for CustomOnRequest {
fn on_request(&mut self, request: &Request<B>, _: &Span) {
let request_headers = allowed_headers(
request.headers(),
&self.request_header_names,
self.allow_all_headers,
);
info!(http.request.headers = ?request_headers, "started processing request");
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CustomOnResponse {
pub allow_all_headers: bool,
pub response_header_names: HashSet<HeaderName>,
}
impl CustomOnResponse {
pub fn new(tracing_config: &TracingConfig) -> CustomOnResponse {
let response_header_names = tracing_config
.response_header_names
.iter()
.filter_map(|name| HeaderName::from_str(name).ok())
.collect();
CustomOnResponse {
allow_all_headers: tracing_config.response_headers_allow_all,
response_header_names,
}
}
}
impl<B> OnResponse<B> for CustomOnResponse {
fn on_response(self, response: &Response<B>, latency: Duration, span: &Span) {
span.record(HTTP_RESPONSE_STATUS_CODE, response.status().as_u16());
let response_headers = allowed_headers(
response.headers(),
&self.response_header_names,
self.allow_all_headers,
);
info!(
http.response.headers = ?response_headers,
latency = latency.as_millis(),
"finished processing request"
)
}
}
fn allowed_headers<'a>(
headers: &'a HeaderMap,
allow_list_headers: &'a HashSet<HeaderName>,
allow_all: bool,
) -> BTreeMap<&'a str, &'a HeaderValue> {
headers
.iter()
.filter(|(key, _)| allow_all || allow_list_headers.contains(*key))
.map(|(key, value)| (key.as_str(), value))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AppConfig;
use rstest::rstest;
#[rstest]
#[case(false, Some(true), true)]
#[case(false, Some(false), false)]
#[cfg_attr(coverage_nightly, coverage(off))]
fn tracing_enabled(
#[case] default_enable: bool,
#[case] enable: Option<bool>,
#[case] expected_enabled: bool,
) {
let mut config = AppConfig::test(None).unwrap();
config.service.http.custom.middleware.default_enable = default_enable;
config.service.http.custom.middleware.tracing.common.enable = enable;
let context = AppContext::test(Some(config), None, None).unwrap();
let middleware = TracingMiddleware;
assert_eq!(middleware.enabled(&context), expected_enabled);
}
#[rstest]
#[case(None, -9980)]
#[case(Some(1234), 1234)]
#[cfg_attr(coverage_nightly, coverage(off))]
fn tracing_priority(#[case] override_priority: Option<i32>, #[case] expected_priority: i32) {
let mut config = AppConfig::test(None).unwrap();
if let Some(priority) = override_priority {
config
.service
.http
.custom
.middleware
.tracing
.common
.priority = priority;
}
let context = AppContext::test(Some(config), None, None).unwrap();
let middleware = TracingMiddleware;
assert_eq!(middleware.priority(&context), expected_priority);
}
}