use axum::extract::FromRequestParts;
use hashbrown::HashMap;
use http::HeaderMap;
use http::request::Parts;
use polaris_graph::hooks::api::BoxedHook;
use polaris_graph::hooks::schedule::OnGraphStart;
use polaris_graph::hooks::{GraphEvent, HooksAPI};
use polaris_system::param::SystemContext;
use polaris_system::plugin::{Plugin, ScheduleId, Version};
use polaris_system::resource::LocalResource;
use polaris_system::server::Server;
use std::any::TypeId;
use std::convert::Infallible;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RequestContext {
pub trace_id: String,
pub correlation_id: Option<String>,
pub request_id: Option<String>,
pub extras: HashMap<String, String>,
}
impl LocalResource for RequestContext {}
impl Default for RequestContext {
fn default() -> Self {
Self {
trace_id: generate_trace_id(),
correlation_id: None,
request_id: None,
extras: HashMap::new(),
}
}
}
impl RequestContext {
#[must_use]
pub fn from_headers(headers: &HeaderMap) -> Self {
let header_str = |name: &str| {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(String::from)
};
Self {
trace_id: header_str("x-trace-id").unwrap_or_else(generate_trace_id),
correlation_id: header_str("x-correlation-id"),
request_id: header_str("x-request-id"),
extras: HashMap::new(),
}
}
}
impl<S: Send + Sync> FromRequestParts<S> for RequestContext {
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Infallible> {
Ok(Self::from_headers(&parts.headers))
}
}
fn generate_trace_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let thread_id = std::thread::current().id();
format!("{nanos:x}-{thread_id:?}")
}
#[derive(Debug, Clone, Default)]
pub struct HttpHeaders(pub HeaderMap);
impl LocalResource for HttpHeaders {}
#[derive(Debug, Default, Clone, Copy)]
pub struct RequestContextPlugin;
impl Plugin for RequestContextPlugin {
const ID: &'static str = "polaris::request_context";
const VERSION: Version = Version::new(0, 0, 1);
fn build(&self, server: &mut Server) {
server.register_local(RequestContext::default);
if !server.contains_api::<HooksAPI>() {
server.insert_api(HooksAPI::new());
}
let hooks = server
.api::<HooksAPI>()
.expect("HooksAPI must be present after initialization");
hooks
.register_boxed(
ScheduleId::of::<OnGraphStart>(),
"request_context_from_headers",
BoxedHook::new(
|ctx: &mut SystemContext<'_>, _event: &GraphEvent| {
let req_ctx = {
let Ok(headers) = ctx.get_resource::<HttpHeaders>() else {
return;
};
RequestContext::from_headers(&headers.0)
};
ctx.insert(req_ctx);
},
vec![TypeId::of::<RequestContext>()],
),
)
.expect("RequestContextPlugin hook registration must not fail");
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::HeaderValue;
#[test]
fn default_request_context_has_trace_id() {
let ctx = RequestContext::default();
assert!(!ctx.trace_id.is_empty());
assert!(ctx.correlation_id.is_none());
assert!(ctx.request_id.is_none());
assert!(ctx.extras.is_empty());
}
#[test]
fn default_trace_ids_are_unique() {
let a = RequestContext::default();
let b = RequestContext::default();
assert_ne!(a.trace_id, b.trace_id);
}
#[test]
fn from_headers_populates_all_known_fields() {
let mut headers = HeaderMap::new();
headers.insert("x-trace-id", HeaderValue::from_static("trace-abc"));
headers.insert("x-correlation-id", HeaderValue::from_static("corr-123"));
headers.insert("x-request-id", HeaderValue::from_static("req-xyz"));
let ctx = RequestContext::from_headers(&headers);
assert_eq!(ctx.trace_id, "trace-abc");
assert_eq!(ctx.correlation_id.as_deref(), Some("corr-123"));
assert_eq!(ctx.request_id.as_deref(), Some("req-xyz"));
}
#[test]
fn from_headers_falls_back_when_trace_id_missing() {
let headers = HeaderMap::new();
let ctx = RequestContext::from_headers(&headers);
assert!(!ctx.trace_id.is_empty());
assert!(ctx.correlation_id.is_none());
assert!(ctx.request_id.is_none());
}
#[tokio::test]
async fn plugin_registers_local_resource() {
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let ctx = server.create_context();
assert!(ctx.contains_resource::<RequestContext>());
let req = ctx.get_resource::<RequestContext>().unwrap();
assert!(!req.trace_id.is_empty());
}
#[tokio::test]
async fn injected_context_overrides_default() {
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let mut ctx = server.create_context();
ctx.insert(RequestContext {
trace_id: "custom-123".into(),
correlation_id: Some("corr-456".into()),
..Default::default()
});
let req = ctx.get_resource::<RequestContext>().unwrap();
assert_eq!(req.trace_id, "custom-123");
assert_eq!(req.correlation_id.as_deref(), Some("corr-456"));
}
#[tokio::test]
async fn on_graph_start_hook_builds_from_http_headers() {
use polaris_graph::hooks::schedule::OnGraphStart;
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let mut headers = HeaderMap::new();
headers.insert("x-trace-id", HeaderValue::from_static("from-header"));
headers.insert("x-request-id", HeaderValue::from_static("req-42"));
let mut ctx = server.create_context();
ctx.insert(HttpHeaders(headers));
let hooks = server.api::<HooksAPI>().expect("HooksAPI present");
hooks.invoke(
ScheduleId::of::<OnGraphStart>(),
&mut ctx,
&GraphEvent::GraphStart {
run_id: polaris_graph::RunId::new(),
labels: polaris_graph::RunLabels::empty(),
node_count: 0,
node_map: Vec::new(),
},
);
let req = ctx.get_resource::<RequestContext>().unwrap();
assert_eq!(req.trace_id, "from-header");
assert_eq!(req.request_id.as_deref(), Some("req-42"));
}
#[tokio::test]
async fn from_request_parts_reads_headers() {
use axum::extract::FromRequestParts;
use http::Request;
let req = Request::builder()
.header("x-trace-id", "trace-extract")
.header("x-request-id", "req-extract")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let ctx = RequestContext::from_request_parts(&mut parts, &())
.await
.unwrap();
assert_eq!(ctx.trace_id, "trace-extract");
assert_eq!(ctx.request_id.as_deref(), Some("req-extract"));
}
#[tokio::test]
async fn on_graph_start_hook_no_ops_without_http_headers() {
use polaris_graph::hooks::schedule::OnGraphStart;
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let mut ctx = server.create_context();
let original_trace_id = ctx
.get_resource::<RequestContext>()
.unwrap()
.trace_id
.clone();
let hooks = server.api::<HooksAPI>().expect("HooksAPI present");
hooks.invoke(
ScheduleId::of::<OnGraphStart>(),
&mut ctx,
&GraphEvent::GraphStart {
run_id: polaris_graph::RunId::new(),
labels: polaris_graph::RunLabels::empty(),
node_count: 0,
node_map: Vec::new(),
},
);
let req = ctx.get_resource::<RequestContext>().unwrap();
assert_eq!(req.trace_id, original_trace_id);
assert!(req.request_id.is_none());
}
}