use std::{
any::{Any, TypeId},
collections::HashMap,
hash::{BuildHasherDefault, Hasher},
sync::{Arc, Mutex},
};
#[derive(Clone, Debug, Default)]
pub struct Extensions {
inner: Arc<Mutex<ExtensionInner>>,
}
#[derive(Clone, Debug, Default)]
struct ExtensionInner {
map: HashMap<TypeId, ExtensionItem, BuildHasherDefault<IdHasher>>,
values: Vec<Option<Arc<dyn Any + Send + Sync>>>,
}
#[derive(Debug, Clone)]
struct ExtensionItem {
index: usize,
ever_fetched: bool,
}
pub enum InsertEffect {
Replaced,
New,
}
#[derive(Debug, Clone)]
pub enum Removed<T> {
Removed(T),
Referenced(Arc<T>),
Invalidated,
}
impl<T> Removed<T> {
pub fn unwrap(self) -> T {
match self {
Removed::Removed(x) => x,
Removed::Referenced(_) => panic!("extension is referenced"),
Removed::Invalidated => panic!("extension is invalidated (was referenced)"),
}
}
}
impl Extensions {
pub fn new() -> Self {
Default::default()
}
pub fn insert<T: Send + Sync + 'static>(&self, val: T) -> InsertEffect {
let type_id = TypeId::of::<T>();
let mut inner = self.inner.lock().unwrap();
let target_index = inner.values.len();
let old_index = inner.map.insert(
type_id,
ExtensionItem {
index: target_index,
ever_fetched: false,
},
);
inner.values.push(Some(Arc::new(val)));
if old_index.is_some() {
return InsertEffect::Replaced;
}
InsertEffect::New
}
pub fn get<'a, T: Send + Sync + 'static>(&'a self) -> Option<&'a T> {
let mut inner = self.inner.lock().unwrap();
let index = inner.map.get_mut(&TypeId::of::<T>())?;
index.ever_fetched = true;
let index = index.index;
let value: &T = (&**inner.values.get(index)?.as_ref()?).downcast_ref()?;
Some(unsafe { std::mem::transmute(value) })
}
pub fn get_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
let inner = self.inner.lock().unwrap();
let index = inner.map.get(&TypeId::of::<T>())?;
let index = index.index;
let value: Arc<T> = Arc::downcast(inner.values.get(index)?.as_ref()?.clone()).ok()?;
Some(value)
}
pub fn remove<T: Send + Sync + 'static>(&self) -> Option<Removed<T>> {
let mut inner = self.inner.lock().unwrap();
let index = inner.map.get(&TypeId::of::<T>())?;
if index.ever_fetched {
return Some(Removed::Invalidated);
}
let index = index.index;
let value = std::mem::replace(inner.values.get_mut(index)?, None)?;
let value: Arc<T> = Arc::downcast(value).ok()?;
match Arc::try_unwrap(value) {
Ok(x) => Some(Removed::Removed(x)),
Err(e) => Some(Removed::Referenced(e)),
}
}
pub fn extend(&self, other: &Extensions) {
let inner = other.inner.lock().unwrap();
let mut this = self.inner.lock().unwrap();
let inner_map = inner.map.clone();
for (type_id, index) in inner_map {
let Some(item) = inner.values.get(index.index) else {
continue;
};
let Some(item) = item.clone() else {
continue;
};
let ext_item = ExtensionItem {
index: this.values.len(),
ever_fetched: false,
};
this.map.insert(type_id, ext_item);
this.values.push(Some(item));
}
}
}
#[derive(Default)]
struct IdHasher(u64);
impl Hasher for IdHasher {
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}
#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}
#[inline]
fn finish(&self) -> u64 {
self.0
}
}
type AnyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
struct HttpExtensions {
map: Option<Box<AnyMap>>,
}
impl From<http::Extensions> for Extensions {
fn from(value: http::Extensions) -> Self {
let value: HttpExtensions = unsafe { std::mem::transmute(value) };
let mut inner = ExtensionInner {
map: Default::default(),
values: Default::default(),
};
if let Some(value) = value.map {
for (type_id, value) in value.into_iter() {
let item = ExtensionItem {
index: inner.values.len(),
ever_fetched: false,
};
inner.map.insert(type_id, item);
inner.values.push(Some(Arc::from(value)));
}
}
Self {
inner: Arc::new(Mutex::new(inner)),
}
}
}