socketioxide/extensions.rs
1//! [`Extensions`] used to store extra data in each socket instance.
2//!
3//! It is heavily inspired by the [`http::Extensions`] type from the `http` crate.
4//!
5//! The main difference is that the inner [`HashMap`] is wrapped with an [`RwLock`]
6//! to allow concurrent access. Moreover, any value extracted from the map is cloned before being returned.
7//!
8//! This is necessary because [`Extensions`] are shared between all the threads that handle the same socket.
9//!
10//! You can use the [`Extension`](crate::extract::Extension) or
11//! [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type.
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::RwLock;
16use std::{
17 any::{Any, TypeId},
18 hash::{BuildHasherDefault, Hasher},
19};
20
21/// TypeMap value
22type AnyVal = Box<dyn Any + Send + Sync>;
23
24/// The [`AnyHashMap`] is a [`HashMap`] that uses `TypeId` as keys and `Any` as values.
25type AnyHashMap = RwLock<HashMap<TypeId, AnyVal, BuildHasherDefault<IdHasher>>>;
26
27// With TypeIds as keys, there's no need to hash them. They are already hashes
28// themselves, coming from the compiler. The IdHasher just holds the u64 of
29// the TypeId, and then returns it, instead of doing any bit fiddling.
30#[derive(Default)]
31struct IdHasher(u64);
32
33impl Hasher for IdHasher {
34 #[inline]
35 fn finish(&self) -> u64 {
36 self.0
37 }
38
39 fn write(&mut self, _: &[u8]) {
40 unreachable!("TypeId calls write_u64");
41 }
42
43 #[inline]
44 fn write_u64(&mut self, id: u64) {
45 self.0 = id;
46 }
47}
48
49/// A type map of protocol extensions.
50///
51/// It is heavily inspired by the `Extensions` type from the `http` crate.
52///
53/// The main difference is that the inner Map is wrapped with an `RwLock` to allow concurrent access.
54///
55/// This is necessary because `Extensions` are shared between all the threads that handle the same socket.
56///
57/// You can use the [`Extension`](crate::extract::Extension) or
58/// [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type.
59#[derive(Default)]
60pub struct Extensions {
61 /// The underlying map
62 map: AnyHashMap,
63}
64
65impl Extensions {
66 /// Create an empty `Extensions`.
67 #[inline]
68 pub fn new() -> Extensions {
69 Extensions {
70 map: AnyHashMap::default(),
71 }
72 }
73
74 /// Insert a type into the `Extensions`.
75 ///
76 /// The type must be cloneable and thread safe to be stored.
77 ///
78 /// If a extension of this type already existed, it will
79 /// be returned.
80 ///
81 /// # Example
82 ///
83 /// ```
84 /// # use socketioxide::extensions::Extensions;
85 /// let mut ext = Extensions::new();
86 /// assert!(ext.insert(5i32).is_none());
87 /// assert!(ext.insert(4u8).is_none());
88 /// assert_eq!(ext.insert(9i32), Some(5i32));
89 /// ```
90 pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) -> Option<T> {
91 self.map
92 .write()
93 .unwrap()
94 .insert(TypeId::of::<T>(), Box::new(val))
95 .and_then(|v| v.downcast().ok().map(|boxed| *boxed))
96 }
97
98 /// Get a cloned value of a type previously inserted in the `Extensions`.
99 ///
100 /// # Example
101 ///
102 /// ```
103 /// # use socketioxide::extensions::Extensions;
104 /// let ext = Extensions::new();
105 /// assert!(ext.get::<i32>().is_none());
106 /// ext.insert(5i32);
107 ///
108 /// assert_eq!(ext.get::<i32>().unwrap(), 5i32);
109 /// ```
110 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
111 self.map
112 .read()
113 .unwrap()
114 .get(&TypeId::of::<T>())
115 .and_then(|v| v.downcast_ref::<T>())
116 .cloned()
117 }
118
119 /// Remove a type from the `Extensions`.
120 ///
121 /// If a extension of this type existed, it will be returned.
122 ///
123 /// # Example
124 ///
125 /// ```
126 /// # use socketioxide::extensions::Extensions;
127 /// let mut ext = Extensions::new();
128 /// ext.insert(5i32);
129 /// assert_eq!(ext.remove::<i32>(), Some(5i32));
130 /// assert!(ext.get::<i32>().is_none());
131 /// ```
132 pub fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
133 self.map
134 .write()
135 .unwrap()
136 .remove(&TypeId::of::<T>())
137 .and_then(|v| v.downcast().ok().map(|boxed| *boxed))
138 }
139
140 /// Clear the `Extensions` of all inserted extensions.
141 ///
142 /// # Example
143 ///
144 /// ```
145 /// # use socketioxide::extensions::Extensions;
146 /// let mut ext = Extensions::new();
147 /// ext.insert(5i32);
148 /// ext.clear();
149 ///
150 /// assert!(ext.get::<i32>().is_none());
151 /// ```
152 #[inline]
153 pub fn clear(&self) {
154 self.map.write().unwrap().clear();
155 }
156
157 /// Check whether the extension set is empty or not.
158 ///
159 /// # Example
160 ///
161 /// ```
162 /// # use socketioxide::extensions::Extensions;
163 /// let mut ext = Extensions::new();
164 /// assert!(ext.is_empty());
165 /// ext.insert(5i32);
166 /// assert!(!ext.is_empty());
167 /// ```
168 #[inline]
169 pub fn is_empty(&self) -> bool {
170 self.map.read().unwrap().is_empty()
171 }
172
173 /// Get the number of extensions available.
174 ///
175 /// # Example
176 ///
177 /// ```
178 /// # use socketioxide::extensions::Extensions;
179 /// let mut ext = Extensions::new();
180 /// assert_eq!(ext.len(), 0);
181 /// ext.insert(5i32);
182 /// assert_eq!(ext.len(), 1);
183 /// ```
184 #[inline]
185 pub fn len(&self) -> usize {
186 self.map.read().unwrap().len()
187 }
188}
189
190impl fmt::Debug for Extensions {
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 f.debug_struct("Extensions").finish()
193 }
194}
195
196#[test]
197fn test_extensions() {
198 use std::sync::Arc;
199 #[derive(Debug, Clone, PartialEq)]
200 struct MyType(i32);
201
202 #[derive(Debug, PartialEq)]
203 struct ComplexSharedType(u64);
204 let shared = Arc::new(ComplexSharedType(20));
205
206 let extensions = Extensions::new();
207
208 extensions.insert(5i32);
209 extensions.insert(MyType(10));
210 extensions.insert(shared.clone());
211
212 assert_eq!(extensions.get(), Some(5i32));
213 assert_eq!(extensions.get::<Arc<ComplexSharedType>>(), Some(shared));
214
215 assert_eq!(extensions.remove::<i32>(), Some(5i32));
216 assert!(extensions.get::<i32>().is_none());
217
218 assert!(extensions.get::<bool>().is_none());
219 assert_eq!(extensions.get(), Some(MyType(10)));
220}