Skip to main content

neumann_server/
correlation.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Request correlation and trace ID propagation.
3//!
4//! This module provides utilities for extracting and propagating trace IDs
5//! across gRPC requests to enable distributed tracing.
6
7use tonic::metadata::MetadataValue;
8use tonic::{Request, Response};
9use tracing::Span;
10use uuid::Uuid;
11
12/// Header name for trace ID in requests.
13pub const TRACE_ID_HEADER: &str = "x-request-id";
14
15/// Metadata key for trace ID in gRPC metadata.
16pub const TRACE_ID_METADATA: &str = "x-request-id";
17
18/// Extract trace ID from request metadata or generate a new one.
19pub fn extract_or_generate<T>(request: &Request<T>) -> String {
20    request
21        .metadata()
22        .get(TRACE_ID_HEADER)
23        .and_then(|v| v.to_str().ok())
24        .filter(|s| !s.is_empty())
25        .map_or_else(|| Uuid::new_v4().to_string(), ToString::to_string)
26}
27
28/// Create a tracing span with trace ID and request metadata.
29pub fn request_span(trace_id: &str, service: &str, method: &str) -> Span {
30    tracing::info_span!(
31        "grpc_request",
32        trace_id = %trace_id,
33        service = %service,
34        method = %method,
35    )
36}
37
38/// Add trace ID to response metadata.
39pub fn add_trace_id_to_response<T>(response: &mut Response<T>, trace_id: &str) {
40    if let Ok(value) = trace_id.parse::<MetadataValue<_>>() {
41        response.metadata_mut().insert(TRACE_ID_METADATA, value);
42    }
43}
44
45/// Extract trace ID from request and create a span.
46///
47/// Returns the trace ID and enters the span. The span will be active
48/// for the duration of the returned guard.
49pub fn start_request_span<T>(request: &Request<T>, service: &str, method: &str) -> RequestSpan {
50    let trace_id = extract_or_generate(request);
51    let span = request_span(&trace_id, service, method);
52    let guard = span.clone().entered();
53
54    tracing::debug!(
55        trace_id = %trace_id,
56        service = %service,
57        method = %method,
58        "Request started"
59    );
60
61    RequestSpan {
62        trace_id,
63        span,
64        guard,
65    }
66}
67
68/// A wrapper holding trace ID and its associated span.
69pub struct RequestSpan {
70    trace_id: String,
71    span: Span,
72    /// Guard that keeps the span entered. Dropped when `RequestSpan` is dropped.
73    #[allow(dead_code)]
74    guard: tracing::span::EnteredSpan,
75}
76
77impl RequestSpan {
78    /// Get the trace ID.
79    #[must_use]
80    pub fn trace_id(&self) -> &str {
81        &self.trace_id
82    }
83
84    /// Get a reference to the span.
85    #[must_use]
86    pub const fn span(&self) -> &Span {
87        &self.span
88    }
89
90    /// Add the trace ID to a response.
91    pub fn add_to_response<T>(&self, response: &mut Response<T>) {
92        add_trace_id_to_response(response, &self.trace_id);
93    }
94
95    /// Create a response with the trace ID in metadata.
96    pub fn into_response<T>(self, inner: T) -> Response<T> {
97        let mut response = Response::new(inner);
98        add_trace_id_to_response(&mut response, &self.trace_id);
99        response
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use tonic::metadata::MetadataValue;
107
108    #[test]
109    fn test_extract_from_header() {
110        let mut request = Request::new(());
111        let trace_id = "test-trace-id-12345";
112        request.metadata_mut().insert(
113            TRACE_ID_HEADER,
114            MetadataValue::try_from(trace_id).expect("valid value"),
115        );
116
117        let extracted = extract_or_generate(&request);
118        assert_eq!(extracted, trace_id);
119    }
120
121    #[test]
122    fn test_generate_when_missing() {
123        let request = Request::new(());
124        let trace_id = extract_or_generate(&request);
125
126        // Should be a valid UUID
127        assert!(Uuid::parse_str(&trace_id).is_ok());
128    }
129
130    #[test]
131    fn test_generate_when_empty() {
132        let mut request = Request::new(());
133        request.metadata_mut().insert(
134            TRACE_ID_HEADER,
135            MetadataValue::try_from("").expect("valid value"),
136        );
137
138        let trace_id = extract_or_generate(&request);
139
140        // Should generate a new UUID, not use the empty string
141        assert!(Uuid::parse_str(&trace_id).is_ok());
142    }
143
144    #[test]
145    fn test_span_creation() {
146        let span = request_span("trace-123", "QueryService", "execute");
147
148        // The span should record the fields
149        span.in_scope(|| {
150            tracing::info!("Inside span");
151        });
152    }
153
154    #[test]
155    fn test_response_metadata() {
156        let trace_id = "response-trace-id";
157        let mut response = Response::new(());
158
159        add_trace_id_to_response(&mut response, trace_id);
160
161        let value = response
162            .metadata()
163            .get(TRACE_ID_METADATA)
164            .expect("should have trace ID");
165        assert_eq!(value.to_str().expect("valid str"), trace_id);
166    }
167
168    #[test]
169    fn test_uuid_format() {
170        let request = Request::new(());
171        let trace_id = extract_or_generate(&request);
172
173        // Verify it's a valid UUID v4 format
174        let uuid = Uuid::parse_str(&trace_id).expect("should be valid UUID");
175        assert_eq!(uuid.get_version_num(), 4);
176    }
177
178    #[test]
179    fn test_request_span_trace_id() {
180        let request = Request::new(());
181        let req_span = start_request_span(&request, "TestService", "test");
182
183        // Should have a valid UUID as trace ID
184        assert!(Uuid::parse_str(req_span.trace_id()).is_ok());
185    }
186
187    #[test]
188    fn test_request_span_with_existing_trace_id() {
189        let mut request = Request::new(());
190        let expected_trace_id = "existing-trace-123";
191        request.metadata_mut().insert(
192            TRACE_ID_HEADER,
193            MetadataValue::try_from(expected_trace_id).expect("valid value"),
194        );
195
196        let req_span = start_request_span(&request, "TestService", "test");
197        assert_eq!(req_span.trace_id(), expected_trace_id);
198    }
199
200    #[test]
201    fn test_request_span_into_response() {
202        let request = Request::new(());
203        let req_span = start_request_span(&request, "TestService", "test");
204        let trace_id = req_span.trace_id().to_string();
205
206        let response = req_span.into_response("result");
207
208        let value = response
209            .metadata()
210            .get(TRACE_ID_METADATA)
211            .expect("should have trace ID");
212        assert_eq!(value.to_str().expect("valid str"), trace_id);
213        assert_eq!(*response.get_ref(), "result");
214    }
215
216    #[test]
217    fn test_request_span_add_to_response() {
218        let request = Request::new(());
219        let req_span = start_request_span(&request, "TestService", "test");
220        let trace_id = req_span.trace_id().to_string();
221
222        let mut response = Response::new(42);
223        req_span.add_to_response(&mut response);
224
225        let value = response
226            .metadata()
227            .get(TRACE_ID_METADATA)
228            .expect("should have trace ID");
229        assert_eq!(value.to_str().expect("valid str"), trace_id);
230    }
231
232    #[test]
233    fn test_span_accessor() {
234        let request = Request::new(());
235        let req_span = start_request_span(&request, "TestService", "test");
236
237        // Should be able to access the span
238        req_span.span().in_scope(|| {
239            tracing::info!("In request span");
240        });
241    }
242}