use crate::tonic::common::get_status_code_from_headers;
use crate::tonic::Body;
use http::{HeaderMap, HeaderValue};
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry::trace::FutureExt;
use opentelemetry::trace::WithContext;
use opentelemetry_http::HeaderInjector;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use qcs_api_client_common::tracing_configuration::HeaderAttributesFilter;
use qcs_api_client_common::tracing_configuration::{
IncludeExclude, TracingConfiguration, TracingFilter,
};
use tonic::client::GrpcService;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::shared::make_grpc_request_span;
use super::shared::should_trace_request;
#[derive(Debug, Clone, Copy)]
enum MetadataAttributeType {
Request,
Response,
}
impl std::fmt::Display for MetadataAttributeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Request => write!(f, "request"),
Self::Response => write!(f, "response"),
}
}
}
fn set_metadata_attribute(
span: &tracing::Span,
include_exclude: &IncludeExclude<String>,
headers: &HeaderMap<HeaderValue>,
metadata_attribute_type: MetadataAttributeType,
) {
let headers_to_trace = include_exclude.get_header_attributes(headers);
for (key, value) in headers_to_trace.into_iter() {
span.set_attribute(
format!("rpc.grpc.{metadata_attribute_type}.metadata.{key}"),
value,
);
}
}
#[derive(Clone, Debug)]
pub struct MakeSpan {
enabled: bool,
request_headers: IncludeExclude<String>,
filter: Option<TracingFilter>,
base_url: String,
}
impl<B> tower_http::trace::MakeSpan<B> for MakeSpan {
fn make_span(&mut self, request: &http::Request<B>) -> tracing::Span {
if self.enabled
&& should_trace_request(self.base_url.as_str(), request, self.filter.as_ref())
{
let span = make_grpc_request_span(request);
let _ = span.set_parent(opentelemetry::Context::current());
set_metadata_attribute(
&span,
&self.request_headers,
request.headers(),
MetadataAttributeType::Request,
);
span
} else {
tracing::Span::none()
}
}
}
#[derive(Clone, Debug)]
pub struct OnEos {
response_headers: IncludeExclude<String>,
inner: tower_http::trace::DefaultOnEos,
}
impl tower_http::trace::OnEos for OnEos {
fn on_eos(
self,
trailers: Option<&HeaderMap>,
stream_duration: std::time::Duration,
span: &Span,
) {
use tracing_opentelemetry::OpenTelemetrySpanExt;
if let Some(trailers) = trailers {
if let Ok(status_code) = get_status_code_from_headers(trailers) {
span.set_attribute("rpc.grpc.status_code", format!("{}", status_code as u8));
}
set_metadata_attribute(
span,
&self.response_headers,
trailers,
MetadataAttributeType::Response,
);
}
self.inner.on_eos(trailers, stream_duration, span);
}
}
#[derive(Clone, Debug)]
pub struct OnResponse {
response_headers: IncludeExclude<String>,
inner: tower_http::trace::DefaultOnResponse,
}
impl Default for OnResponse {
fn default() -> Self {
Self {
response_headers: IncludeExclude::include_none(),
inner: tower_http::trace::DefaultOnResponse::default(),
}
}
}
impl<B> tower_http::trace::OnResponse<B> for OnResponse {
fn on_response(self, response: &http::Response<B>, latency: std::time::Duration, span: &Span) {
set_metadata_attribute(
span,
&self.response_headers,
response.headers(),
MetadataAttributeType::Response,
);
self.inner.on_response(response, latency, span);
}
}
type BaseTraceLayer = tower_http::trace::TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::GrpcErrorsAsFailures>,
MakeSpan,
tower_http::trace::DefaultOnRequest,
OnResponse,
tower_http::trace::DefaultOnBodyChunk,
OnEos,
super::shared::OnFailure,
>;
type BaseTraceService = tower_http::trace::Trace<
tonic::transport::Channel,
tower_http::classify::SharedClassifier<tower_http::classify::GrpcErrorsAsFailures>,
MakeSpan,
tower_http::trace::DefaultOnRequest,
OnResponse,
tower_http::trace::DefaultOnBodyChunk,
OnEos,
super::shared::OnFailure,
>;
#[derive(Clone)]
pub struct CustomTraceService {
propagate_trace_id: bool,
filter: Option<TracingFilter>,
base_url: String,
inner: BaseTraceService,
}
impl CustomTraceService {
pub fn new(
propagate_trace_id: bool,
base_url: String,
filter: Option<TracingFilter>,
inner: BaseTraceService,
) -> Self {
Self {
propagate_trace_id,
filter,
base_url,
inner,
}
}
}
impl GrpcService<Body> for CustomTraceService {
type ResponseBody = <BaseTraceService as GrpcService<Body>>::ResponseBody;
type Error = <BaseTraceService as GrpcService<Body>>::Error;
type Future = WithContext<<BaseTraceService as GrpcService<Body>>::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
GrpcService::poll_ready(&mut self.inner, cx)
}
fn call(&mut self, mut request: http::Request<Body>) -> Self::Future {
if self.propagate_trace_id
&& should_trace_request(self.base_url.as_str(), &request, self.filter.as_ref())
{
let propagator = TraceContextPropagator::new();
let mut injector = HeaderInjector(request.headers_mut());
propagator.inject_context(&opentelemetry::Context::current(), &mut injector);
}
self.inner.call(request).with_current_context()
}
}
#[derive(Debug, Clone)]
pub struct CustomTraceLayer {
propagate_trace_id: bool,
filter: Option<TracingFilter>,
pub(super) base_url: String,
base_trace_layer: BaseTraceLayer,
}
impl CustomTraceLayer {
pub fn new(
propagate_trace_id: bool,
base_url: String,
filter: Option<TracingFilter>,
base_trace_layer: BaseTraceLayer,
) -> Self {
Self {
propagate_trace_id,
filter,
base_url,
base_trace_layer,
}
}
}
impl tower::Layer<tonic::transport::Channel> for CustomTraceLayer {
type Service = CustomTraceService;
fn layer(&self, inner: tonic::transport::Channel) -> Self::Service {
let traced_channel = self.base_trace_layer.layer(inner);
CustomTraceService::new(
self.propagate_trace_id,
self.base_url.clone(),
self.filter.clone(),
traced_channel,
)
}
}
#[must_use]
fn build_base_trace_layer(
base_url: String,
configuration: Option<&TracingConfiguration>,
) -> BaseTraceLayer {
tower_http::trace::TraceLayer::new_for_grpc()
.on_eos(OnEos {
inner: tower_http::trace::DefaultOnEos::default(),
response_headers: configuration
.as_ref()
.map(|configuration| configuration.response_headers().clone())
.unwrap_or_else(IncludeExclude::include_none),
})
.make_span_with(MakeSpan {
enabled: configuration.is_some(),
request_headers: configuration
.as_ref()
.map(|configuration| configuration.request_headers().clone())
.unwrap_or_else(IncludeExclude::include_none),
filter: configuration
.as_ref()
.and_then(|configuration| configuration.filter())
.cloned(),
base_url: base_url.clone(),
})
.on_failure(super::shared::OnFailure {
inner: tower_http::trace::DefaultOnFailure::default(),
})
.on_response(OnResponse {
inner: tower_http::trace::DefaultOnResponse::default(),
response_headers: configuration
.as_ref()
.map(|configuration| configuration.response_headers().clone())
.unwrap_or_else(IncludeExclude::include_none),
})
}
#[must_use]
pub fn build_layer(
base_url: String,
configuration: Option<&TracingConfiguration>,
) -> CustomTraceLayer {
let trace_layer = build_base_trace_layer(base_url.clone(), configuration);
CustomTraceLayer::new(
configuration
.as_ref()
.map(|configuration| configuration.propagate_otel_context())
.unwrap_or(false),
base_url,
configuration
.as_ref()
.and_then(|configuration| configuration.filter())
.cloned(),
trace_layer,
)
}