use hashbrown::HashMap;
use polaris_system::plugin::{Plugin, Version};
use polaris_system::resource::LocalResource;
use polaris_system::server::Server;
#[derive(Debug, Clone)]
pub struct RequestContext {
pub trace_id: String,
pub correlation_id: Option<String>,
pub extras: HashMap<String, String>,
}
impl LocalResource for RequestContext {}
impl Default for RequestContext {
fn default() -> Self {
Self {
trace_id: generate_trace_id(),
correlation_id: None,
extras: HashMap::new(),
}
}
}
fn generate_trace_id() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let thread_id = std::thread::current().id();
format!("{nanos:x}-{thread_id:?}")
}
#[derive(Debug, Default, Clone, Copy)]
pub struct RequestContextPlugin;
impl Plugin for RequestContextPlugin {
const ID: &'static str = "polaris::request_context";
const VERSION: Version = Version::new(0, 0, 1);
fn build(&self, server: &mut Server) {
server.register_local(RequestContext::default);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_request_context_has_trace_id() {
let ctx = RequestContext::default();
assert!(!ctx.trace_id.is_empty());
assert!(ctx.correlation_id.is_none());
assert!(ctx.extras.is_empty());
}
#[test]
fn default_trace_ids_are_unique() {
let a = RequestContext::default();
let b = RequestContext::default();
assert_ne!(a.trace_id, b.trace_id);
}
#[tokio::test]
async fn plugin_registers_local_resource() {
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let ctx = server.create_context();
assert!(ctx.contains_resource::<RequestContext>());
let req = ctx.get_resource::<RequestContext>().unwrap();
assert!(!req.trace_id.is_empty());
}
#[tokio::test]
async fn injected_context_overrides_default() {
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let mut ctx = server.create_context();
ctx.insert(RequestContext {
trace_id: "custom-123".into(),
correlation_id: Some("corr-456".into()),
extras: HashMap::new(),
});
let req = ctx.get_resource::<RequestContext>().unwrap();
assert_eq!(req.trace_id, "custom-123");
assert_eq!(req.correlation_id.as_deref(), Some("corr-456"));
}
#[tokio::test]
async fn extras_propagate_additional_headers() {
let mut server = Server::new();
server.add_plugins(RequestContextPlugin);
server.finish().await;
let mut extras = HashMap::new();
extras.insert("x-forwarded-for".into(), "10.0.0.1".into());
extras.insert("baggage".into(), "env=prod".into());
let mut ctx = server.create_context();
ctx.insert(RequestContext {
extras,
..Default::default()
});
let req = ctx.get_resource::<RequestContext>().unwrap();
assert_eq!(req.extras.get("x-forwarded-for").unwrap(), "10.0.0.1");
assert_eq!(req.extras.get("baggage").unwrap(), "env=prod");
}
}