1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 hash::{BuildHasherDefault, Hasher},
5 sync::{Arc, Mutex},
6};
7
8#[derive(Clone, Debug, Default)]
9pub struct Extensions {
10 inner: Arc<Mutex<ExtensionInner>>,
11}
12
13#[derive(Clone, Debug, Default)]
14struct ExtensionInner {
15 map: HashMap<TypeId, ExtensionItem, BuildHasherDefault<IdHasher>>,
16 values: Vec<Option<Arc<dyn Any + Send + Sync>>>,
17}
18
19#[derive(Debug, Clone)]
20struct ExtensionItem {
21 index: usize,
22 ever_fetched: bool,
23}
24
25pub enum InsertEffect {
26 Replaced,
27 New,
29}
30
31#[derive(Debug, Clone)]
32pub enum Removed<T> {
33 Removed(T),
35 Referenced(Arc<T>),
37 Invalidated,
39}
40
41impl<T> Removed<T> {
42 pub fn unwrap(self) -> T {
43 match self {
44 Removed::Removed(x) => x,
45 Removed::Referenced(_) => panic!("extension is referenced"),
46 Removed::Invalidated => panic!("extension is invalidated (was referenced)"),
47 }
48 }
49}
50
51impl Extensions {
52 pub fn new() -> Self {
53 Default::default()
54 }
55
56 pub fn insert<T: Send + Sync + 'static>(&self, val: T) -> InsertEffect {
59 let type_id = TypeId::of::<T>();
60 let mut inner = self.inner.lock().unwrap();
61 let target_index = inner.values.len();
62 let old_index = inner.map.insert(
63 type_id,
64 ExtensionItem {
65 index: target_index,
66 ever_fetched: false,
67 },
68 );
69 inner.values.push(Some(Arc::new(val)));
70 if old_index.is_some() {
71 return InsertEffect::Replaced;
72 }
73 InsertEffect::New
74 }
75
76 pub fn get<'a, T: Send + Sync + 'static>(&'a self) -> Option<&'a T> {
79 let mut inner = self.inner.lock().unwrap();
80 let index = inner.map.get_mut(&TypeId::of::<T>())?;
81 index.ever_fetched = true;
82 let index = index.index;
83 let value: &T = (&**inner.values.get(index)?.as_ref()?).downcast_ref()?;
87 Some(unsafe { std::mem::transmute(value) })
88 }
89
90 pub fn get_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
95 let inner = self.inner.lock().unwrap();
96 let index = inner.map.get(&TypeId::of::<T>())?;
97 let index = index.index;
98 let value: Arc<T> = Arc::downcast(inner.values.get(index)?.as_ref()?.clone()).ok()?;
99 Some(value)
100 }
101
102 pub fn remove<T: Send + Sync + 'static>(&self) -> Option<Removed<T>> {
104 let mut inner = self.inner.lock().unwrap();
105 let index = inner.map.get(&TypeId::of::<T>())?;
106 if index.ever_fetched {
107 return Some(Removed::Invalidated);
108 }
109 let index = index.index;
110 let value = std::mem::replace(inner.values.get_mut(index)?, None)?;
111 let value: Arc<T> = Arc::downcast(value).ok()?;
112 match Arc::try_unwrap(value) {
113 Ok(x) => Some(Removed::Removed(x)),
114 Err(e) => Some(Removed::Referenced(e)),
115 }
116 }
117
118 pub fn extend(&self, other: &Extensions) {
119 let inner = other.inner.lock().unwrap();
120 let mut this = self.inner.lock().unwrap();
121 let inner_map = inner.map.clone();
122 for (type_id, index) in inner_map {
123 let Some(item) = inner.values.get(index.index) else {
124 continue;
125 };
126 let Some(item) = item.clone() else {
127 continue;
128 };
129 let ext_item = ExtensionItem {
130 index: this.values.len(),
131 ever_fetched: false,
133 };
134 this.map.insert(type_id, ext_item);
135 this.values.push(Some(item));
136 }
137 }
138}
139
140#[derive(Default)]
141struct IdHasher(u64);
142
143impl Hasher for IdHasher {
144 fn write(&mut self, _: &[u8]) {
145 unreachable!("TypeId calls write_u64");
146 }
147
148 #[inline]
149 fn write_u64(&mut self, id: u64) {
150 self.0 = id;
151 }
152
153 #[inline]
154 fn finish(&self) -> u64 {
155 self.0
156 }
157}
158
159type AnyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
160struct HttpExtensions {
161 map: Option<Box<AnyMap>>,
162}
163
164impl From<http::Extensions> for Extensions {
165 fn from(value: http::Extensions) -> Self {
166 let value: HttpExtensions = unsafe { std::mem::transmute(value) };
167 let mut inner = ExtensionInner {
168 map: Default::default(),
169 values: Default::default(),
170 };
171 if let Some(value) = value.map {
172 for (type_id, value) in value.into_iter() {
173 let item = ExtensionItem {
174 index: inner.values.len(),
175 ever_fetched: false,
176 };
177 inner.map.insert(type_id, item);
178 inner.values.push(Some(Arc::from(value)));
179 }
180 }
181 Self {
182 inner: Arc::new(Mutex::new(inner)),
183 }
184 }
185}
186
187