use async_trait::async_trait;
use hyper::header::HeaderName;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct Span {
pub span_id: String,
pub parent_span_id: Option<String>,
pub trace_id: String,
pub operation_name: String,
pub start_time: Instant,
pub end_time: Option<Instant>,
pub tags: HashMap<String, String>,
pub status: SpanStatus,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SpanStatus {
Active,
Ok,
Error,
}
impl Span {
pub fn new(trace_id: String, operation_name: String) -> Self {
Self {
span_id: uuid::Uuid::now_v7().to_string(),
parent_span_id: None,
trace_id,
operation_name,
start_time: Instant::now(),
end_time: None,
tags: HashMap::new(),
status: SpanStatus::Active,
}
}
pub fn with_parent(mut self, parent_span_id: String) -> Self {
self.parent_span_id = Some(parent_span_id);
self
}
pub fn add_tag(&mut self, key: String, value: String) {
self.tags.insert(key, value);
}
pub fn end(&mut self) {
self.end_time = Some(Instant::now());
if self.status == SpanStatus::Active {
self.status = SpanStatus::Ok;
}
}
pub fn mark_error(&mut self) {
self.status = SpanStatus::Error;
}
pub fn duration_ms(&self) -> Option<f64> {
self.end_time
.map(|end| (end - self.start_time).as_secs_f64() * 1000.0)
}
}
const DEFAULT_MAX_SPANS: usize = 10_000;
#[derive(Debug)]
pub struct TraceStore {
spans: RwLock<HashMap<String, Span>>,
max_spans: usize,
}
impl Default for TraceStore {
fn default() -> Self {
Self {
spans: RwLock::new(HashMap::new()),
max_spans: DEFAULT_MAX_SPANS,
}
}
}
impl TraceStore {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_spans(max_spans: usize) -> Self {
Self {
spans: RwLock::new(HashMap::new()),
max_spans,
}
}
pub fn start_span(&self, trace_id: String, operation_name: String) -> String {
let span = Span::new(trace_id, operation_name);
let span_id = span.span_id.clone();
let mut spans = self.spans.write().unwrap_or_else(|e| e.into_inner());
spans.insert(span_id.clone(), span);
if spans.len() > self.max_spans {
spans.retain(|_, s| s.end_time.is_none());
}
drop(spans);
span_id
}
pub fn end_span(&self, span_id: &str) {
if let Some(span) = self
.spans
.write()
.unwrap_or_else(|e| e.into_inner())
.get_mut(span_id)
{
span.end();
}
}
pub fn mark_span_error(&self, span_id: &str) {
if let Some(span) = self
.spans
.write()
.unwrap_or_else(|e| e.into_inner())
.get_mut(span_id)
{
span.mark_error();
}
}
pub fn add_span_tag(&self, span_id: &str, key: String, value: String) {
if let Some(span) = self
.spans
.write()
.unwrap_or_else(|e| e.into_inner())
.get_mut(span_id)
{
span.add_tag(key, value);
}
}
pub fn get_span(&self, span_id: &str) -> Option<Span> {
self.spans
.read()
.unwrap_or_else(|e| e.into_inner())
.get(span_id)
.cloned()
}
pub fn completed_spans(&self) -> Vec<Span> {
self.spans
.read()
.unwrap()
.values()
.filter(|s| s.end_time.is_some())
.cloned()
.collect()
}
pub fn clear_completed(&self) {
self.spans
.write()
.unwrap()
.retain(|_, span| span.end_time.is_none());
}
}
pub const TRACE_ID_HEADER: &str = "X-Trace-ID";
pub const SPAN_ID_HEADER: &str = "X-Span-ID";
pub const PARENT_SPAN_ID_HEADER: &str = "X-Parent-Span-ID";
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct TracingConfig {
pub enabled: bool,
pub sample_rate: f64,
pub trace_id_header: String,
pub span_id_header: String,
pub exclude_paths: Vec<String>,
}
impl TracingConfig {
pub fn new() -> Self {
Self {
enabled: true,
sample_rate: 1.0,
trace_id_header: TRACE_ID_HEADER.to_string(),
span_id_header: SPAN_ID_HEADER.to_string(),
exclude_paths: vec!["/health".to_string(), "/metrics".to_string()],
}
}
pub fn with_sample_rate(mut self, rate: f64) -> Self {
self.sample_rate = rate.clamp(0.0, 1.0);
self
}
pub fn disabled(mut self) -> Self {
self.enabled = false;
self
}
pub fn with_excluded_paths(mut self, paths: Vec<String>) -> Self {
self.exclude_paths.extend(paths);
self
}
}
impl Default for TracingConfig {
fn default() -> Self {
Self::new()
}
}
pub struct TracingMiddleware {
config: TracingConfig,
store: Arc<TraceStore>,
}
impl TracingMiddleware {
pub fn new(config: TracingConfig) -> Self {
Self {
config,
store: Arc::new(TraceStore::new()),
}
}
pub fn with_defaults() -> Self {
Self::new(TracingConfig::default())
}
pub fn from_arc(config: TracingConfig, store: Arc<TraceStore>) -> Self {
Self { config, store }
}
pub fn store(&self) -> &TraceStore {
&self.store
}
pub fn store_arc(&self) -> Arc<TraceStore> {
Arc::clone(&self.store)
}
fn should_exclude(&self, path: &str) -> bool {
self.config
.exclude_paths
.iter()
.any(|p| path.starts_with(p))
}
fn should_sample(&self) -> bool {
if self.config.sample_rate >= 1.0 {
return true;
}
if self.config.sample_rate <= 0.0 {
return false;
}
use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
let random_state = RandomState::new();
let hash = random_state.hash_one(Instant::now());
(hash as f64 / u64::MAX as f64) < self.config.sample_rate
}
fn get_or_generate_trace_id(&self, request: &Request) -> String {
request
.headers
.get(&self.config.trace_id_header)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::now_v7().to_string())
}
}
impl Default for TracingMiddleware {
fn default() -> Self {
Self::with_defaults()
}
}
#[async_trait]
impl Middleware for TracingMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let path = request.uri.path();
if !self.config.enabled || self.should_exclude(path) {
return handler.handle(request).await;
}
if !self.should_sample() {
return handler.handle(request).await;
}
let trace_id = self.get_or_generate_trace_id(&request);
let operation_name = format!("{} {}", request.method.as_str(), path);
let span_id = self.store.start_span(trace_id.clone(), operation_name);
self.store.add_span_tag(
&span_id,
"http.method".to_string(),
request.method.as_str().to_string(),
);
self.store
.add_span_tag(&span_id, "http.path".to_string(), path.to_string());
let result = handler.handle(request).await;
match &result {
Ok(response) => {
self.store.add_span_tag(
&span_id,
"http.status_code".to_string(),
response.status.as_u16().to_string(),
);
if !response.status.is_success() {
self.store.mark_span_error(&span_id);
}
}
Err(_) => {
self.store.mark_span_error(&span_id);
}
}
self.store.end_span(&span_id);
let mut response = result?;
if let (Ok(trace_header), Ok(trace_value)) = (
self.config.trace_id_header.parse::<HeaderName>(),
trace_id.parse(),
) {
response.headers.insert(trace_header, trace_value);
}
if let (Ok(span_header), Ok(span_value)) = (
self.config.span_id_header.parse::<HeaderName>(),
span_id.parse(),
) {
response.headers.insert(span_header, span_value);
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
struct TestHandler {
status: StatusCode,
}
impl TestHandler {
fn new(status: StatusCode) -> Self {
Self { status }
}
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(self.status).with_body(Bytes::from("OK")))
}
}
#[tokio::test]
async fn test_basic_tracing() {
let config = TracingConfig::new();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(TRACE_ID_HEADER));
assert!(response.headers.contains_key(SPAN_ID_HEADER));
let spans = middleware.store.completed_spans();
assert_eq!(spans.len(), 1);
assert_eq!(spans[0].status, SpanStatus::Ok);
}
#[tokio::test]
async fn test_propagate_trace_id() {
let config = TracingConfig::new();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let existing_trace_id = "existing-trace-123";
let mut headers = HeaderMap::new();
headers.insert(TRACE_ID_HEADER, existing_trace_id.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(
response.headers.get(TRACE_ID_HEADER).unwrap(),
existing_trace_id
);
}
#[tokio::test]
async fn test_error_status() {
let config = TracingConfig::new();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::INTERNAL_SERVER_ERROR));
let request = Request::builder()
.method(Method::GET)
.uri("/error")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let _response = middleware.process(request, handler).await.unwrap();
let spans = middleware.store.completed_spans();
assert_eq!(spans.len(), 1);
assert_eq!(spans[0].status, SpanStatus::Error);
}
#[tokio::test]
async fn test_exclude_paths() {
let config = TracingConfig::new();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/health")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(TRACE_ID_HEADER));
assert_eq!(middleware.store.completed_spans().len(), 0);
}
#[tokio::test]
async fn test_disabled_tracing() {
let config = TracingConfig::new().disabled();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(TRACE_ID_HEADER));
assert_eq!(middleware.store.completed_spans().len(), 0);
}
#[tokio::test]
async fn test_span_metadata() {
let config = TracingConfig::new();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::POST)
.uri("/api/users")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let _response = middleware.process(request, handler).await.unwrap();
let spans = middleware.store.completed_spans();
assert_eq!(spans.len(), 1);
let span = &spans[0];
assert_eq!(span.operation_name, "POST /api/users");
assert_eq!(span.tags.get("http.method").unwrap(), "POST");
assert_eq!(span.tags.get("http.path").unwrap(), "/api/users");
assert_eq!(span.tags.get("http.status_code").unwrap(), "200");
}
#[tokio::test]
async fn test_span_duration() {
let config = TracingConfig::new();
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let _response = middleware.process(request, handler).await.unwrap();
let spans = middleware.store.completed_spans();
let span = &spans[0];
assert!(span.duration_ms().is_some());
assert!(span.duration_ms().unwrap() >= 0.0);
}
#[tokio::test]
async fn test_clear_completed_spans() {
let config = TracingConfig::new();
let middleware = Arc::new(TracingMiddleware::new(config));
let handler = Arc::new(TestHandler::new(StatusCode::OK));
for _ in 0..5 {
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let _response = middleware.process(request, handler.clone()).await.unwrap();
}
assert_eq!(middleware.store.completed_spans().len(), 5);
middleware.store.clear_completed();
assert_eq!(middleware.store.completed_spans().len(), 0);
}
#[tokio::test]
async fn test_sample_rate_zero() {
let config = TracingConfig::new().with_sample_rate(0.0);
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(TRACE_ID_HEADER));
}
#[tokio::test]
async fn test_sample_rate_one() {
let config = TracingConfig::new().with_sample_rate(1.0);
let middleware = TracingMiddleware::new(config);
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(TRACE_ID_HEADER));
}
#[tokio::test]
async fn test_default_middleware() {
let middleware = TracingMiddleware::default();
let handler = Arc::new(TestHandler::new(StatusCode::OK));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(TRACE_ID_HEADER));
}
#[test]
fn test_trace_store_with_max_spans() {
let store = TraceStore::with_max_spans(3);
let id1 = store.start_span("t1".to_string(), "op1".to_string());
let id2 = store.start_span("t2".to_string(), "op2".to_string());
let id3 = store.start_span("t3".to_string(), "op3".to_string());
assert!(store.get_span(&id1).is_some());
assert!(store.get_span(&id2).is_some());
assert!(store.get_span(&id3).is_some());
}
#[test]
fn test_trace_store_evicts_completed_spans_on_overflow() {
let store = TraceStore::with_max_spans(3);
let id1 = store.start_span("t1".to_string(), "op1".to_string());
let id2 = store.start_span("t2".to_string(), "op2".to_string());
let id3 = store.start_span("t3".to_string(), "op3".to_string());
store.end_span(&id1);
store.end_span(&id2);
let id4 = store.start_span("t4".to_string(), "op4".to_string());
assert!(store.get_span(&id1).is_none());
assert!(store.get_span(&id2).is_none());
assert!(store.get_span(&id3).is_some());
assert!(store.get_span(&id4).is_some());
}
#[test]
fn test_trace_store_no_eviction_when_under_limit() {
let store = TraceStore::with_max_spans(10);
let id1 = store.start_span("t1".to_string(), "op1".to_string());
store.end_span(&id1);
let id2 = store.start_span("t2".to_string(), "op2".to_string());
assert!(store.get_span(&id1).is_some());
assert!(store.get_span(&id2).is_some());
}
#[rstest::rstest]
fn test_rwlock_poison_recovery_trace_store() {
let store = Arc::new(TraceStore::new());
let span_id = store.start_span("trace-1".to_string(), "GET /test".to_string());
let store_clone = Arc::clone(&store);
let _ = std::thread::spawn(move || {
let _guard = store_clone.spans.write().unwrap();
panic!("intentional panic to poison lock");
})
.join();
store.add_span_tag(&span_id, "key".to_string(), "value".to_string());
store.mark_span_error(&span_id);
store.end_span(&span_id);
let span = store.get_span(&span_id);
assert!(span.is_some());
assert_eq!(span.unwrap().status, SpanStatus::Error);
}
}