use std::collections::HashMap;
use std::task::{Context as TaskContext, Poll};
use futures_util::future::BoxFuture;
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use opentelemetry::Context;
use opentelemetry::global;
use opentelemetry::propagation::{Extractor, Injector};
use tower::{Layer, Service};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
pub fn extract_layer() -> ExtractLayer {
ExtractLayer
}
#[derive(Clone, Copy, Default)]
pub struct ExtractLayer;
impl<S> Layer<S> for ExtractLayer {
type Service = ExtractService<S>;
fn layer(&self, inner: S) -> Self::Service {
ExtractService { inner }
}
}
#[derive(Clone)]
pub struct ExtractService<S> {
inner: S,
}
impl<S, B, Resp> Service<Request<B>> for ExtractService<S>
where
S: Service<Request<B>, Response = Response<Resp>> + Clone + Send + 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let parent_cx =
global::get_text_map_propagator(|prop| prop.extract(&HeaderExtractor(req.headers())));
Span::current().set_parent(parent_cx);
let mut inner = self.inner.clone();
Box::pin(async move { inner.call(req).await })
}
}
struct HeaderExtractor<'a>(&'a HeaderMap);
impl<'a> Extractor for HeaderExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(|k| k.as_str()).collect()
}
}
pub fn inject_current_context(metadata: &mut tonic::metadata::MetadataMap) {
let cx = Span::current().context();
let mut injector = MetadataInjector(metadata);
global::get_text_map_propagator(|prop| prop.inject_context(&cx, &mut injector));
}
struct MetadataInjector<'a>(&'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> {
fn set(&mut self, key: &str, value: String) {
if let Ok(name) = key.parse::<tonic::metadata::MetadataKey<tonic::metadata::Ascii>>()
&& let Ok(val) = value.parse()
{
self.0.insert(name, val);
}
}
}
pub fn inject_current_context_http(headers: &mut HeaderMap) {
let cx = Context::current();
let mut injector = HttpInjector(headers);
global::get_text_map_propagator(|prop| prop.inject_context(&cx, &mut injector));
}
struct HttpInjector<'a>(&'a mut HeaderMap);
impl<'a> Injector for HttpInjector<'a> {
fn set(&mut self, key: &str, value: String) {
if let (Ok(name), Ok(val)) = (key.parse::<HeaderName>(), HeaderValue::from_str(&value)) {
self.0.insert(name, val);
}
}
}
pub fn inject_current_context_map(headers: &mut HashMap<String, String>) {
let cx = Span::current().context();
let mut injector = MapInjector(headers);
global::get_text_map_propagator(|prop| prop.inject_context(&cx, &mut injector));
}
pub fn extract_context_from_map(headers: &HashMap<String, String>) -> Context {
global::get_text_map_propagator(|prop| prop.extract(&MapExtractor(headers)))
}
struct MapInjector<'a>(&'a mut HashMap<String, String>);
impl<'a> Injector for MapInjector<'a> {
fn set(&mut self, key: &str, value: String) {
self.0.insert(key.to_string(), value);
}
}
struct MapExtractor<'a>(&'a HashMap<String, String>);
impl<'a> Extractor for MapExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).map(|s| s.as_str())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(|k| k.as_str()).collect()
}
}