use std::{
fmt::{self, Display},
sync::OnceLock,
time::Duration,
};
use anyhow::{Context, bail, ensure};
use http::{HeaderName, HeaderValue};
use lexe_common::time::DisplayMs;
use lexe_crypto::rng::{RngCore, ThreadFastRng};
use tracing::{Dispatch, span, warn};
#[cfg(doc)]
use crate::rest::RestClient;
pub(crate) const TARGET: &str = "lxapi";
pub(crate) static TRACE_ID_HEADER_NAME: HeaderName =
HeaderName::from_static("lexe-trace-id");
#[derive(Clone, PartialEq)]
pub struct TraceId(HeaderValue);
pub static GET_TRACE_ID_FN: OnceLock<
fn(&span::Id, &Dispatch) -> anyhow::Result<Option<TraceId>>,
> = OnceLock::new();
pub static INSERT_TRACE_ID_FN: OnceLock<
fn(&span::Id, &Dispatch, TraceId) -> anyhow::Result<Option<TraceId>>,
> = OnceLock::new();
impl TraceId {
const LENGTH: usize = 16;
pub fn generate() -> Self {
Self::from_rng(&mut ThreadFastRng::new())
}
pub fn from_rng(rng: &mut impl RngCore) -> Self {
use lexe_crypto::rng::RngExt;
let buf: [u8; Self::LENGTH] = rng.gen_alphanum_bytes();
let header_value = HeaderValue::from_bytes(&buf).expect(
"All alphanumeric bytes are in range (32..=255), \
and none are byte 127 (DEL). This is also checked in tests.",
);
Self(header_value)
}
pub fn as_str(&self) -> &str {
debug_assert!(std::str::from_utf8(self.0.as_bytes()).is_ok());
unsafe { std::str::from_utf8_unchecked(self.0.as_bytes()) }
}
pub fn to_header_value(&self) -> HeaderValue {
self.0.clone()
}
fn get_from_span(span: &tracing::Span) -> Option<Self> {
#[cfg(any(test, feature = "test-utils"))]
if span.is_disabled() {
return None;
}
let try_get_trace_id = || {
let get_trace_id_fn = GET_TRACE_ID_FN.get().context(
"GET_TRACE_ID_FN not set. Did lexe_logger::try_init() \
initialize the TraceId statics?",
)?;
let maybe_trace_id = span
.with_subscriber(|(id, dispatch)| get_trace_id_fn(id, dispatch))
.context("Span is not enabled")?
.context("get_trace_id_fn (get_trace_id_from_span) failed")?;
Ok::<_, anyhow::Error>(maybe_trace_id)
};
try_get_trace_id()
.inspect_err(|e| warn!("Failed to check for trace id: {e:#}"))
.unwrap_or_default()
}
fn insert_into_span(self, span: &tracing::Span) {
let try_insert_trace_id = || {
let insert_trace_id_fn = INSERT_TRACE_ID_FN.get().context(
"INSERT_TRACE_ID_FN not set. Did lexe_logger::try_init() \
initialize the TraceId statics?",
)?;
let maybe_replaced = span
.with_subscriber(|(id, dispatch)| {
insert_trace_id_fn(id, dispatch, self)
})
.context("Span is not enabled")?
.context("insert_trace_id_into_span failed")?;
Ok::<_, anyhow::Error>(maybe_replaced)
};
try_insert_trace_id()
.unwrap_or_default()
.inspect(|replaced| warn!("Replaced existing TraceId: {replaced}"));
}
#[cfg(any(test, feature = "test-utils"))]
pub fn get_and_insert_test_impl() {
use tracing::{error_span, info};
GET_TRACE_ID_FN.get().expect("GET_TRACE_ID_FN not set");
INSERT_TRACE_ID_FN
.get()
.expect("INSERT_TRACE_ID_FN not set");
let trace_id1 = TraceId::generate();
let outer_span = error_span!("(outer)", trace_id=%trace_id1);
assert!(TraceId::get_from_span(&outer_span).is_none());
trace_id1.clone().insert_into_span(&outer_span);
outer_span.in_scope(|| {
info!("This msg should contain (outer) and `trace_id`");
let current_span = tracing::Span::current();
let trace_id2 = TraceId::get_from_span(¤t_span)
.expect("No trace id returned");
assert_eq!(trace_id1, trace_id2);
let inner_span =
error_span!("(inner)", trace_id = tracing::field::Empty);
inner_span.in_scope(|| {
info!("This msg should have (outer):(inner) and `trace_id`");
let current_span = tracing::Span::current();
let trace_id3 = TraceId::get_from_span(¤t_span)
.expect("No trace id returned");
assert_eq!(trace_id2, trace_id3);
});
});
info!("Test complete");
}
}
impl TryFrom<HeaderValue> for TraceId {
type Error = anyhow::Error;
fn try_from(src: HeaderValue) -> Result<Self, Self::Error> {
let src_bytes = src.as_bytes();
if src_bytes.len() != Self::LENGTH {
bail!("Source header value had wrong length");
}
let all_alphanumeric = src_bytes
.iter()
.all(|byte| char::is_alphanumeric(*byte as char));
ensure!(
all_alphanumeric,
"Source header value contained non-alphanumeric bytes"
);
Ok(Self(src))
}
}
impl Display for TraceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl fmt::Debug for TraceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self}")
}
}
#[cfg(any(test, feature = "test-utils"))]
mod arbitrary_impl {
use lexe_crypto::rng::FastRng;
use proptest::{
arbitrary::{Arbitrary, any},
strategy::{BoxedStrategy, Strategy},
};
use super::*;
impl Arbitrary for TraceId {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<FastRng>()
.prop_map(|mut rng| Self::from_rng(&mut rng))
.boxed()
}
}
}
#[macro_export]
macro_rules! define_trace_id_fns {
($subscriber:ty) => {
use anyhow::Context;
use lexe_api::trace::TraceId;
use tracing_subscriber::registry::LookupSpan;
fn get_trace_id_from_span(
id: &tracing::span::Id,
dispatch: &tracing::Dispatch,
) -> anyhow::Result<Option<TraceId>> {
let subscriber = dispatch.downcast_ref::<$subscriber>().context(
"Downcast failed. Did lexe_logger::try_init() define the trace_id \
fns with the correct subscriber type?",
)?;
let span_ref = subscriber
.span(id)
.context("Failed to get SpanRef from id")?;
let maybe_trace_id = span_ref
.scope()
.find_map(|span| span.extensions().get::<TraceId>().cloned());
Ok(maybe_trace_id)
}
fn insert_trace_id_into_span(
id: &tracing::span::Id,
dispatch: &tracing::Dispatch,
trace_id: TraceId,
) -> anyhow::Result<Option<TraceId>> {
let subscriber = dispatch.downcast_ref::<$subscriber>().context(
"Downcast failed. Did lexe_logger::try_init() define the trace_id \
fns with the correct subscriber type?",
)?;
let span_ref = subscriber.span(id).context("No span ref for id")?;
let maybe_replaced = span_ref.extensions_mut().replace(trace_id);
Ok(maybe_replaced)
}
};
}
pub(crate) mod client {
use tracing::info_span;
use super::*;
pub(crate) fn request_span(
req: &reqwest::Request,
from: &str,
to: &'static str,
) -> (tracing::Span, TraceId) {
let request_span = info_span!(
target: TARGET,
"(req)(cli)",
trace_id = tracing::field::Empty,
%from,
%to,
method = %req.method(),
url = %req.url(),
attempts_left = tracing::field::Empty,
);
let existing_trace_id =
TraceId::get_from_span(&tracing::Span::current());
let trace_id = match existing_trace_id {
Some(tid) => tid,
None => {
let trace_id = TraceId::generate();
request_span.record("trace_id", trace_id.as_str());
trace_id.clone().insert_into_span(&request_span);
trace_id
}
};
(request_span, trace_id)
}
}
pub(crate) mod server {
use anyhow::anyhow;
use http::header::USER_AGENT;
use tower_http::{
classify::{
ClassifiedResponse, ClassifyResponse, NeverClassifyEos,
SharedClassifier,
},
trace::{
MakeSpan, OnEos, OnFailure, OnRequest, OnResponse, TraceLayer,
},
};
use tracing::{debug, error, info_span, warn};
use super::*;
pub(crate) fn trace_layer(
api_span: tracing::Span,
) -> TraceLayer<
SharedClassifier<LxClassifyResponse>,
LxMakeSpan,
LxOnRequest,
LxOnResponse,
(),
LxOnEos,
LxOnFailure,
> {
TraceLayer::new(SharedClassifier::new(LxClassifyResponse))
.make_span_with(LxMakeSpan { api_span })
.on_request(LxOnRequest)
.on_response(LxOnResponse)
.on_body_chunk(())
.on_eos(LxOnEos)
.on_failure(LxOnFailure)
}
#[derive(Clone)]
pub(crate) struct LxClassifyResponse;
impl ClassifyResponse for LxClassifyResponse {
type FailureClass = anyhow::Error;
type ClassifyEos = NeverClassifyEos<Self::FailureClass>;
fn classify_response<B>(
self,
_response: &http::Response<B>,
) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
ClassifiedResponse::Ready(Ok(()))
}
fn classify_error<E: Display + 'static>(
self,
error: &E,
) -> Self::FailureClass {
anyhow!("{error:#}")
}
}
#[derive(Clone)]
pub(crate) struct LxMakeSpan {
api_span: tracing::Span,
}
impl<B> MakeSpan<B> for LxMakeSpan {
fn make_span(&mut self, request: &http::Request<B>) -> tracing::Span {
let url = request
.uri()
.path_and_query()
.map(|url| url.as_str())
.unwrap_or("/");
let trace_id = request
.headers()
.get(&TRACE_ID_HEADER_NAME)
.and_then(|value| TraceId::try_from(value.clone()).ok())
.unwrap_or_else(TraceId::generate);
let from = request
.headers()
.get(USER_AGENT)
.map(|value| value.to_str().unwrap_or("(non-ascii)"))
.unwrap_or("(none)");
let request_span = info_span!(
target: TARGET,
parent: self.api_span.clone(),
"(req)(srv)",
%trace_id,
%from,
method = %request.method().as_str(),
url = %url,
version = ?request.version(),
);
trace_id.insert_into_span(&request_span);
request_span
}
}
#[derive(Clone)]
pub(crate) struct LxOnRequest;
impl<B> OnRequest<B> for LxOnRequest {
fn on_request(
&mut self,
request: &http::Request<B>,
_request_span: &tracing::Span,
) {
let headers = request.headers();
debug!(target: TARGET, "New server request");
debug!(target: TARGET, ?headers, "Server request (headers)");
}
}
#[derive(Clone)]
pub(crate) struct LxOnResponse;
impl<B> OnResponse<B> for LxOnResponse {
fn on_response(
self,
response: &http::Response<B>,
resp_time: Duration,
_request_span: &tracing::Span,
) {
let status = response.status();
let headers = response.headers();
let resp_time = DisplayMs(resp_time);
if status.is_success() {
debug!(target: TARGET, %resp_time, ?status, "Done (success)");
} else if status.is_client_error() {
warn!(target: TARGET, %resp_time, ?status, "Done (client error)");
} else if status.is_server_error() && status.as_u16() == 503 {
warn!(target: TARGET, %resp_time, ?status, "Done (load shedded)");
} else if status.is_server_error() {
error!(target: TARGET, %resp_time, ?status, "Done (server error)");
} else {
debug!(target: TARGET, %resp_time, ?status, "Done (other)");
}
debug!(
target: TARGET, %resp_time, ?status, ?headers,
"Done (headers)",
);
}
}
#[derive(Clone)]
pub(crate) struct LxOnEos;
impl OnEos for LxOnEos {
fn on_eos(
self,
trailers: Option<&http::HeaderMap>,
stream_time: Duration,
_request_span: &tracing::Span,
) {
let num_trailers = trailers.map(|trailers| trailers.len());
let stream_time = DisplayMs(stream_time);
debug!(target: TARGET, %stream_time, ?num_trailers, "Stream ended");
}
}
#[derive(Clone)]
pub(crate) struct LxOnFailure;
impl<FailureClass: Display> OnFailure<FailureClass> for LxOnFailure {
fn on_failure(
&mut self,
fail_class: FailureClass,
fail_time: Duration,
_request_span: &tracing::Span,
) {
let fail_time = DisplayMs(fail_time);
warn!(target: TARGET, %fail_time, %fail_class, "Other failure");
}
}
}
#[cfg(test)]
mod test {
use proptest::{prop_assert_eq, proptest};
use super::*;
#[test]
fn trace_id_proptest() {
proptest!(|(id1: TraceId)| {
id1.as_str();
let id2 = TraceId::try_from(id1.to_header_value()).unwrap();
prop_assert_eq!(&id1, &id2);
});
}
}