axol_http/
extensions.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    hash::{BuildHasherDefault, Hasher},
5    sync::{Arc, Mutex},
6};
7
8#[derive(Clone, Debug, Default)]
9pub struct Extensions {
10    inner: Arc<Mutex<ExtensionInner>>,
11}
12
13#[derive(Clone, Debug, Default)]
14struct ExtensionInner {
15    map: HashMap<TypeId, ExtensionItem, BuildHasherDefault<IdHasher>>,
16    values: Vec<Option<Arc<dyn Any + Send + Sync>>>,
17}
18
19#[derive(Debug, Clone)]
20struct ExtensionItem {
21    index: usize,
22    ever_fetched: bool,
23}
24
25pub enum InsertEffect {
26    Replaced,
27    /// No previous value
28    New,
29}
30
31#[derive(Debug, Clone)]
32pub enum Removed<T> {
33    /// Value was fully removed and unshelled
34    Removed(T),
35    /// Value was fully removed, but could not be unshelled (still referenced?)
36    Referenced(Arc<T>),
37    /// Value had previously been returned from a `get` call and can no longer be removed.
38    Invalidated,
39}
40
41impl<T> Removed<T> {
42    pub fn unwrap(self) -> T {
43        match self {
44            Removed::Removed(x) => x,
45            Removed::Referenced(_) => panic!("extension is referenced"),
46            Removed::Invalidated => panic!("extension is invalidated (was referenced)"),
47        }
48    }
49}
50
51impl Extensions {
52    pub fn new() -> Self {
53        Default::default()
54    }
55
56    /// Inserts a new value into the extension map. It will replace any existing value with the same TypeId.
57    /// Note that any outstanding `get`/`get_arc` on the type will not be altered.
58    pub fn insert<T: Send + Sync + 'static>(&self, val: T) -> InsertEffect {
59        let type_id = TypeId::of::<T>();
60        let mut inner = self.inner.lock().unwrap();
61        let target_index = inner.values.len();
62        let old_index = inner.map.insert(
63            type_id,
64            ExtensionItem {
65                index: target_index,
66                ever_fetched: false,
67            },
68        );
69        inner.values.push(Some(Arc::new(val)));
70        if old_index.is_some() {
71            return InsertEffect::Replaced;
72        }
73        InsertEffect::New
74    }
75
76    /// Gets a reference to an extension value.
77    /// This will invalidate that value from ever being manually removed.
78    pub fn get<'a, T: Send + Sync + 'static>(&'a self) -> Option<&'a T> {
79        let mut inner = self.inner.lock().unwrap();
80        let index = inner.map.get_mut(&TypeId::of::<T>())?;
81        index.ever_fetched = true;
82        let index = index.index;
83        // SAFETY: we never remove things from Extensions until its dropped, so the reference to the interior is always valid for self's lifetime
84        // Furthermore, we prevent calling `remove` unless a value has **never** been "get"ted before.
85        // ... look, I really want to be able to reference this data and remove elements. I know it's overkill.
86        let value: &T = (&**inner.values.get(index)?.as_ref()?).downcast_ref()?;
87        Some(unsafe { std::mem::transmute(value) })
88    }
89
90    /// Gets a reference to an extension value.
91    /// Since it returns an `Arc` and tracks it's deallocation, it does not prevent a value from being manually removed.
92    /// However, while the `Arc` is alive, it cannot be removed.
93    /// Take care that the `Arc` doesn't outlive the Request/Response, otherwise there will be a panic.
94    pub fn get_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
95        let inner = self.inner.lock().unwrap();
96        let index = inner.map.get(&TypeId::of::<T>())?;
97        let index = index.index;
98        let value: Arc<T> = Arc::downcast(inner.values.get(index)?.as_ref()?.clone()).ok()?;
99        Some(value)
100    }
101
102    /// Removes a non-invalidated entry from the Extensions map
103    pub fn remove<T: Send + Sync + 'static>(&self) -> Option<Removed<T>> {
104        let mut inner = self.inner.lock().unwrap();
105        let index = inner.map.get(&TypeId::of::<T>())?;
106        if index.ever_fetched {
107            return Some(Removed::Invalidated);
108        }
109        let index = index.index;
110        let value = std::mem::replace(inner.values.get_mut(index)?, None)?;
111        let value: Arc<T> = Arc::downcast(value).ok()?;
112        match Arc::try_unwrap(value) {
113            Ok(x) => Some(Removed::Removed(x)),
114            Err(e) => Some(Removed::Referenced(e)),
115        }
116    }
117
118    pub fn extend(&self, other: &Extensions) {
119        let inner = other.inner.lock().unwrap();
120        let mut this = self.inner.lock().unwrap();
121        let inner_map = inner.map.clone();
122        for (type_id, index) in inner_map {
123            let Some(item) = inner.values.get(index.index) else {
124                continue;
125            };
126            let Some(item) = item.clone() else {
127                continue;
128            };
129            let ext_item = ExtensionItem {
130                index: this.values.len(),
131                // the old lifetime is necessarily over since it's being dropped
132                ever_fetched: false,
133            };
134            this.map.insert(type_id, ext_item);
135            this.values.push(Some(item));
136        }
137    }
138}
139
140#[derive(Default)]
141struct IdHasher(u64);
142
143impl Hasher for IdHasher {
144    fn write(&mut self, _: &[u8]) {
145        unreachable!("TypeId calls write_u64");
146    }
147
148    #[inline]
149    fn write_u64(&mut self, id: u64) {
150        self.0 = id;
151    }
152
153    #[inline]
154    fn finish(&self) -> u64 {
155        self.0
156    }
157}
158
159type AnyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
160struct HttpExtensions {
161    map: Option<Box<AnyMap>>,
162}
163
164impl From<http::Extensions> for Extensions {
165    fn from(value: http::Extensions) -> Self {
166        let value: HttpExtensions = unsafe { std::mem::transmute(value) };
167        let mut inner = ExtensionInner {
168            map: Default::default(),
169            values: Default::default(),
170        };
171        if let Some(value) = value.map {
172            for (type_id, value) in value.into_iter() {
173                let item = ExtensionItem {
174                    index: inner.values.len(),
175                    ever_fetched: false,
176                };
177                inner.map.insert(type_id, item);
178                inner.values.push(Some(Arc::from(value)));
179            }
180        }
181        Self {
182            inner: Arc::new(Mutex::new(inner)),
183        }
184    }
185}
186
187// not possible in stable rust atm (converting Arc<T> -> Box<T> while ?Sized)
188// maybe possible with lots of assumptions and asm?
189
190// impl Into<http::Extensions> for Extensions {
191//     fn into(self) -> http::Extensions {
192//         let mut out = http::Extensions::new();
193//         let mut inner = self.inner.lock().unwrap();
194//         for (type_id, index) in std::mem::take(&mut inner.map) {
195//             let Some(item) = inner.values.get_mut(index.index) else {
196//                 continue;
197//             };
198//             let Some(item) = std::mem::take(item) else {
199//                 continue;
200//             };
201//             // item = Arc::try_int
202//         }
203//         out
204//     }
205// }