use crate::errcode::{Kind, Origin};
use core::fmt::{Display, Formatter};
use core::str::FromStr;
use opentelemetry::propagation::{Extractor, Injector};
use opentelemetry::{global, Context};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use tracing_opentelemetry::OtelData;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::Registry;
const TRACE_CONTEXT_PROPAGATION_SPAN: &str = "trace context propagation";
pub const OCKAM_TRACER_NAME: &str = "ockam";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct OpenTelemetryContext(HashMap<String, String>);
impl Hash for OpenTelemetryContext {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_string().hash(state)
}
}
impl PartialOrd for OpenTelemetryContext {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OpenTelemetryContext {
fn cmp(&self, other: &Self) -> Ordering {
self.to_string().cmp(&other.to_string())
}
}
impl OpenTelemetryContext {
pub fn extract(&self) -> Context {
global::get_text_map_propagator(|propagator| propagator.extract(self))
}
pub fn inject(context: &Context) -> Self {
global::get_text_map_propagator(|propagator| {
let mut propagation_context = OpenTelemetryContext::empty();
propagator.inject_context(context, &mut propagation_context);
propagation_context
})
}
pub fn update(mut self) -> OpenTelemetryContext {
let _guard = self.extract().attach();
let updated = OpenTelemetryContext::current();
self.0 = updated.0;
self
}
pub fn current() -> OpenTelemetryContext {
let span = tracing::trace_span!(TRACE_CONTEXT_PROPAGATION_SPAN);
let mut result = None;
tracing::dispatcher::get_default(|dispatcher| {
if let Some(registry) = dispatcher.downcast_ref::<Registry>() {
if let Some(id) = span.id() {
if let Some(span) = registry.span(&id) {
let mut extensions = span.extensions_mut();
if let Some(OtelData {
builder: _,
parent_cx,
}) = extensions.remove::<OtelData>()
{
result = Some(OpenTelemetryContext::inject(&parent_cx))
}
}
}
};
});
result.unwrap_or_else(|| OpenTelemetryContext::inject(&Context::current()))
}
fn set_as_parent_context(self) {
let parent_cx = self.extract();
let span = tracing::trace_span!(TRACE_CONTEXT_PROPAGATION_SPAN);
tracing::dispatcher::get_default(|dispatcher| {
if let Some(registry) = dispatcher.downcast_ref::<Registry>() {
if let Some(id) = span.id() {
if let Some(span) = registry.span(&id) {
if let Some(parent) = span.parent() {
let mut extensions = parent.extensions_mut();
if let Some(otel_data) = extensions.get_mut::<OtelData>() {
otel_data.parent_cx = parent_cx.clone();
}
}
{
let mut extensions = span.extensions_mut();
extensions.remove::<OtelData>();
}
}
}
};
})
}
pub fn current_context() -> Context {
OpenTelemetryContext::current().extract()
}
pub fn from_remote_context(tracing_context: &str) -> OpenTelemetryContext {
let result: Option<OpenTelemetryContext> = tracing_context.try_into().ok();
if let Some(tc) = result {
tc.set_as_parent_context()
};
OpenTelemetryContext::current()
}
fn empty() -> Self {
Self(HashMap::new())
}
pub fn as_map(&self) -> HashMap<String, String> {
self.0.clone()
}
}
impl Display for OpenTelemetryContext {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_str(&serde_json::to_string(&self).map_err(|_| core::fmt::Error)?)
}
}
impl Injector for OpenTelemetryContext {
fn set(&mut self, key: &str, value: String) {
self.0.insert(key.to_owned(), value);
}
}
impl Extractor for OpenTelemetryContext {
fn get(&self, key: &str) -> Option<&str> {
let key = key.to_owned();
self.0.get(&key).map(|v| v.as_ref())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(|k| k.as_ref()).collect()
}
}
impl TryFrom<&str> for OpenTelemetryContext {
type Error = crate::Error;
fn try_from(value: &str) -> crate::Result<Self> {
opentelemetry_context_parser(value)
}
}
impl FromStr for OpenTelemetryContext {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.try_into()
}
}
impl TryFrom<String> for OpenTelemetryContext {
type Error = crate::Error;
fn try_from(value: String) -> crate::Result<Self> {
opentelemetry_context_parser(&value)
}
}
pub fn opentelemetry_context_parser(input: &str) -> crate::Result<OpenTelemetryContext> {
serde_json::from_str(input).map_err(|e| {
crate::Error::new(
Origin::Api,
Kind::Serialization,
format!("Invalid OpenTelemetry context: {input}. Got error: {e:?}"),
)
})
}