socketioxide/
extensions.rs

1//! [`Extensions`] used to store extra data in each socket instance.
2//!
3//! It is heavily inspired by the [`http::Extensions`] type from the `http` crate.
4//!
5//! The main difference is that the inner [`HashMap`] is wrapped with an [`RwLock`]
6//! to allow concurrent access. Moreover, any value extracted from the map is cloned before being returned.
7//!
8//! This is necessary because [`Extensions`] are shared between all the threads that handle the same socket.
9//!
10//! You can use the [`Extension`](crate::extract::Extension) or
11//! [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type.
12
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::RwLock;
16use std::{
17    any::{Any, TypeId},
18    hash::{BuildHasherDefault, Hasher},
19};
20
21/// TypeMap value
22type AnyVal = Box<dyn Any + Send + Sync>;
23
24/// The [`AnyHashMap`] is a [`HashMap`] that uses `TypeId` as keys and `Any` as values.
25type AnyHashMap = RwLock<HashMap<TypeId, AnyVal, BuildHasherDefault<IdHasher>>>;
26
27// With TypeIds as keys, there's no need to hash them. They are already hashes
28// themselves, coming from the compiler. The IdHasher just holds the u64 of
29// the TypeId, and then returns it, instead of doing any bit fiddling.
30#[derive(Default)]
31struct IdHasher(u64);
32
33impl Hasher for IdHasher {
34    #[inline]
35    fn finish(&self) -> u64 {
36        self.0
37    }
38
39    fn write(&mut self, _: &[u8]) {
40        unreachable!("TypeId calls write_u64");
41    }
42
43    #[inline]
44    fn write_u64(&mut self, id: u64) {
45        self.0 = id;
46    }
47}
48
49/// A type map of protocol extensions.
50///
51/// It is heavily inspired by the `Extensions` type from the `http` crate.
52///
53/// The main difference is that the inner Map is wrapped with an `RwLock` to allow concurrent access.
54///
55/// This is necessary because `Extensions` are shared between all the threads that handle the same socket.
56///
57/// You can use the [`Extension`](crate::extract::Extension) or
58/// [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type.
59#[derive(Default)]
60pub struct Extensions {
61    /// The underlying map
62    map: AnyHashMap,
63}
64
65impl Extensions {
66    /// Create an empty `Extensions`.
67    #[inline]
68    pub fn new() -> Extensions {
69        Extensions {
70            map: AnyHashMap::default(),
71        }
72    }
73
74    /// Insert a type into the `Extensions`.
75    ///
76    /// The type must be cloneable and thread safe to be stored.
77    ///
78    /// If a extension of this type already existed, it will
79    /// be returned.
80    ///
81    /// # Example
82    ///
83    /// ```
84    /// # use socketioxide::extensions::Extensions;
85    /// let mut ext = Extensions::new();
86    /// assert!(ext.insert(5i32).is_none());
87    /// assert!(ext.insert(4u8).is_none());
88    /// assert_eq!(ext.insert(9i32), Some(5i32));
89    /// ```
90    pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) -> Option<T> {
91        self.map
92            .write()
93            .unwrap()
94            .insert(TypeId::of::<T>(), Box::new(val))
95            .and_then(|v| v.downcast().ok().map(|boxed| *boxed))
96    }
97
98    /// Get a cloned value of a type previously inserted in the `Extensions`.
99    ///
100    /// # Example
101    ///
102    /// ```
103    /// # use socketioxide::extensions::Extensions;
104    /// let ext = Extensions::new();
105    /// assert!(ext.get::<i32>().is_none());
106    /// ext.insert(5i32);
107    ///
108    /// assert_eq!(ext.get::<i32>().unwrap(), 5i32);
109    /// ```
110    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
111        self.map
112            .read()
113            .unwrap()
114            .get(&TypeId::of::<T>())
115            .and_then(|v| v.downcast_ref::<T>())
116            .cloned()
117    }
118
119    /// Remove a type from the `Extensions`.
120    ///
121    /// If a extension of this type existed, it will be returned.
122    ///
123    /// # Example
124    ///
125    /// ```
126    /// # use socketioxide::extensions::Extensions;
127    /// let mut ext = Extensions::new();
128    /// ext.insert(5i32);
129    /// assert_eq!(ext.remove::<i32>(), Some(5i32));
130    /// assert!(ext.get::<i32>().is_none());
131    /// ```
132    pub fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
133        self.map
134            .write()
135            .unwrap()
136            .remove(&TypeId::of::<T>())
137            .and_then(|v| v.downcast().ok().map(|boxed| *boxed))
138    }
139
140    /// Clear the `Extensions` of all inserted extensions.
141    ///
142    /// # Example
143    ///
144    /// ```
145    /// # use socketioxide::extensions::Extensions;
146    /// let mut ext = Extensions::new();
147    /// ext.insert(5i32);
148    /// ext.clear();
149    ///
150    /// assert!(ext.get::<i32>().is_none());
151    /// ```
152    #[inline]
153    pub fn clear(&self) {
154        self.map.write().unwrap().clear();
155    }
156
157    /// Check whether the extension set is empty or not.
158    ///
159    /// # Example
160    ///
161    /// ```
162    /// # use socketioxide::extensions::Extensions;
163    /// let mut ext = Extensions::new();
164    /// assert!(ext.is_empty());
165    /// ext.insert(5i32);
166    /// assert!(!ext.is_empty());
167    /// ```
168    #[inline]
169    pub fn is_empty(&self) -> bool {
170        self.map.read().unwrap().is_empty()
171    }
172
173    /// Get the number of extensions available.
174    ///
175    /// # Example
176    ///
177    /// ```
178    /// # use socketioxide::extensions::Extensions;
179    /// let mut ext = Extensions::new();
180    /// assert_eq!(ext.len(), 0);
181    /// ext.insert(5i32);
182    /// assert_eq!(ext.len(), 1);
183    /// ```
184    #[inline]
185    pub fn len(&self) -> usize {
186        self.map.read().unwrap().len()
187    }
188}
189
190impl fmt::Debug for Extensions {
191    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        f.debug_struct("Extensions").finish()
193    }
194}
195
196#[test]
197fn test_extensions() {
198    use std::sync::Arc;
199    #[derive(Debug, Clone, PartialEq)]
200    struct MyType(i32);
201
202    #[derive(Debug, PartialEq)]
203    struct ComplexSharedType(u64);
204    let shared = Arc::new(ComplexSharedType(20));
205
206    let extensions = Extensions::new();
207
208    extensions.insert(5i32);
209    extensions.insert(MyType(10));
210    extensions.insert(shared.clone());
211
212    assert_eq!(extensions.get(), Some(5i32));
213    assert_eq!(extensions.get::<Arc<ComplexSharedType>>(), Some(shared));
214
215    assert_eq!(extensions.remove::<i32>(), Some(5i32));
216    assert!(extensions.get::<i32>().is_none());
217
218    assert!(extensions.get::<bool>().is_none());
219    assert_eq!(extensions.get(), Some(MyType(10)));
220}