use crate::otel_warn;
#[cfg(feature = "trace")]
use crate::trace::context::SynchronizedSpan;
use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt;
use std::hash::{BuildHasherDefault, Hasher};
use std::marker::PhantomData;
use std::sync::Arc;
#[cfg(feature = "futures")]
mod future_ext;
#[cfg(feature = "futures")]
pub use future_ext::{FutureExt, WithContext};
thread_local! {
static CURRENT_CONTEXT: RefCell<ContextStack> = RefCell::new(ContextStack::default());
}
#[derive(Clone, Default)]
pub struct Context {
#[cfg(feature = "trace")]
pub(crate) span: Option<Arc<SynchronizedSpan>>,
entries: Option<Arc<EntryMap>>,
suppress_telemetry: bool,
}
type EntryMap = HashMap<TypeId, Arc<dyn Any + Sync + Send>, BuildHasherDefault<IdHasher>>;
impl Context {
pub fn new() -> Self {
Context::default()
}
pub fn current() -> Self {
Self::map_current(|cx| cx.clone())
}
pub fn map_current<T>(f: impl FnOnce(&Context) -> T) -> T {
CURRENT_CONTEXT.with(|cx| cx.borrow().map_current_cx(f))
}
pub fn current_with_value<T: 'static + Send + Sync>(value: T) -> Self {
Self::map_current(|cx| cx.with_value(value))
}
pub fn get<T: 'static>(&self) -> Option<&T> {
self.entries
.as_ref()?
.get(&TypeId::of::<T>())?
.downcast_ref()
}
pub fn with_value<T: 'static + Send + Sync>(&self, value: T) -> Self {
let entries = if let Some(current_entries) = &self.entries {
let mut inner_entries = (**current_entries).clone();
inner_entries.insert(TypeId::of::<T>(), Arc::new(value));
Some(Arc::new(inner_entries))
} else {
let mut entries = EntryMap::default();
entries.insert(TypeId::of::<T>(), Arc::new(value));
Some(Arc::new(entries))
};
Context {
entries,
#[cfg(feature = "trace")]
span: self.span.clone(),
suppress_telemetry: self.suppress_telemetry,
}
}
pub fn attach(self) -> ContextGuard {
let cx_id = CURRENT_CONTEXT.with(|cx| cx.borrow_mut().push(self));
ContextGuard {
cx_pos: cx_id,
_marker: PhantomData,
}
}
#[inline]
pub fn is_telemetry_suppressed(&self) -> bool {
self.suppress_telemetry
}
pub fn with_telemetry_suppressed(&self) -> Self {
Context {
entries: self.entries.clone(),
#[cfg(feature = "trace")]
span: self.span.clone(),
suppress_telemetry: true,
}
}
pub fn enter_telemetry_suppressed_scope() -> ContextGuard {
Self::map_current(|cx| cx.with_telemetry_suppressed()).attach()
}
#[inline]
pub fn is_current_telemetry_suppressed() -> bool {
Self::map_current(|cx| cx.is_telemetry_suppressed())
}
#[cfg(feature = "trace")]
pub(crate) fn current_with_synchronized_span(value: SynchronizedSpan) -> Self {
Self::map_current(|cx| Context {
span: Some(Arc::new(value)),
entries: cx.entries.clone(),
suppress_telemetry: cx.suppress_telemetry,
})
}
#[cfg(feature = "trace")]
pub(crate) fn with_synchronized_span(&self, value: SynchronizedSpan) -> Self {
Context {
span: Some(Arc::new(value)),
entries: self.entries.clone(),
suppress_telemetry: self.suppress_telemetry,
}
}
}
impl fmt::Debug for Context {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut dbg = f.debug_struct("Context");
#[cfg(feature = "trace")]
let mut entries = self.entries.as_ref().map_or(0, |e| e.len());
#[cfg(feature = "trace")]
{
if let Some(span) = &self.span {
dbg.field("span", &span.span_context());
entries += 1;
} else {
dbg.field("span", &"None");
}
}
#[cfg(not(feature = "trace"))]
let entries = self.entries.as_ref().map_or(0, |e| e.len());
dbg.field("entries count", &entries)
.field("suppress_telemetry", &self.suppress_telemetry)
.finish()
}
}
#[derive(Debug)]
pub struct ContextGuard {
cx_pos: u16,
_marker: PhantomData<*const ()>,
}
impl Drop for ContextGuard {
fn drop(&mut self) {
let id = self.cx_pos;
if id > ContextStack::BASE_POS && id < ContextStack::MAX_POS {
CURRENT_CONTEXT.with(|context_stack| context_stack.borrow_mut().pop_id(id));
}
}
}
#[derive(Clone, Default, Debug)]
struct IdHasher(u64);
impl Hasher for IdHasher {
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}
#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}
#[inline]
fn finish(&self) -> u64 {
self.0
}
}
struct ContextStack {
current_cx: Context,
stack: Vec<Option<Context>>,
_marker: PhantomData<*const ()>,
}
impl ContextStack {
const BASE_POS: u16 = 0;
const MAX_POS: u16 = u16::MAX;
const INITIAL_CAPACITY: usize = 8;
#[inline(always)]
fn push(&mut self, cx: Context) -> u16 {
let next_id = self.stack.len() + 1;
if next_id < ContextStack::MAX_POS.into() {
let current_cx = std::mem::replace(&mut self.current_cx, cx);
self.stack.push(Some(current_cx));
next_id as u16
} else {
otel_warn!(
name: "Context.AttachFailed",
message = format!("Too many contexts. Max limit is {}. \
Context::current() remains unchanged as this attach failed. \
Dropping the returned ContextGuard will have no impact on Context::current().",
ContextStack::MAX_POS)
);
ContextStack::MAX_POS
}
}
#[inline(always)]
fn pop_id(&mut self, pos: u16) {
if pos == ContextStack::BASE_POS || pos == ContextStack::MAX_POS {
otel_warn!(
name: "Context.OutOfOrderDrop",
position = pos,
message = if pos == ContextStack::BASE_POS {
"Attempted to pop the base context which is not allowed"
} else {
"Attempted to pop the overflow position which is not allowed"
}
);
return;
}
let len: u16 = self.stack.len() as u16;
if pos == len {
while let Some(None) = self.stack.last() {
_ = self.stack.pop();
}
if let Some(Some(next_cx)) = self.stack.pop() {
self.current_cx = next_cx;
}
} else {
if pos >= len {
otel_warn!(
name: "Context.PopOutOfBounds",
position = pos,
stack_length = len,
message = "Attempted to pop beyond the end of the context stack"
);
return;
}
_ = self.stack[pos as usize].take();
}
}
#[inline(always)]
fn map_current_cx<T>(&self, f: impl FnOnce(&Context) -> T) -> T {
f(&self.current_cx)
}
}
impl Default for ContextStack {
fn default() -> Self {
ContextStack {
current_cx: Context::default(),
stack: Vec::with_capacity(ContextStack::INITIAL_CAPACITY),
_marker: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, PartialEq)]
struct ValueA(u64);
#[derive(Debug, PartialEq)]
struct ValueB(u64);
#[test]
fn context_immutable() {
let cx = Context::current();
assert_eq!(cx.get::<ValueA>(), None);
assert_eq!(cx.get::<ValueB>(), None);
let cx_new = cx.with_value(ValueA(1));
assert_eq!(cx.get::<ValueA>(), None);
assert_eq!(cx.get::<ValueB>(), None);
assert_eq!(cx_new.get::<ValueA>(), Some(&ValueA(1)));
let cx_newer = cx_new.with_value(ValueB(1));
assert_eq!(cx.get::<ValueA>(), None);
assert_eq!(cx.get::<ValueB>(), None);
assert_eq!(cx_new.get::<ValueA>(), Some(&ValueA(1)));
assert_eq!(cx_new.get::<ValueB>(), None);
assert_eq!(cx_newer.get::<ValueA>(), Some(&ValueA(1)));
assert_eq!(cx_newer.get::<ValueB>(), Some(&ValueB(1)));
}
#[test]
fn nested_contexts() {
let _outer_guard = Context::new().with_value(ValueA(1)).attach();
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA(1)));
assert_eq!(current.get::<ValueB>(), None);
{
let _inner_guard = Context::current_with_value(ValueB(42)).attach();
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA(1)));
assert_eq!(current.get(), Some(&ValueB(42)));
assert!(Context::map_current(|cx| {
assert_eq!(cx.get(), Some(&ValueA(1)));
assert_eq!(cx.get(), Some(&ValueB(42)));
true
}));
}
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA(1)));
assert_eq!(current.get::<ValueB>(), None);
assert!(Context::map_current(|cx| {
assert_eq!(cx.get(), Some(&ValueA(1)));
assert_eq!(cx.get::<ValueB>(), None);
true
}));
}
#[test]
fn overlapping_contexts() {
let outer_guard = Context::new().with_value(ValueA(1)).attach();
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA(1)));
assert_eq!(current.get::<ValueB>(), None);
let inner_guard = Context::current_with_value(ValueB(42)).attach();
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA(1)));
assert_eq!(current.get(), Some(&ValueB(42)));
assert!(Context::map_current(|cx| {
assert_eq!(cx.get(), Some(&ValueA(1)));
assert_eq!(cx.get(), Some(&ValueB(42)));
true
}));
drop(outer_guard);
let current = Context::current();
assert_eq!(current.get(), Some(&ValueA(1)));
assert_eq!(current.get(), Some(&ValueB(42)));
drop(inner_guard);
let current = Context::current();
assert_eq!(current.get::<ValueA>(), None);
assert_eq!(current.get::<ValueB>(), None);
}
#[test]
fn too_many_contexts() {
let mut guards: Vec<ContextGuard> = Vec::with_capacity(ContextStack::MAX_POS as usize);
let stack_max_pos = ContextStack::MAX_POS as u64;
for i in 1..stack_max_pos {
let cx_guard = Context::current().with_value(ValueB(i)).attach();
assert_eq!(Context::current().get(), Some(&ValueB(i)));
assert_eq!(cx_guard.cx_pos, i as u16);
guards.push(cx_guard);
}
for _ in 0..16 {
let cx_guard = Context::current().with_value(ValueA(1)).attach();
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
assert_eq!(Context::current().get::<ValueA>(), None);
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
guards.push(cx_guard);
}
for _ in 0..16 {
guards.pop();
assert_eq!(Context::current().get::<ValueA>(), None);
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
}
guards.pop();
assert_eq!(Context::current().get::<ValueA>(), None);
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
let cx_guard = Context::current().with_value(ValueA(2)).attach();
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS - 1);
assert_eq!(Context::current().get(), Some(&ValueA(2)));
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
guards.push(cx_guard);
for _ in 0..16 {
let cx_guard = Context::current().with_value(ValueA(1)).attach();
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(2)));
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
guards.push(cx_guard);
}
}
#[test]
fn test_initial_capacity() {
let stack = ContextStack::default();
assert_eq!(stack.stack.capacity(), ContextStack::INITIAL_CAPACITY);
}
#[test]
fn test_map_current_cx() {
let mut stack = ContextStack::default();
let test_value = ValueA(42);
stack.current_cx = Context::new().with_value(test_value);
let result = stack.map_current_cx(|cx| {
assert_eq!(cx.get::<ValueA>(), Some(&ValueA(42)));
true
});
assert!(result);
}
#[test]
fn test_pop_id_out_of_order() {
let mut stack = ContextStack::default();
let cx1 = Context::new().with_value(ValueA(1));
let cx2 = Context::new().with_value(ValueA(2));
let cx3 = Context::new().with_value(ValueA(3));
let id1 = stack.push(cx1);
let id2 = stack.push(cx2);
let id3 = stack.push(cx3);
stack.pop_id(id2);
assert_eq!(stack.current_cx.get::<ValueA>(), Some(&ValueA(3)));
assert_eq!(stack.stack.len(), 3);
stack.pop_id(id3);
assert_eq!(stack.current_cx.get::<ValueA>(), Some(&ValueA(1)));
assert_eq!(stack.stack.len(), 1);
stack.pop_id(id1);
assert_eq!(stack.current_cx.get::<ValueA>(), None);
assert_eq!(stack.stack.len(), 0);
}
#[test]
fn test_pop_id_edge_cases() {
let mut stack = ContextStack::default();
stack.pop_id(ContextStack::BASE_POS);
assert_eq!(stack.stack.len(), 0);
stack.pop_id(ContextStack::MAX_POS);
assert_eq!(stack.stack.len(), 0);
stack.pop_id(1000);
assert_eq!(stack.stack.len(), 0);
stack.pop_id(1);
assert_eq!(stack.stack.len(), 0);
}
#[test]
fn test_push_overflow() {
let mut stack = ContextStack::default();
let max_pos = ContextStack::MAX_POS as usize;
for i in 0..max_pos {
let cx = Context::new().with_value(ValueA(i as u64));
let id = stack.push(cx);
assert_eq!(id, (i + 1) as u16);
}
let cx = Context::new().with_value(ValueA(max_pos as u64));
let id = stack.push(cx);
assert_eq!(id, ContextStack::MAX_POS);
assert_eq!(
stack.current_cx.get::<ValueA>(),
Some(&ValueA((max_pos - 2) as u64))
);
}
#[tokio::test]
async fn test_async_context_propagation() {
async fn nested_operation() {
assert_eq!(
Context::current().get::<ValueA>(),
Some(&ValueA(42)),
"Parent context value should be available in async operation"
);
let cx_with_both = Context::current()
.with_value(ValueA(43)) .with_value(ValueB(24));
async {
assert_eq!(
Context::current().get::<ValueA>(),
Some(&ValueA(43)),
"Parent value should still be available after adding new value"
);
assert_eq!(
Context::current().get::<ValueB>(),
Some(&ValueB(24)),
"New value should be available in async operation"
);
sleep(Duration::from_millis(10)).await;
assert_eq!(
Context::current().get::<ValueA>(),
Some(&ValueA(43)),
"Parent value should persist across await points"
);
assert_eq!(
Context::current().get::<ValueB>(),
Some(&ValueB(24)),
"New value should persist across await points"
);
}
.with_context(cx_with_both)
.await;
}
let parent_cx = Context::new().with_value(ValueA(42));
nested_operation().with_context(parent_cx.clone()).await;
assert_eq!(
parent_cx.get::<ValueA>(),
Some(&ValueA(42)),
"Parent context should be unchanged"
);
assert_eq!(
parent_cx.get::<ValueB>(),
None,
"Parent context should not see values added in async operation"
);
assert_eq!(
Context::current().get::<ValueA>(),
None,
"Current context should be back to default"
);
assert_eq!(
Context::current().get::<ValueB>(),
None,
"Current context should not have async operation's values"
);
}
#[tokio::test]
async fn test_out_of_order_context_detachment_futures() {
async fn create_a_future() -> impl std::future::Future<Output = ()> {
async {
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(42)));
sleep(Duration::from_millis(50)).await;
}
.with_context(Context::current())
}
let parent_cx = Context::new().with_value(ValueA(42));
let future = create_a_future().with_context(parent_cx).await;
let _a = future.await;
assert_eq!(Context::current().get::<ValueA>(), None);
assert_eq!(Context::current().get::<ValueB>(), None);
}
#[test]
fn test_is_telemetry_suppressed() {
let cx = Context::new();
assert!(!cx.is_telemetry_suppressed());
let suppressed = cx.with_telemetry_suppressed();
assert!(suppressed.is_telemetry_suppressed());
}
#[test]
fn test_with_telemetry_suppressed() {
let cx = Context::new();
assert!(!cx.is_telemetry_suppressed());
let suppressed = cx.with_telemetry_suppressed();
assert!(!cx.is_telemetry_suppressed());
assert!(suppressed.is_telemetry_suppressed());
let cx_with_value = cx.with_value(ValueA(42));
let suppressed_with_value = cx_with_value.with_telemetry_suppressed();
assert!(!cx_with_value.is_telemetry_suppressed());
assert!(suppressed_with_value.is_telemetry_suppressed());
assert_eq!(suppressed_with_value.get::<ValueA>(), Some(&ValueA(42)));
}
#[test]
fn test_enter_telemetry_suppressed_scope() {
let _reset_guard = Context::new().attach();
assert!(!Context::is_current_telemetry_suppressed());
let cx_with_value = Context::current().with_value(ValueA(42));
let _guard_with_value = cx_with_value.attach();
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(42)));
assert!(!Context::is_current_telemetry_suppressed());
{
let _guard = Context::enter_telemetry_suppressed_scope();
assert!(Context::is_current_telemetry_suppressed());
assert!(Context::current().is_telemetry_suppressed());
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(42)));
}
assert!(!Context::is_current_telemetry_suppressed());
assert!(!Context::current().is_telemetry_suppressed());
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(42)));
}
#[test]
fn test_nested_suppression_scopes() {
let _reset_guard = Context::new().attach();
assert!(!Context::is_current_telemetry_suppressed());
{
let _outer = Context::enter_telemetry_suppressed_scope();
assert!(Context::is_current_telemetry_suppressed());
{
let _inner = Context::current().with_value(ValueA(1)).attach();
assert!(Context::is_current_telemetry_suppressed());
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(1)));
}
{
let _inner = Context::new().with_value(ValueA(1)).attach();
assert!(!Context::is_current_telemetry_suppressed());
assert_eq!(Context::current().get::<ValueA>(), Some(&ValueA(1)));
}
assert!(Context::is_current_telemetry_suppressed());
}
assert!(!Context::is_current_telemetry_suppressed());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_async_suppression() {
async fn nested_operation() {
assert!(Context::is_current_telemetry_suppressed());
let cx_with_additional_value = Context::current().with_value(ValueB(24));
async {
assert_eq!(
Context::current().get::<ValueB>(),
Some(&ValueB(24)),
"Parent value should still be available after adding new value"
);
assert!(Context::is_current_telemetry_suppressed());
sleep(Duration::from_millis(10)).await;
assert_eq!(
Context::current().get::<ValueB>(),
Some(&ValueB(24)),
"Parent value should still be available after adding new value"
);
assert!(Context::is_current_telemetry_suppressed());
}
.with_context(cx_with_additional_value)
.await;
}
let suppressed_parent = Context::new().with_telemetry_suppressed();
assert!(!Context::is_current_telemetry_suppressed());
nested_operation()
.with_context(suppressed_parent.clone())
.await;
assert!(suppressed_parent.is_telemetry_suppressed());
assert!(!Context::is_current_telemetry_suppressed());
}
}