use async_trait::async_trait;
use reqwest::{header::HeaderValue, Request, Response};
use reqwest_middleware::{Middleware, Next, RequestBuilder};
use task_local_extensions::Extensions;
#[cfg(not(feature = "uuid"))]
fn generate_id() -> String {
nanoid::nanoid!()
}
#[cfg(feature = "uuid")]
fn generate_id() -> String {
uuid::Uuid::new_v4().to_string()
}
#[derive(Debug, Clone)]
pub struct RequestId {
pub request_id: String,
}
impl Default for RequestId {
fn default() -> Self {
Self {
request_id: generate_id(),
}
}
}
impl RequestId {
pub fn new(request_id: impl ToString) -> Self {
Self {
request_id: request_id.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct TraceId {
pub trace_id: String,
pub span_id: Option<String>,
}
impl Default for TraceId {
fn default() -> Self {
Self {
trace_id: generate_id(),
span_id: None,
}
}
}
impl TraceId {
pub fn new(trace_id: impl ToString, span_id: Option<impl ToString>) -> Self {
Self {
trace_id: trace_id.to_string(),
span_id: span_id.map(|id| id.to_string()),
}
}
}
struct DefaultId(String);
impl Default for DefaultId {
fn default() -> Self {
Self(generate_id())
}
}
#[derive(Default)]
pub(crate) struct RequestTraceIdInjector;
impl RequestTraceIdInjector {
pub(crate) fn inject_to_builder(req: RequestBuilder) -> RequestBuilder {
let mut req = req;
let request_id = if let Some(request_id) = req
.extensions()
.get::<RequestId>()
.map(|id| id.request_id.clone())
{
req = req.header("X-Request-ID", &request_id);
Some(request_id)
} else {
None
};
let trace_id = if let Some((trace_id, span_id)) = req
.extensions()
.get::<TraceId>()
.map(|id| (id.trace_id.clone(), id.span_id.clone()))
{
req = req.header("X-Trace-ID", &trace_id);
if let Some(span_id) = span_id {
req = req.header("X-Span-ID", span_id);
}
Some(trace_id)
} else {
None
};
match (request_id, trace_id) {
(Some(id), None) | (None, Some(id)) => req = req.with_extension(DefaultId(id)),
(None, None) => req = req.with_extension(DefaultId::default()),
_ => {}
};
req
}
pub(crate) fn inject_to_request(req: Request, extensions: &Extensions) -> Request {
let mut req = req;
let headers = req.headers_mut();
let default_id = extensions
.get::<DefaultId>()
.map(|id| id.0.clone())
.unwrap_or_else(generate_id);
if !headers.contains_key("X-Request-ID") {
let request_id = extensions
.get::<RequestId>()
.map(|id| id.request_id.clone())
.unwrap_or_else(|| default_id.clone());
headers.insert("X-Request-ID", HeaderValue::from_str(&request_id).unwrap());
}
if !headers.contains_key("X-Trace-ID") {
let (trace_id, span_id) = match extensions.get::<TraceId>() {
Some(id) => (id.trace_id.clone(), id.span_id.clone()),
None => (default_id, None),
};
headers.insert("X-Trace-ID", HeaderValue::from_str(&trace_id).unwrap());
if let Some(span_id) = span_id {
headers.insert("X-Span-ID", HeaderValue::from_str(&span_id).unwrap());
}
}
req
}
}
#[async_trait]
impl Middleware for RequestTraceIdInjector {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response, reqwest_middleware::Error> {
let req = Self::inject_to_request(req, extensions);
next.run(req, extensions).await
}
}