use std::any::{Any, TypeId};
use std::collections::HashMap;
use tokio::sync::RwLock;
pub struct PluginContext {
state: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
}
impl PluginContext {
pub fn new() -> Self {
Self { state: RwLock::new(HashMap::new()) }
}
pub async fn insert<T: Send + Sync + 'static>(&self, value: T) {
self.state.write().await.insert(TypeId::of::<T>(), Box::new(value));
}
pub async fn get<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
self.state.read().await.get(&TypeId::of::<T>()).and_then(|v| v.downcast_ref::<T>()).cloned()
}
pub async fn contains<T: Send + Sync + 'static>(&self) -> bool {
self.state.read().await.contains_key(&TypeId::of::<T>())
}
pub async fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
self.state
.write()
.await
.remove(&TypeId::of::<T>())
.and_then(|v| v.downcast::<T>().ok())
.map(|b| *b)
}
}
impl Default for PluginContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, PartialEq)]
struct Counter(u32);
#[derive(Clone, Debug, PartialEq)]
struct Name(String);
#[tokio::test]
async fn test_new_context_is_empty() {
let ctx = PluginContext::new();
assert!(!ctx.contains::<Counter>().await);
assert_eq!(ctx.get::<Counter>().await, None);
}
#[tokio::test]
async fn test_insert_and_get() {
let ctx = PluginContext::new();
ctx.insert(Counter(42)).await;
let value = ctx.get::<Counter>().await;
assert_eq!(value, Some(Counter(42)));
}
#[tokio::test]
async fn test_insert_overwrites_previous() {
let ctx = PluginContext::new();
ctx.insert(Counter(1)).await;
ctx.insert(Counter(99)).await;
assert_eq!(ctx.get::<Counter>().await, Some(Counter(99)));
}
#[tokio::test]
async fn test_multiple_types() {
let ctx = PluginContext::new();
ctx.insert(Counter(10)).await;
ctx.insert(Name("hello".to_string())).await;
assert_eq!(ctx.get::<Counter>().await, Some(Counter(10)));
assert_eq!(ctx.get::<Name>().await, Some(Name("hello".to_string())));
}
#[tokio::test]
async fn test_contains() {
let ctx = PluginContext::new();
assert!(!ctx.contains::<Counter>().await);
ctx.insert(Counter(0)).await;
assert!(ctx.contains::<Counter>().await);
}
#[tokio::test]
async fn test_remove_returns_value() {
let ctx = PluginContext::new();
ctx.insert(Counter(7)).await;
let removed = ctx.remove::<Counter>().await;
assert_eq!(removed, Some(Counter(7)));
}
#[tokio::test]
async fn test_remove_makes_get_return_none() {
let ctx = PluginContext::new();
ctx.insert(Counter(7)).await;
ctx.remove::<Counter>().await;
assert_eq!(ctx.get::<Counter>().await, None);
assert!(!ctx.contains::<Counter>().await);
}
#[tokio::test]
async fn test_remove_nonexistent_returns_none() {
let ctx = PluginContext::new();
let removed = ctx.remove::<Counter>().await;
assert_eq!(removed, None);
}
#[tokio::test]
async fn test_default_creates_empty_context() {
let ctx = PluginContext::default();
assert!(!ctx.contains::<Counter>().await);
}
#[tokio::test]
async fn test_concurrent_access() {
use std::sync::Arc;
let ctx = Arc::new(PluginContext::new());
ctx.insert(Counter(0)).await;
let ctx_clone = Arc::clone(&ctx);
let writer = tokio::spawn(async move {
for i in 1..=100 {
ctx_clone.insert(Counter(i)).await;
}
});
let ctx_clone2 = Arc::clone(&ctx);
let reader = tokio::spawn(async move {
for _ in 0..100 {
let _ = ctx_clone2.get::<Counter>().await;
}
});
writer.await.unwrap();
reader.await.unwrap();
assert_eq!(ctx.get::<Counter>().await, Some(Counter(100)));
}
}