mod context;
use std::any::TypeId;
use std::any::Any;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::Arc;
use std::future::Future;
use std::time::{Duration, Instant};
pub use context::{ContextImpl, CancelFn, ContextError};
pub type Context = Arc<ContextImpl>;
#[derive(Hash, Eq, PartialEq, Clone)]
pub(crate) struct TypedKey {
key: Vec<u8>,
type_id: TypeId,
}
impl TypedKey {
pub fn new<K: serde::Serialize + 'static>(key: &K) -> Self {
Self {
key: bincode::serialize(key).expect("Failed to serialize key"),
type_id: TypeId::of::<K>(),
}
}
pub fn matches<K: serde::Serialize + 'static>(&self, key: &K) -> bool {
if self.type_id != TypeId::of::<K>() {
return false;
}
let key_bytes = bincode::serialize(key).expect("Failed to serialize key");
self.key == key_bytes
}
}
pub trait ContextTrait: Send + Sync {
fn done(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn err(&self) -> Option<ContextError>;
fn deadline(&self) -> Option<Instant>;
fn value<K: Hash + Eq + Clone + Send + Sync + serde::Serialize + 'static, V: Any + Send + Sync + Clone>(&self, key: &K) -> Option<Arc<V>>;
}
pub fn background() -> Context {
ContextImpl::background()
}
pub fn with_cancel(parent: &Context) -> (Context, CancelFn) {
ContextImpl::with_cancel(parent)
}
pub fn with_deadline(
parent: &Context,
deadline: Instant,
) -> (Context, CancelFn) {
ContextImpl::with_deadline(parent, deadline)
}
pub fn with_timeout(
parent: &Context,
timeout: Duration,
) -> (Context, CancelFn) {
ContextImpl::with_timeout(parent, timeout)
}
pub async fn with_value<K: Hash + Eq + Clone + Send + Sync + serde::Serialize + 'static, V: Any + Send + Sync>(
parent: &Context,
key: K,
value: V,
) -> Context {
ContextImpl::with_value(parent, key, value).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use serde::Serialize;
use std::sync::Arc;
use std::time::Instant;
#[tokio::test]
async fn test_background() {
let ctx = background();
assert!(ctx.err().is_none());
assert!(ctx.deadline().is_none());
}
#[tokio::test]
async fn test_cancel() {
let ctx = background();
let (ctx, cancel) = with_cancel(&ctx);
assert!(ctx.err().is_none());
cancel.cancel().await;
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(ctx.err(), Some(ContextError::Canceled));
}
#[tokio::test]
async fn test_deadline() {
let ctx = background();
let deadline = Instant::now() + Duration::from_millis(50);
let (ctx, _) = with_deadline(&ctx, deadline);
assert!(ctx.err().is_none());
assert_eq!(ctx.deadline(), Some(deadline));
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(ctx.err(), Some(ContextError::DeadlineExceeded));
}
#[tokio::test]
async fn test_value() {
let ctx = background();
let ctx = with_value(&ctx, "test_key", "test_value".to_string()).await;
let value: Option<Arc<String>> = ctx.value(&"test_key");
assert_eq!(value.map(|v| (*v).clone()), Some("test_value".to_string()));
}
#[tokio::test]
async fn test_multiple_values() {
let ctx = background();
let ctx = with_value(&ctx, "str_key", "test".to_string()).await;
let ctx = with_value(&ctx, 42_i32, 42_i64).await;
#[derive(Hash, Eq, PartialEq, Clone, Serialize)]
struct BoolKey(String);
let ctx = with_value(&ctx, BoolKey("bool".to_string()), true).await;
let str_value: Option<Arc<String>> = ctx.value(&"str_key");
let int_value: Option<Arc<i64>> = ctx.value(&42_i32);
let bool_value: Option<Arc<bool>> = ctx.value(&BoolKey("bool".to_string()));
assert_eq!(str_value.map(|v| (*v).clone()), Some("test".to_string()));
assert_eq!(int_value.map(|v| (*v).clone()), Some(42_i64));
assert_eq!(bool_value.map(|v| (*v).clone()), Some(true));
}
#[tokio::test]
async fn test_same_type_different_keys() {
let ctx = background();
let ctx = with_value(&ctx, "key1", "first".to_string()).await;
let ctx = with_value(&ctx, "key2", "second".to_string()).await;
let value1: Option<Arc<String>> = ctx.value(&"key1");
let value2: Option<Arc<String>> = ctx.value(&"key2");
assert_eq!(value1.map(|v| (*v).clone()), Some("first".to_string()));
assert_eq!(value2.map(|v| (*v).clone()), Some("second".to_string()));
}
#[tokio::test]
async fn test_nested_contexts() {
let root = background();
let root = with_value(&root, "root_key", "root_value".to_string()).await;
let child = with_value(&root, "child_key", "child_value".to_string()).await;
let deadline = Instant::now() + Duration::from_secs(1);
let (grandchild, _) = with_deadline(&child, deadline);
let grandchild = with_value(&grandchild, "grandchild_key", "grandchild_value".to_string()).await;
let root_value: Option<Arc<String>> = grandchild.value(&"root_key");
let child_value: Option<Arc<String>> = grandchild.value(&"child_key");
let grandchild_value: Option<Arc<String>> = grandchild.value(&"grandchild_key");
assert_eq!(root_value.map(|v| (*v).clone()), Some("root_value".to_string()));
assert_eq!(child_value.map(|v| (*v).clone()), Some("child_value".to_string()));
assert_eq!(grandchild_value.map(|v| (*v).clone()), Some("grandchild_value".to_string()));
let root_from_child: Option<Arc<String>> = child.value(&"root_key");
let child_from_child: Option<Arc<String>> = child.value(&"child_key");
let grandchild_from_child: Option<Arc<String>> = child.value(&"grandchild_key");
assert_eq!(root_from_child.map(|v| (*v).clone()), Some("root_value".to_string()));
assert_eq!(child_from_child.map(|v| (*v).clone()), Some("child_value".to_string()));
assert_eq!(grandchild_from_child, None);
assert_eq!(grandchild.deadline(), Some(deadline));
assert!(child.deadline().is_none());
}
#[tokio::test]
async fn test_context_in_tokio_tasks() {
let ctx = background();
let (ctx, cancel) = with_cancel(&ctx);
let handle = tokio::spawn({
let ctx = ctx.clone();
async move {
let done = ctx.done();
tokio::select! {
_ = done => {
"cancelled"
}
_ = tokio::time::sleep(Duration::from_secs(10)) => {
"timeout"
}
}
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
cancel.cancel().await;
let result = handle.await.unwrap();
assert_eq!(result, "cancelled");
}
#[tokio::test]
async fn test_context_value_across_tasks() {
let ctx = background();
let ctx = with_value(&ctx, "key", "parent_value".to_string()).await;
let mut handles = vec![];
for i in 0..3 {
let ctx = ctx.clone();
handles.push(tokio::spawn(async move {
if let Some(value) = ctx.value::<_, String>(&"key") {
format!("task_{}_got_{}", i, value.as_ref())
} else {
format!("task_{}_no_value", i)
}
}));
}
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap();
assert_eq!(result, format!("task_{}_got_parent_value", i));
}
}
#[tokio::test]
async fn test_context_deadline_in_tasks() {
let ctx = background();
let deadline = Instant::now() + Duration::from_millis(100);
let (ctx, _) = with_deadline(&ctx, deadline);
let handle = tokio::spawn({
let ctx = ctx.clone();
async move {
let done = ctx.done();
tokio::select! {
_ = done => {
ctx.err().unwrap() == ContextError::DeadlineExceeded
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
false
}
}
}
});
let deadline_exceeded = handle.await.unwrap();
assert!(deadline_exceeded);
}
#[tokio::test]
async fn test_nested_tasks_with_context() {
let ctx = background();
let mut current_ctx = ctx;
let mut expected_values = Vec::new();
for i in 0..=3 {
current_ctx = with_value(¤t_ctx, format!("level_{}", i), i).await;
expected_values.push(i);
let ctx = current_ctx.clone();
let handle = tokio::spawn(async move {
let mut values = Vec::new();
for level in 0..=i {
if let Some(value) = ctx.value::<_, i32>(&format!("level_{}", level)) {
values.push(*value.as_ref());
}
}
values
});
let values = handle.await.unwrap();
assert_eq!(values, expected_values);
}
}
#[tokio::test]
async fn test_task_cancellation_cleanup() {
let ctx = background();
let (ctx, cancel) = with_cancel(&ctx);
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let handle = tokio::spawn({
let ctx = ctx.clone();
let tx = tx.clone();
async move {
let result = tokio::select! {
_ = ctx.done() => {
tx.send("cleanup").await.unwrap();
"cancelled"
}
_ = tokio::time::sleep(Duration::from_secs(10)) => {
"timeout"
}
};
result
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
cancel.cancel().await;
assert_eq!(handle.await.unwrap(), "cancelled");
assert_eq!(rx.recv().await.unwrap(), "cleanup");
}
}