use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
pub struct ContextKey<T> {
name: &'static str,
_marker: std::marker::PhantomData<T>,
}
impl<T> ContextKey<T> {
pub const fn new(name: &'static str) -> Self {
Self {
name,
_marker: std::marker::PhantomData,
}
}
pub fn name(&self) -> &'static str {
self.name
}
}
#[derive(Default)]
pub struct Context {
values: HashMap<(TypeId, &'static str), Arc<dyn Any + Send + Sync>>,
}
impl Context {
pub fn new() -> Self {
Self::default()
}
pub fn insert<T: Send + Sync + 'static>(&mut self, key: &ContextKey<T>, value: T) {
let type_id = TypeId::of::<T>();
self.values
.insert((type_id, key.name), Arc::new(value));
}
pub fn get<T: Send + Sync + 'static>(&self, key: &ContextKey<T>) -> Option<&T> {
let type_id = TypeId::of::<T>();
self.values
.get(&(type_id, key.name))
.and_then(|v| v.downcast_ref::<T>())
}
pub fn remove<T: Send + Sync + Clone + 'static>(&mut self, key: &ContextKey<T>) -> Option<T> {
let type_id = TypeId::of::<T>();
let value = self.values
.get(&(type_id, key.name))
.and_then(|v| v.downcast_ref::<T>())
.cloned();
self.values.remove(&(type_id, key.name));
value
}
pub fn contains<T: Send + Sync + 'static>(&self, key: &ContextKey<T>) -> bool {
let type_id = TypeId::of::<T>();
self.values.contains_key(&(type_id, key.name))
}
pub fn clear(&mut self) {
self.values.clear();
}
}
impl Clone for Context {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
}
}
}
pub mod keys {
use super::ContextKey;
pub static CORRELATION_ID: ContextKey<String> = ContextKey::new("correlation_id");
pub static RETRY_ATTEMPT: ContextKey<u32> = ContextKey::new("retry_attempt");
pub static OPERATION_NAME: ContextKey<String> = ContextKey::new("operation_name");
pub static START_TIME: ContextKey<std::time::Instant> = ContextKey::new("start_time");
}
#[cfg(test)]
mod tests {
use super::*;
static STRING_KEY: ContextKey<String> = ContextKey::new("string");
static INT_KEY: ContextKey<i32> = ContextKey::new("int");
#[test]
fn test_insert_and_get() {
let mut ctx = Context::new();
ctx.insert(&STRING_KEY, "hello".to_string());
ctx.insert(&INT_KEY, 42);
assert_eq!(ctx.get(&STRING_KEY), Some(&"hello".to_string()));
assert_eq!(ctx.get(&INT_KEY), Some(&42));
}
#[test]
fn test_get_missing_key() {
let ctx = Context::new();
assert_eq!(ctx.get(&STRING_KEY), None);
}
#[test]
fn test_contains() {
let mut ctx = Context::new();
assert!(!ctx.contains(&STRING_KEY));
ctx.insert(&STRING_KEY, "value".to_string());
assert!(ctx.contains(&STRING_KEY));
}
#[test]
fn test_remove() {
let mut ctx = Context::new();
ctx.insert(&STRING_KEY, "value".to_string());
let removed = ctx.remove(&STRING_KEY);
assert_eq!(removed, Some("value".to_string()));
assert!(!ctx.contains(&STRING_KEY));
}
#[test]
fn test_clear() {
let mut ctx = Context::new();
ctx.insert(&STRING_KEY, "value".to_string());
ctx.insert(&INT_KEY, 42);
ctx.clear();
assert!(!ctx.contains(&STRING_KEY));
assert!(!ctx.contains(&INT_KEY));
}
#[test]
fn test_clone() {
let mut ctx = Context::new();
ctx.insert(&STRING_KEY, "value".to_string());
let ctx2 = ctx.clone();
assert_eq!(ctx2.get(&STRING_KEY), Some(&"value".to_string()));
}
}