use crate::trace::{self, TraceId};
use opentelemetry::trace::TraceContextExt;
use static_assertions::assert_impl_all;
use std::{
convert::TryFrom,
time::{Duration, Instant},
};
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Context {
#[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))]
#[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))]
pub deadline: Instant,
pub trace_context: trace::Context,
}
#[cfg(feature = "serde1")]
mod absolute_to_relative_time {
pub use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use std::time::{Duration, Instant};
pub fn serialize<S>(deadline: &Instant, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let deadline = deadline.duration_since(Instant::now());
deadline.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Instant, D::Error>
where
D: Deserializer<'de>,
{
let deadline = Duration::deserialize(deserializer)?;
Ok(Instant::now() + deadline)
}
#[cfg(test)]
#[derive(serde::Serialize, serde::Deserialize)]
struct AbsoluteToRelative(#[serde(with = "self")] Instant);
#[test]
fn test_serialize() {
let now = Instant::now();
let deadline = now + Duration::from_secs(10);
let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap();
let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap();
assert!(deserialized_deadline > Duration::from_secs(9));
}
#[test]
fn test_deserialize() {
let deadline = Duration::from_secs(10);
let serialized_deadline = bincode::serialize(&deadline).unwrap();
let AbsoluteToRelative(deserialized_deadline) =
bincode::deserialize(&serialized_deadline).unwrap();
assert!(deserialized_deadline > Instant::now() + Duration::from_secs(9));
}
}
assert_impl_all!(Context: Send, Sync);
fn ten_seconds_from_now() -> Instant {
Instant::now() + Duration::from_secs(10)
}
pub fn current() -> Context {
Context::current()
}
#[derive(Clone)]
struct Deadline(Instant);
impl Default for Deadline {
fn default() -> Self {
Self(ten_seconds_from_now())
}
}
impl Context {
pub fn current() -> Self {
let span = tracing::Span::current();
Self {
trace_context: trace::Context::try_from(&span)
.unwrap_or_else(|_| trace::Context::default()),
deadline: span
.context()
.get::<Deadline>()
.cloned()
.unwrap_or_default()
.0,
}
}
pub fn trace_id(&self) -> &TraceId {
&self.trace_context.trace_id
}
}
pub(crate) trait SpanExt {
fn set_context(&self, context: &Context);
}
impl SpanExt for tracing::Span {
fn set_context(&self, context: &Context) {
self.set_parent(
opentelemetry::Context::new()
.with_remote_span_context(opentelemetry::trace::SpanContext::new(
opentelemetry::trace::TraceId::from(context.trace_context.trace_id),
opentelemetry::trace::SpanId::from(context.trace_context.span_id),
opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision),
true,
opentelemetry::trace::TraceState::default(),
))
.with_value(Deadline(context.deadline)),
);
}
}