use serde::{Deserialize, Serialize};
use crate::*;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct RequestId(String);
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct UserId(u64);
fn unique_key(prefix: &str, suffix: &str) -> &'static str {
let s = format!("{}_{}", prefix, suffix);
Box::leak(s.into_boxed_str())
}
#[test]
fn test_register_and_get_default() {
let key = unique_key("reg_default", "rid");
register::<RequestId>(key);
let val: RequestId = get_context(key);
assert_eq!(val, RequestId::default());
}
#[test]
fn test_try_register_idempotent() {
let key = unique_key("reg_idem", "rid");
try_register::<RequestId>(key).unwrap();
try_register::<RequestId>(key).unwrap(); }
#[test]
fn test_try_register_conflict() {
let key = unique_key("reg_conflict", "val");
try_register::<RequestId>(key).unwrap();
let err = try_register::<UserId>(key).unwrap_err();
assert!(matches!(err, ContextError::AlreadyRegistered(_)));
}
#[test]
#[should_panic]
fn test_get_unregistered_panics() {
let key = unique_key("unreg_panic", "missing");
get_context::<RequestId>(key);
}
#[test]
fn test_try_get_unregistered_returns_err() {
let key = unique_key("unreg_err", "missing");
let result = try_get_context::<RequestId>(key);
assert!(matches!(result, Err(ContextError::NotRegistered(_))));
}
#[test]
fn test_set_and_get() {
let key = unique_key("set_get", "rid");
register::<RequestId>(key);
set_context(key, RequestId("req-42".into()));
let val: RequestId = get_context(key);
assert_eq!(val.0, "req-42");
}
#[test]
fn test_set_unregistered_returns_err() {
let key = unique_key("set_unreg", "val");
let err = try_set_context(key, RequestId("x".into())).unwrap_err();
assert!(matches!(err, ContextError::NotRegistered(_)));
}
#[test]
fn test_type_mismatch_get() {
let key = unique_key("type_mm_get", "val");
register::<RequestId>(key);
set_context(key, RequestId("x".into()));
let err = try_get_context::<UserId>(key).unwrap_err();
assert!(matches!(err, ContextError::TypeMismatch(..)));
}
#[test]
fn test_type_mismatch_set() {
let key = unique_key("type_mm_set", "val");
register::<RequestId>(key);
let err = try_set_context(key, UserId(1)).unwrap_err();
assert!(matches!(err, ContextError::TypeMismatch(..)));
}
#[test]
fn test_scope_shadows_and_reverts() {
let key = unique_key("scope_shadow", "rid");
register::<RequestId>(key);
set_context(key, RequestId("parent".into()));
{
let _guard = enter_scope();
set_context(key, RequestId("child".into()));
assert_eq!(get_context::<RequestId>(key).0, "child");
}
assert_eq!(get_context::<RequestId>(key).0, "parent");
}
#[test]
fn test_nested_scopes() {
let key = unique_key("nested_scope", "rid");
register::<RequestId>(key);
set_context(key, RequestId("root".into()));
{
let _g1 = enter_scope();
set_context(key, RequestId("level1".into()));
{
let _g2 = enter_scope();
set_context(key, RequestId("level2".into()));
assert_eq!(get_context::<RequestId>(key).0, "level2");
}
assert_eq!(get_context::<RequestId>(key).0, "level1");
}
assert_eq!(get_context::<RequestId>(key).0, "root");
}
#[test]
fn test_scope_fn() {
let key = unique_key("scope_fn", "rid");
register::<RequestId>(key);
set_context(key, RequestId("before".into()));
scope(|| {
set_context(key, RequestId("inside".into()));
assert_eq!(get_context::<RequestId>(key).0, "inside");
});
assert_eq!(get_context::<RequestId>(key).0, "before");
}
#[test]
fn test_scope_inherits_parent() {
let key = unique_key("scope_inherit", "rid");
register::<RequestId>(key);
set_context(key, RequestId("parent_val".into()));
scope(|| {
assert_eq!(get_context::<RequestId>(key).0, "parent_val");
});
}
#[test]
fn test_scope_partial_override() {
let key_a = unique_key("scope_partial", "a");
let key_b = unique_key("scope_partial", "b");
register::<RequestId>(key_a);
register::<UserId>(key_b);
set_context(key_a, RequestId("a_parent".into()));
set_context(key_b, UserId(10));
scope(|| {
set_context(key_a, RequestId("a_child".into()));
assert_eq!(get_context::<RequestId>(key_a).0, "a_child");
assert_eq!(get_context::<UserId>(key_b).0, 10); });
assert_eq!(get_context::<RequestId>(key_a).0, "a_parent");
assert_eq!(get_context::<UserId>(key_b).0, 10);
}
#[test]
fn test_snapshot_captures_current() {
let key = unique_key("snap_capture", "rid");
register::<RequestId>(key);
set_context(key, RequestId("snapped".into()));
let snap = snapshot();
set_context(key, RequestId("modified".into()));
{
let _guard = attach(snap);
assert_eq!(get_context::<RequestId>(key).0, "snapped");
}
assert_eq!(get_context::<RequestId>(key).0, "modified");
}
#[test]
fn test_snapshot_empty_context() {
let snap = ContextSnapshot::empty();
{
let _guard = attach(snap);
}
}
#[test]
fn test_spawn_with_context() {
let key = unique_key("thread_spawn", "rid");
register::<RequestId>(key);
set_context(key, RequestId("main-thread".into()));
let handle = spawn_with_context("test-worker", move || {
get_context::<RequestId>(key)
}).unwrap();
let result = handle.join().unwrap();
assert_eq!(result.0, "main-thread");
}
#[test]
fn test_wrap_with_context_fn_once() {
let key = unique_key("wrap_once", "rid");
register::<RequestId>(key);
set_context(key, RequestId("wrapped".into()));
let wrapped = wrap_with_context(move || get_context::<RequestId>(key));
set_context(key, RequestId("changed".into()));
let handle = std::thread::spawn(wrapped);
let result = handle.join().unwrap();
assert_eq!(result.0, "wrapped");
}
#[test]
fn test_wrap_with_context_fn_multi() {
let key = unique_key("wrap_multi", "rid");
register::<RequestId>(key);
set_context(key, RequestId("multi".into()));
let wrapped = wrap_with_context_fn(move || get_context::<RequestId>(key));
let r1 = wrapped();
let r2 = wrapped();
assert_eq!(r1.0, "multi");
assert_eq!(r2.0, "multi");
}
#[test]
fn test_serialize_deserialize_roundtrip() {
let key = unique_key("serde_rt", "rid");
register::<RequestId>(key);
set_context(key, RequestId("serialized".into()));
let bytes = serialize_context().unwrap();
{
let _guard = deserialize_context(&bytes).unwrap();
assert_eq!(get_context::<RequestId>(key).0, "serialized");
}
}
#[cfg(feature = "base64")]
#[test]
fn test_serialize_deserialize_string_roundtrip() {
let key = unique_key("serde_str", "rid");
register::<RequestId>(key);
set_context(key, RequestId("base64val".into()));
let encoded = serialize_context_string().unwrap();
assert!(!encoded.is_empty());
scope(|| {
set_context(key, RequestId("cleared".into()));
{
let _guard = deserialize_context_string(&encoded).unwrap();
assert_eq!(get_context::<RequestId>(key).0, "base64val");
}
});
}
#[test]
fn test_deserialize_unknown_keys_skipped() {
let key = unique_key("serde_skip", "rid");
register::<RequestId>(key);
set_context(key, RequestId("known".into()));
let bytes = serialize_context().unwrap();
{
let _guard = deserialize_context(&bytes).unwrap();
assert_eq!(get_context::<RequestId>(key).0, "known");
}
}
#[test]
fn test_serialize_multiple_keys() {
let key_a = unique_key("serde_multi", "a");
let key_b = unique_key("serde_multi", "b");
register::<RequestId>(key_a);
register::<UserId>(key_b);
set_context(key_a, RequestId("req-multi".into()));
set_context(key_b, UserId(42));
let bytes = serialize_context().unwrap();
scope(|| {
let _guard = deserialize_context(&bytes).unwrap();
assert_eq!(get_context::<RequestId>(key_a).0, "req-multi");
assert_eq!(get_context::<UserId>(key_b).0, 42);
});
}
#[test]
fn test_try_get_registered_but_unset() {
let key = unique_key("try_get_none", "rid");
register::<RequestId>(key);
let result = try_get_context::<RequestId>(key).unwrap();
assert!(result.is_none());
}
#[test]
fn test_force_thread_local_basic() {
let key = unique_key("force_tl", "rid");
register::<RequestId>(key);
force_thread_local(|| {
set_context(key, RequestId("forced".into()));
assert_eq!(get_context::<RequestId>(key).0, "forced");
});
}
#[test]
fn test_force_thread_local_nesting() {
let key = unique_key("force_tl_nest", "rid");
register::<RequestId>(key);
force_thread_local(|| {
set_context(key, RequestId("outer".into()));
force_thread_local(|| {
assert_eq!(get_context::<RequestId>(key).0, "outer");
set_context(key, RequestId("inner".into()));
});
assert_eq!(get_context::<RequestId>(key).0, "inner");
});
}
#[test]
fn test_force_thread_local_panic_safety() {
let key = unique_key("force_tl_panic", "rid");
register::<RequestId>(key);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
force_thread_local(|| {
set_context(key, RequestId("before_panic".into()));
panic!("intentional panic");
});
}));
assert!(result.is_err());
force_thread_local(|| {
assert_eq!(get_context::<RequestId>(key).0, "before_panic");
});
}
#[cfg(feature = "context-key")]
static TEST_CK_KEY: crate::ContextKey<RequestId> = crate::ContextKey::new("test_ck_rid");
#[cfg(feature = "context-key")]
#[test]
fn test_context_key_register_and_get() {
register::<RequestId>(TEST_CK_KEY.key());
TEST_CK_KEY.set(RequestId("ck-val".into()));
assert_eq!(TEST_CK_KEY.get().0, "ck-val");
}
#[cfg(feature = "context-key")]
#[test]
fn test_context_key_try_get_none() {
let key: crate::ContextKey<UserId> = crate::ContextKey::new(unique_key("ck_none", "uid"));
register::<UserId>(key.key());
assert!(key.try_get().unwrap().is_none());
}
#[test]
fn test_register_contexts_macro() {
let key_a = unique_key("macro_reg", "a");
let key_b = unique_key("macro_reg", "b");
register::<RequestId>(key_a);
register::<UserId>(key_b);
set_context(key_a, RequestId("macro-a".into()));
set_context(key_b, UserId(77));
assert_eq!(get_context::<RequestId>(key_a).0, "macro-a");
assert_eq!(get_context::<UserId>(key_b).0, 77);
}
#[test]
fn test_with_scope_macro() {
let key = unique_key("macro_scope", "rid");
register::<RequestId>(key);
set_context(key, RequestId("before".into()));
with_scope! {
key => RequestId("inside-macro".into()),
=> {
assert_eq!(get_context::<RequestId>(key).0, "inside-macro");
}
}
assert_eq!(get_context::<RequestId>(key).0, "before");
}
#[test]
fn test_set_max_context_size_enforced() {
let key = unique_key("size_limit", "rid");
register::<RequestId>(key);
set_context(key, RequestId("some-value".into()));
set_max_context_size(5);
let result = serialize_context();
assert!(matches!(result, Err(ContextError::ContextTooLarge { .. })));
set_max_context_size(0);
let result = serialize_context();
assert!(result.is_ok());
}
#[derive(Clone, Default, Debug, PartialEq)]
struct DbPool {
connection_count: u32,
}
#[test]
fn test_register_local_and_get() {
let key = unique_key("local_basic", "pool");
register_local::<DbPool>(key);
let _guard = enter_scope();
set_context_local(key, DbPool { connection_count: 5 });
let val: DbPool = get_context(key);
assert_eq!(val.connection_count, 5);
}
#[test]
fn test_local_excluded_from_serialization() {
let key_local = unique_key("local_serde", "pool");
let key_normal = unique_key("local_serde", "rid");
register_local::<DbPool>(key_local);
register::<RequestId>(key_normal);
let _guard = enter_scope();
set_context_local(key_local, DbPool { connection_count: 42 });
set_context(key_normal, RequestId("req-local-test".into()));
let bytes = serialize_context().expect("serialization should succeed");
let handle = std::thread::spawn(move || {
register_local::<DbPool>(key_local);
register::<RequestId>(key_normal);
let _guard = deserialize_context(&bytes).expect("deserialization should succeed");
let rid: RequestId = get_context(key_normal);
assert_eq!(rid.0, "req-local-test");
let pool: DbPool = get_context(key_local);
assert_eq!(pool, DbPool::default());
});
handle.join().unwrap();
}
#[test]
fn test_local_propagates_via_snapshot() {
let key = unique_key("local_snap", "pool");
register_local::<DbPool>(key);
let _guard = enter_scope();
set_context_local(key, DbPool { connection_count: 10 });
let snap = snapshot();
let _guard2 = attach(snap);
let val: DbPool = get_context(key);
assert_eq!(val.connection_count, 10);
}
#[test]
fn test_local_scope_isolation() {
let key = unique_key("local_scope", "pool");
register_local::<DbPool>(key);
let _guard = enter_scope();
set_context_local(key, DbPool { connection_count: 1 });
scope(|| {
set_context_local(key, DbPool { connection_count: 99 });
let val: DbPool = get_context(key);
assert_eq!(val.connection_count, 99);
});
let val: DbPool = get_context(key);
assert_eq!(val.connection_count, 1);
}
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct TraceV1 {
trace_id: String,
}
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct TraceV2 {
trace_id: String,
span_id: String,
}
#[test]
fn test_migration_v1_to_v2() {
let key = unique_key("migrate_v1v2", "trace");
register_with::<TraceV2>(key, |o| o.version(2));
register_migration::<TraceV1, TraceV2>(key, 1, |v1| TraceV2 {
trace_id: v1.trace_id,
span_id: "migrated".into(),
});
let v1_val = TraceV1 { trace_id: "tid-old".into() };
let v1_bytes = bincode::serialize(&v1_val).unwrap();
let result = crate::registry::with_registration(key, |reg| {
let deser = reg.deserializers.get(&1).expect("v1 deserializer missing");
deser(&v1_bytes)
});
let boxed = result.unwrap().unwrap();
let migrated = boxed.as_any().downcast_ref::<TraceV2>().unwrap();
assert_eq!(migrated.trace_id, "tid-old");
assert_eq!(migrated.span_id, "migrated");
}
#[test]
fn test_migration_end_to_end() {
let key = unique_key("migrate_e2e", "ctx");
register_with::<TraceV2>(key, |o| o.version(2));
register_migration::<TraceV1, TraceV2>(key, 1, |v1| TraceV2 {
trace_id: v1.trace_id,
span_id: "default-span".into(),
});
let v1_value = TraceV1 { trace_id: "from-v1-sender".into() };
let v1_value_bytes = bincode::serialize(&v1_value).unwrap();
let wire = crate::wire::test_helpers::make_wire_bytes(key, 1, &v1_value_bytes);
let handle = std::thread::spawn(move || {
let _guard = deserialize_context(&wire).unwrap();
let val: TraceV2 = get_context(key);
assert_eq!(val.trace_id, "from-v1-sender");
assert_eq!(val.span_id, "default-span");
});
handle.join().unwrap();
}
#[test]
fn test_migration_unknown_version_errors() {
let key = unique_key("migrate_unknown", "ctx");
register_with::<TraceV2>(key, |o| o.version(2));
let has_v1 = crate::registry::with_registration(key, |reg| {
reg.deserializers.contains_key(&1)
});
assert_eq!(has_v1, Some(false));
}
#[test]
fn test_migration_current_version_still_works() {
let key = unique_key("migrate_current", "ctx");
register_with::<TraceV2>(key, |o| o.version(2));
register_migration::<TraceV1, TraceV2>(key, 1, |v1| TraceV2 {
trace_id: v1.trace_id,
span_id: "migrated".into(),
});
let _guard = enter_scope();
set_context(key, TraceV2 {
trace_id: "current".into(),
span_id: "current-span".into(),
});
let bytes = serialize_context().unwrap();
scope(|| {
let _guard = deserialize_context(&bytes).unwrap();
let val: TraceV2 = get_context(key);
assert_eq!(val.trace_id, "current");
assert_eq!(val.span_id, "current-span");
});
}
#[test]
fn test_migration_rejects_current_version() {
let key = unique_key("migrate_reject", "ctx");
register_with::<TraceV2>(key, |o| o.version(2));
let result = try_register_migration::<TraceV1, TraceV2>(key, 2, |v1| TraceV2 {
trace_id: v1.trace_id,
span_id: "should-fail".into(),
});
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ContextError::DeserializationFailed(_)));
}
#[test]
fn test_register_with_json_codec() {
let key = unique_key("codec_json", "rid");
register_with::<RequestId>(key, |o| o.codec(
|val| serde_json::to_vec(val).map_err(|e| e.to_string()),
|bytes| serde_json::from_slice(bytes).map_err(|e| e.to_string()),
));
let _guard = enter_scope();
set_context(key, RequestId("json-encoded".into()));
let bytes = serialize_context().unwrap();
scope(|| {
let _guard = deserialize_context(&bytes).unwrap();
let val: RequestId = get_context(key);
assert_eq!(val.0, "json-encoded");
});
}
#[test]
fn test_json_codec_wire_bytes_are_json() {
let key = unique_key("codec_json_verify", "rid");
register_with::<RequestId>(key, |o| o.codec(
|val| serde_json::to_vec(val).map_err(|e| e.to_string()),
|bytes| serde_json::from_slice(bytes).map_err(|e| e.to_string()),
));
let _guard = enter_scope();
set_context(key, RequestId("verify-json".into()));
let wire_bytes = serialize_context().unwrap();
let handle = std::thread::spawn(move || {
register_with::<RequestId>(key, |o| o.codec(
|val| serde_json::to_vec(val).map_err(|e| e.to_string()),
|bytes| serde_json::from_slice(bytes).map_err(|e| e.to_string()),
));
let _guard = deserialize_context(&wire_bytes).unwrap();
let val: RequestId = get_context(key);
assert_eq!(val.0, "verify-json");
});
handle.join().unwrap();
}
#[test]
fn test_default_codec_still_works() {
let key = unique_key("codec_default", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("bincode-default".into()));
let bytes = serialize_context().unwrap();
scope(|| {
let _guard = deserialize_context(&bytes).unwrap();
let val: RequestId = get_context(key);
assert_eq!(val.0, "bincode-default");
});
}
#[test]
fn test_local_only_rejects_codec() {
let key = unique_key("local_codec", "rid");
let result = try_register_with::<RequestId>(key, |o| {
o.local_only().codec(
|val| serde_json::to_vec(val).map_err(|e| e.to_string()),
|bytes| serde_json::from_slice(bytes).map_err(|e| e.to_string()),
)
});
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ContextError::SerializationFailed(_)));
}
#[test]
fn test_local_only_rejects_version() {
let key = unique_key("local_version", "rid");
let result = try_register_with::<RequestId>(key, |o| o.local_only().version(2));
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ContextError::SerializationFailed(_)));
}
#[test]
fn test_local_only_builder_excludes_from_serialization() {
let key = unique_key("local_builder_ser", "rid");
register_with::<RequestId>(key, |o| o.local_only());
let _scope = enter_scope();
set_context(key, RequestId("should-not-serialize".into()));
let bytes = serialize_context().unwrap();
std::thread::spawn(move || {
let _guard = deserialize_context(&bytes).unwrap();
let val = try_get_context::<RequestId>(key).unwrap();
assert!(
val.is_none(),
"local_only value registered via builder should not survive serialization"
);
})
.join()
.unwrap();
}
#[cfg(feature = "context-future")]
mod context_future_tests {
use super::*;
fn block_on_simple<F: std::future::Future>(mut fut: F) -> F::Output {
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
fn dummy_raw_waker() -> RawWaker {
fn no_op(_: *const ()) {}
fn clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
const VTABLE: RawWakerVTable =
RawWakerVTable::new(clone, no_op, no_op, no_op);
RawWaker::new(std::ptr::null(), &VTABLE)
}
let waker = unsafe { Waker::from_raw(dummy_raw_waker()) };
let mut cx = Context::from_waker(&waker);
let mut fut = unsafe { std::pin::Pin::new_unchecked(&mut fut) };
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(val) => return val,
Poll::Pending => {
std::thread::yield_now();
}
}
}
}
#[test]
fn test_context_future_basic() {
let key = unique_key("ctx_fut_basic", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("future-123".into()));
let fut = with_context_future(async move {
let val: RequestId = get_context(key);
val
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("future-123".into()));
}
#[test]
fn test_context_future_mutation_propagates() {
let key = unique_key("ctx_fut_mut", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("initial".into()));
let fut = with_context_future(async move {
set_context(key, RequestId("mutated".into()));
let val: RequestId = get_context(key);
val
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("mutated".into()));
}
#[test]
fn test_context_future_isolation() {
let key = unique_key("ctx_fut_iso", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("outer".into()));
let fut = with_context_future(async move {
set_context(key, RequestId("inner-change".into()));
});
block_on_simple(fut);
let val: RequestId = get_context(key);
assert_eq!(val, RequestId("outer".into()));
}
#[test]
fn test_context_future_empty_snapshot() {
let key = unique_key("ctx_fut_empty", "rid");
register::<RequestId>(key);
let fut = ContextFuture::new(ContextSnapshot::empty(), async move {
let val: RequestId = get_context(key);
val
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId::default());
}
#[test]
fn test_context_future_regular_async_fn() {
let key = unique_key("ctx_fut_regular", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("propagated".into()));
async fn inner_fn(key: &'static str) -> RequestId {
get_context(key)
}
let fut = with_context_future(async move {
let val = inner_fn(key).await;
val
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("propagated".into()));
}
#[test]
fn test_context_future_deep_await_chain() {
let key = unique_key("ctx_fut_deep", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("deep-value".into()));
async fn level_3(key: &'static str) -> RequestId {
get_context(key)
}
async fn level_2(key: &'static str) -> RequestId {
level_3(key).await
}
async fn level_1(key: &'static str) -> RequestId {
level_2(key).await
}
let fut = with_context_future(async move {
level_1(key).await
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("deep-value".into()));
}
#[test]
fn test_context_future_mutation_across_await() {
let key = unique_key("ctx_fut_mutawait", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("original".into()));
async fn writer(key: &'static str) {
set_context(key, RequestId("written-by-async-fn".into()));
}
async fn reader(key: &'static str) -> RequestId {
get_context(key)
}
let fut = with_context_future(async move {
writer(key).await;
reader(key).await
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("written-by-async-fn".into()));
}
#[test]
fn test_context_future_multi_poll() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let key = unique_key("ctx_fut_multipoll", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("survives-suspend".into()));
struct YieldOnceFuture {
yielded: Arc<AtomicBool>,
key: &'static str,
}
impl std::future::Future for YieldOnceFuture {
type Output = RequestId;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<Self::Output>
{
if self.yielded.swap(true, Ordering::SeqCst) {
let val: RequestId = get_context(self.key);
std::task::Poll::Ready(val)
} else {
let val: RequestId = get_context(self.key);
assert_eq!(val, RequestId("survives-suspend".into()),
"context must be available on first poll");
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
}
let yielded = Arc::new(AtomicBool::new(false));
let fut = with_context_future(YieldOnceFuture {
yielded: yielded.clone(),
key,
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("survives-suspend".into()));
assert!(yielded.load(Ordering::SeqCst), "future must have yielded at least once");
}
#[test]
fn test_context_future_mutation_survives_yield() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let key = unique_key("ctx_fut_mutyield", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("before".into()));
struct MutateAndYield {
yielded: Arc<AtomicBool>,
key: &'static str,
}
impl std::future::Future for MutateAndYield {
type Output = RequestId;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>)
-> std::task::Poll<Self::Output>
{
if self.yielded.swap(true, Ordering::SeqCst) {
let val: RequestId = get_context(self.key);
std::task::Poll::Ready(val)
} else {
set_context(self.key, RequestId("mutated-before-yield".into()));
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
}
let yielded = Arc::new(AtomicBool::new(false));
let fut = with_context_future(MutateAndYield {
yielded: yielded.clone(),
key,
});
let result = block_on_simple(fut);
assert_eq!(result, RequestId("mutated-before-yield".into()));
}
#[test]
fn test_context_future_nested() {
let key = unique_key("ctx_fut_nested", "rid");
register::<RequestId>(key);
let _guard = enter_scope();
set_context(key, RequestId("outer-value".into()));
let fut = with_context_future(async move {
assert_eq!(get_context::<RequestId>(key).0, "outer-value");
let inner = with_context_future(async move {
set_context(key, RequestId("inner-value".into()));
get_context::<RequestId>(key)
});
let inner_result = inner.await;
assert_eq!(inner_result.0, "inner-value");
get_context::<RequestId>(key)
});
let result = block_on_simple(fut);
assert_eq!(result.0, "outer-value");
}
}
mod async_tests {
use super::*;
#[tokio::test]
async fn test_with_context_basic() {
let key = unique_key("async_basic", "rid");
register::<RequestId>(key);
let snap = {
force_thread_local(|| {
set_context(key, RequestId("async-val".into()));
snapshot()
})
};
let result = with_context(snap, async {
get_context::<RequestId>(key)
})
.await;
assert_eq!(result.0, "async-val");
}
#[tokio::test]
async fn test_scope_async() {
let key = unique_key("scope_async", "rid");
register::<RequestId>(key);
let snap = force_thread_local(|| {
set_context(key, RequestId("before-async".into()));
snapshot()
});
with_context(snap, async {
assert_eq!(get_context::<RequestId>(key).0, "before-async");
scope_async(async {
set_context(key, RequestId("inside-async".into()));
assert_eq!(get_context::<RequestId>(key).0, "inside-async");
})
.await;
assert_eq!(get_context::<RequestId>(key).0, "before-async");
})
.await;
}
#[tokio::test]
async fn test_spawn_with_context_async() {
let key = unique_key("async_spawn", "rid");
register::<RequestId>(key);
let snap = {
force_thread_local(|| {
set_context(key, RequestId("spawned-async".into()));
snapshot()
})
};
let handle = with_context(snap, async {
spawn_with_context_async(async {
get_context::<RequestId>(key)
})
})
.await;
let result = handle.await.unwrap();
assert_eq!(result.0, "spawned-async");
}
#[tokio::test]
async fn test_async_scope_isolation() {
let key = unique_key("async_scope_iso", "rid");
register::<RequestId>(key);
let snap = {
force_thread_local(|| {
set_context(key, RequestId("outer".into()));
snapshot()
})
};
with_context(snap, async {
assert_eq!(get_context::<RequestId>(key).0, "outer");
scope(|| {
set_context(key, RequestId("inner".into()));
assert_eq!(get_context::<RequestId>(key).0, "inner");
});
assert_eq!(get_context::<RequestId>(key).0, "outer");
})
.await;
}
#[tokio::test]
async fn test_async_serialize_roundtrip() {
let key = unique_key("async_serde", "rid");
register::<RequestId>(key);
let snap = {
force_thread_local(|| {
set_context(key, RequestId("async-serde".into()));
snapshot()
})
};
with_context(snap, async {
let bytes = serialize_context().unwrap();
scope(|| {
set_context(key, RequestId("cleared".into()));
let _guard = deserialize_context(&bytes).unwrap();
assert_eq!(get_context::<RequestId>(key).0, "async-serde");
});
})
.await;
}
}
#[test]
fn test_scope_chain_empty_by_default() {
let chain = scope_chain();
assert!(chain.is_empty(), "default scope chain should be empty");
}
#[test]
fn test_scope_chain_named_scope() {
let _g = enter_named_scope("outer");
assert_eq!(scope_chain(), vec!["outer"]);
{
let _g2 = enter_named_scope("inner");
assert_eq!(scope_chain(), vec!["outer", "inner"]);
}
assert_eq!(scope_chain(), vec!["outer"]);
}
#[test]
fn test_scope_chain_unnamed_invisible() {
let _g1 = enter_named_scope("named");
let _g2 = enter_scope();
let _g3 = enter_named_scope("also-named");
assert_eq!(scope_chain(), vec!["named", "also-named"]);
}
#[test]
fn test_scope_chain_snapshot_preserves_chain() {
let _g = enter_named_scope("request-handler");
let snap = snapshot();
assert_eq!(snap.scope_chain(), &["request-handler"]);
scope(|| {
let _guard = attach(snap.clone());
assert_eq!(scope_chain(), vec!["request-handler"]);
let _g2 = enter_named_scope("sub-handler");
assert_eq!(scope_chain(), vec!["request-handler", "sub-handler"]);
});
}
#[test]
fn test_scope_chain_serialize_roundtrip() {
let key = unique_key("sc_serde", "rid");
register::<RequestId>(key);
let _g1 = enter_named_scope("app");
let _g2 = enter_named_scope("service");
set_context(key, RequestId("req-1".into()));
let bytes = serialize_context().unwrap();
scope(|| {
let _guard = deserialize_context(&bytes).unwrap();
assert_eq!(get_context::<RequestId>(key).0, "req-1");
assert_eq!(scope_chain(), vec!["app", "service"]);
let _g3 = enter_named_scope("handler");
assert_eq!(scope_chain(), vec!["app", "service", "handler"]);
});
}
#[test]
fn test_scope_chain_wire_v1_compat() {
let key = unique_key("sc_v1", "rid");
register::<RequestId>(key);
let value_bytes = bincode::serialize(&RequestId("v1-value".into())).unwrap();
let v1_bytes = make_wire_bytes_v(1, key, 1, &value_bytes);
scope(|| {
let _guard = deserialize_context(&v1_bytes).unwrap();
assert_eq!(get_context::<RequestId>(key).0, "v1-value");
assert!(scope_chain().is_empty());
});
}
#[test]
fn test_scope_chain_remote_chain_lifo_restore() {
let key = unique_key("sc_lifo", "rid");
register::<RequestId>(key);
let _g = enter_named_scope("local-root");
let _g1 = enter_named_scope("sender-scope");
set_context(key, RequestId("first".into()));
let bytes1 = serialize_context().unwrap();
scope(|| {
let _guard1 = deserialize_context(&bytes1).unwrap();
assert_eq!(scope_chain(), vec!["local-root", "sender-scope"]);
let _g2 = enter_named_scope("nested-scope");
let bytes2 = serialize_context().unwrap();
scope(|| {
let _guard2 = deserialize_context(&bytes2).unwrap();
assert_eq!(
scope_chain(),
vec!["local-root", "sender-scope", "nested-scope"]
);
});
assert_eq!(
scope_chain(),
vec!["local-root", "sender-scope", "nested-scope"]
);
});
}
mod async_scope_chain_tests {
use super::*;
#[tokio::test]
async fn test_scope_chain_with_context() {
let _g = enter_named_scope("pre-send");
let snap = snapshot();
with_context(snap, async {
assert_eq!(scope_chain(), vec!["pre-send"]);
named_scope_async("handler", async {
assert_eq!(scope_chain(), vec!["pre-send", "handler"]);
})
.await;
})
.await;
}
#[tokio::test]
async fn test_named_scope_async_basic() {
let snap = {
force_thread_local(|| {
let _g = enter_named_scope("root");
snapshot()
})
};
with_context(snap, async {
named_scope_async("level-1", async {
assert_eq!(scope_chain(), vec!["root", "level-1"]);
named_scope_async("level-2", async {
assert_eq!(scope_chain(), vec!["root", "level-1", "level-2"]);
})
.await;
})
.await;
})
.await;
}
}
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
struct ReentrantDropVal(String);
impl Drop for ReentrantDropVal {
fn drop(&mut self) {
let _ = crate::try_get_context::<RequestId>("__reentrant_drop_probe__");
}
}
#[test]
fn test_reentrant_read_during_scope_enter() {
let key = unique_key("reentrant_enter", "rid");
register::<RequestId>(key);
set_context(key, RequestId("parent-val".into()));
let _g = enter_scope();
let val: RequestId = get_context(key);
assert_eq!(val.0, "parent-val");
}
#[test]
fn test_reentrant_read_during_scope_leave() {
let key = unique_key("reentrant_leave", "rid");
register::<RequestId>(key);
set_context(key, RequestId("base".into()));
{
let _g = enter_scope();
set_context(key, RequestId("child".into()));
assert_eq!(get_context::<RequestId>(key).0, "child");
}
let val: RequestId = get_context(key);
assert_eq!(val.0, "base");
}
#[test]
fn test_reentrant_read_during_set_context() {
let key_a = unique_key("reentrant_set", "a");
let key_b = unique_key("reentrant_set", "b");
register::<RequestId>(key_a);
register::<RequestId>(key_b);
set_context(key_a, RequestId("aaa".into()));
set_context(key_b, RequestId("bbb".into()));
assert_eq!(get_context::<RequestId>(key_a).0, "aaa");
assert_eq!(get_context::<RequestId>(key_b).0, "bbb");
}
#[test]
fn test_reentrant_drop_on_value_overwrite() {
let key = unique_key("reentrant_drop_overwrite", "val");
register::<ReentrantDropVal>(key);
set_context(key, ReentrantDropVal("first".into()));
set_context(key, ReentrantDropVal("second".into()));
let val: ReentrantDropVal = get_context(key);
assert_eq!(val.0, "second");
}
#[test]
fn test_reentrant_drop_on_scope_leave() {
let key = unique_key("reentrant_drop_leave", "val");
register::<ReentrantDropVal>(key);
set_context(key, ReentrantDropVal("root".into()));
{
let _g = enter_scope();
set_context(key, ReentrantDropVal("child-scope".into()));
}
let val: ReentrantDropVal = get_context(key);
assert_eq!(val.0, "root");
}
#[test]
fn test_scope_push_pop_integrity_across_many_levels() {
let key = unique_key("stress_scope", "rid");
register::<RequestId>(key);
set_context(key, RequestId("root".into()));
let depth = 50;
let mut guards: Vec<ScopeGuard> = Vec::new();
for i in 0..depth {
guards.push(enter_named_scope(format!("scope-{}", i)));
set_context(key, RequestId(format!("val-{}", i)));
}
assert_eq!(
get_context::<RequestId>(key).0,
format!("val-{}", depth - 1)
);
for i in (0..depth).rev() {
guards.pop();
if i > 0 {
assert_eq!(
get_context::<RequestId>(key).0,
format!("val-{}", i - 1)
);
}
}
assert_eq!(get_context::<RequestId>(key).0, "root");
}
#[test]
fn test_scope_chain_integrity_after_many_push_pops() {
let key = unique_key("chain_stress", "rid");
register::<RequestId>(key);
for round in 0..10 {
let name = format!("round-{}", round);
let _g = enter_named_scope(&name);
let chain = scope_chain();
assert!(chain.last().map(|s| s.as_str()) == Some(name.as_str()));
}
assert!(scope_chain().is_empty());
}
#[test]
fn test_update_context_basic() {
let key = unique_key("update_ctx", "counter");
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct Counter(u64);
register::<Counter>(key);
set_context(key, Counter(10));
update_context::<Counter>(key, |c| Counter(c.0 + 5));
let val: Counter = get_context(key);
assert_eq!(val.0, 15);
}
#[test]
fn test_update_context_default_when_unset() {
let key = unique_key("update_default", "counter");
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct Counter(u64);
register::<Counter>(key);
update_context::<Counter>(key, |c| Counter(c.0 + 1));
let val: Counter = get_context(key);
assert_eq!(val.0, 1);
}
#[test]
fn test_update_context_callback_can_read_other_keys() {
let key_a = unique_key("update_read_other", "a");
let key_b = unique_key("update_read_other", "b");
register::<RequestId>(key_a);
register::<RequestId>(key_b);
set_context(key_a, RequestId("aaa".into()));
set_context(key_b, RequestId("bbb".into()));
update_context::<RequestId>(key_a, |_old| {
let b: RequestId = get_context(key_b);
RequestId(format!("merged-{}", b.0))
});
assert_eq!(get_context::<RequestId>(key_a).0, "merged-bbb");
assert_eq!(get_context::<RequestId>(key_b).0, "bbb");
}
#[test]
fn test_update_context_in_scope_reverts() {
let key = unique_key("update_scope_revert", "val");
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
struct Val(String);
register::<Val>(key);
set_context(key, Val("root".into()));
{
let _g = enter_scope();
update_context::<Val>(key, |_| Val("updated-in-child".into()));
assert_eq!(get_context::<Val>(key).0, "updated-in-child");
}
assert_eq!(get_context::<Val>(key).0, "root");
}
#[test]
fn test_get_context_option_some_and_none() {
let key_set = unique_key("get_opt", "set");
let key_unset = unique_key("get_opt", "unset");
register::<RequestId>(key_set);
register::<RequestId>(key_unset);
set_context(key_set, RequestId("hello".into()));
assert_eq!(
get_context_option::<RequestId>(key_set),
Some(RequestId("hello".into()))
);
assert_eq!(get_context_option::<RequestId>(key_unset), None);
}
#[test]
fn test_snapshot_uses_arc_sharing() {
let key = unique_key("snap_arc", "rid");
register::<RequestId>(key);
set_context(key, RequestId("original".into()));
let snap = snapshot();
set_context(key, RequestId("modified".into()));
{
let _g = attach(snap);
assert_eq!(get_context::<RequestId>(key).0, "original");
}
assert_eq!(get_context::<RequestId>(key).0, "modified");
}
#[test]
fn test_concurrent_scope_and_read_no_panic() {
let key = unique_key("concurrent_rw", "rid");
register::<RequestId>(key);
set_context(key, RequestId("base".into()));
for i in 0..100 {
set_context(key, RequestId(format!("iter-{}", i)));
let val: RequestId = get_context(key);
assert_eq!(val.0, format!("iter-{}", i));
}
}
#[test]
fn test_cached_key_o1_read_in_nested_scopes() {
let key = unique_key("cached_read", "rid");
register_with::<RequestId>(key, |opts| opts.cached());
set_context(key, RequestId("root-val".into()));
let _g1 = enter_scope();
assert_eq!(get_context::<RequestId>(key).0, "root-val");
let _g2 = enter_scope();
assert_eq!(get_context::<RequestId>(key).0, "root-val");
set_context(key, RequestId("inner-val".into()));
assert_eq!(get_context::<RequestId>(key).0, "inner-val");
drop(_g2);
assert_eq!(get_context::<RequestId>(key).0, "root-val");
drop(_g1);
assert_eq!(get_context::<RequestId>(key).0, "root-val");
}
#[test]
fn test_non_cached_key_walks_parents() {
let key = unique_key("non_cached", "rid");
register::<RequestId>(key);
set_context(key, RequestId("root-val".into()));
let _g1 = enter_scope();
assert_eq!(get_context::<RequestId>(key).0, "root-val");
set_context(key, RequestId("child-val".into()));
assert_eq!(get_context::<RequestId>(key).0, "child-val");
let _g2 = enter_scope();
assert_eq!(get_context::<RequestId>(key).0, "child-val");
drop(_g2);
assert_eq!(get_context::<RequestId>(key).0, "child-val");
drop(_g1);
assert_eq!(get_context::<RequestId>(key).0, "root-val");
}
#[tokio::test]
async fn test_async_reentrant_safety() {
let key = unique_key("async_reentrant", "rid");
register::<RequestId>(key);
let snap = {
force_thread_local(|| {
set_context(key, RequestId("base".into()));
snapshot()
})
};
with_context(snap, async {
assert_eq!(get_context::<RequestId>(key).0, "base");
scope_async(async {
set_context(key, RequestId("in-scope-async".into()));
assert_eq!(get_context::<RequestId>(key).0, "in-scope-async");
named_scope_async("inner", async {
assert_eq!(get_context::<RequestId>(key).0, "in-scope-async");
set_context(key, RequestId("deep".into()));
assert_eq!(get_context::<RequestId>(key).0, "deep");
})
.await;
assert_eq!(get_context::<RequestId>(key).0, "in-scope-async");
})
.await;
assert_eq!(get_context::<RequestId>(key).0, "base");
})
.await;
}
#[tokio::test]
async fn test_fork_reads_parent_values() {
let key = unique_key("fork_read", "rid");
register::<RequestId>(key);
let snap = force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("parent-val".into()));
fork()
});
let result = with_fork(snap, async {
get_context::<RequestId>(key)
}).await;
assert_eq!(result.0, "parent-val");
}
#[tokio::test]
async fn test_fork_writes_are_isolated() {
let key = unique_key("fork_isolate", "rid");
register::<RequestId>(key);
let handle = force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("parent".into()));
fork()
});
with_fork(handle, async {
set_context(key, RequestId("child-override".into()));
let val: RequestId = get_context(key);
assert_eq!(val.0, "child-override");
}).await;
let parent_val: RequestId = force_thread_local(|| get_context(key));
assert_eq!(parent_val, RequestId::default());
}
#[tokio::test]
async fn test_fork_is_cheap_clone() {
let key = unique_key("fork_clone", "rid");
register::<RequestId>(key);
let handle = force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("shared".into()));
fork()
});
let handle2 = handle.clone();
let r1 = with_fork(handle, async { get_context::<RequestId>(key) }).await;
let r2 = with_fork(handle2, async { get_context::<RequestId>(key) }).await;
assert_eq!(r1.0, "shared");
assert_eq!(r2.0, "shared");
}
#[tokio::test]
async fn test_fork_child_scopes_work() {
let key = unique_key("fork_scope", "rid");
register::<RequestId>(key);
let handle = force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("base".into()));
fork()
});
with_fork(handle, async {
assert_eq!(get_context::<RequestId>(key).0, "base");
scope_async(async {
set_context(key, RequestId("inner".into()));
assert_eq!(get_context::<RequestId>(key).0, "inner");
}).await;
assert_eq!(get_context::<RequestId>(key).0, "base");
}).await;
}
#[tokio::test]
async fn test_spawn_with_fork_async() {
let key = unique_key("fork_spawn", "rid");
register::<RequestId>(key);
force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("spawned".into()));
});
force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("spawned".into()));
});
let handle = force_thread_local(|| {
let _scope = enter_scope();
set_context(key, RequestId("for-spawn".into()));
fork()
});
let join = tokio::spawn(with_fork(handle, async {
get_context::<RequestId>(key)
}));
let result = join.await.unwrap();
assert_eq!(result.0, "for-spawn");
}
#[tokio::test]
async fn test_fork_empty_context() {
let handle = force_thread_local(|| fork());
with_fork(handle, async {
}).await;
}
#[tokio::test]
async fn test_fork_scope_chain_preserved() {
let key = unique_key("fork_chain", "rid");
register::<RequestId>(key);
let handle = force_thread_local(|| {
let _scope = enter_named_scope("parent-scope");
set_context(key, RequestId("chained".into()));
fork()
});
with_fork(handle, async {
let chain = scope_chain();
assert!(chain.contains(&"parent-scope".to_string()),
"fork should preserve parent scope chain: {:?}", chain);
}).await;
}