use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone, Default)]
pub struct Extensions {
inner: Arc<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
}
impl Extensions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn get<T>(&self) -> Option<Arc<T>>
where
T: Send + Sync + 'static,
{
self.inner
.get(&TypeId::of::<T>())
.and_then(|entry| Arc::clone(entry).downcast::<T>().ok())
}
#[must_use]
pub fn contains<T>(&self) -> bool
where
T: Send + Sync + 'static,
{
self.inner.contains_key(&TypeId::of::<T>())
}
#[must_use]
pub(crate) fn inserted<T>(&self, value: T) -> Self
where
T: Send + Sync + 'static,
{
let mut next: HashMap<TypeId, Arc<dyn Any + Send + Sync>> = (*self.inner).clone();
next.insert(TypeId::of::<T>(), Arc::new(value));
Self {
inner: Arc::new(next),
}
}
}
impl std::fmt::Debug for Extensions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Extensions")
.field("len", &self.inner.len())
.finish()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[derive(Debug, PartialEq, Eq)]
struct Workspace(&'static str);
#[derive(Debug, PartialEq, Eq)]
struct RequestId(u64);
#[test]
fn empty_extensions_have_no_entries() {
let ext = Extensions::new();
assert!(ext.is_empty());
assert_eq!(ext.len(), 0);
assert!(ext.get::<Workspace>().is_none());
assert!(!ext.contains::<Workspace>());
}
#[test]
fn insert_and_get_round_trip() {
let ext = Extensions::new().inserted(Workspace("repo-a"));
assert_eq!(ext.len(), 1);
let got = ext.get::<Workspace>().unwrap();
assert_eq!(*got, Workspace("repo-a"));
}
#[test]
fn multiple_distinct_types_coexist() {
let ext = Extensions::new()
.inserted(Workspace("repo-a"))
.inserted(RequestId(42));
assert_eq!(ext.len(), 2);
assert_eq!(*ext.get::<Workspace>().unwrap(), Workspace("repo-a"));
assert_eq!(*ext.get::<RequestId>().unwrap(), RequestId(42));
}
#[test]
fn second_insert_of_same_type_replaces() {
let ext = Extensions::new()
.inserted(Workspace("repo-a"))
.inserted(Workspace("repo-b"));
assert_eq!(ext.len(), 1, "one entry per type");
assert_eq!(*ext.get::<Workspace>().unwrap(), Workspace("repo-b"));
}
#[test]
fn copy_on_write_does_not_mutate_original() {
let original = Extensions::new().inserted(Workspace("repo-a"));
let extended = original.inserted(RequestId(7));
assert_eq!(original.len(), 1);
assert!(original.get::<RequestId>().is_none());
assert_eq!(extended.len(), 2);
assert!(extended.get::<RequestId>().is_some());
}
#[test]
fn absent_type_returns_none() {
let ext = Extensions::new().inserted(Workspace("repo-a"));
assert!(ext.get::<RequestId>().is_none());
}
#[test]
fn contains_reflects_insertion() {
let ext = Extensions::new();
assert!(!ext.contains::<Workspace>());
let ext = ext.inserted(Workspace("repo-a"));
assert!(ext.contains::<Workspace>());
}
#[test]
fn debug_surfaces_cardinality() {
let ext = Extensions::new().inserted(Workspace("repo-a"));
let debug_str = format!("{ext:?}");
assert!(debug_str.contains("len: 1"), "{debug_str}");
}
#[test]
fn arc_returned_from_get_outlives_extensions_clone() {
let ext = Extensions::new().inserted(Workspace("repo-a"));
let arc = ext.get::<Workspace>().unwrap();
drop(ext);
assert_eq!(*arc, Workspace("repo-a"));
}
}