graphile_worker_extensions/
lib.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    fmt::Debug,
5    hash::{BuildHasherDefault, Hasher},
6    sync::Arc,
7};
8
9pub(crate) type AnyMap =
10    HashMap<TypeId, Box<dyn AnyClone + Send + Sync>, BuildHasherDefault<IdHasher>>;
11
12// With TypeIds as keys, there's no need to hash them. They are already hashes
13// themselves, coming from the compiler. The IdHasher just holds the u64 of
14// the TypeId, and then returns it, instead of doing any bit fiddling.
15#[derive(Default)]
16pub(crate) struct IdHasher(u64);
17
18impl Hasher for IdHasher {
19    fn write(&mut self, _: &[u8]) {
20        unreachable!("TypeId calls write_u64");
21    }
22
23    #[inline]
24    fn write_u64(&mut self, id: u64) {
25        self.0 = id;
26    }
27
28    #[inline]
29    fn finish(&self) -> u64 {
30        self.0
31    }
32}
33
34/// A type map of worker extensions.
35///
36/// `Extensions` can be used by worker job to store extra data.
37#[derive(Clone, Default, Debug)]
38pub struct Extensions {
39    // If extensions are never used, no need to carry around an empty HashMap.
40    // That's 3 words. Instead, this is only 1 word.
41    map: Option<Box<AnyMap>>,
42}
43
44impl Extensions {
45    /// Create an empty `Extensions`.
46    #[inline]
47    pub fn new() -> Extensions {
48        Extensions { map: None }
49    }
50
51    /// Insert a type into this `Extensions`.
52    ///
53    /// If a extension of this type already existed, it will
54    /// be returned.
55    ///
56    /// # Example
57    ///
58    /// ```
59    /// use graphile_worker_extensions::Extensions;
60    /// let mut ext = Extensions::new();
61    /// assert!(ext.insert(5i32).is_none());
62    /// assert!(ext.insert(4u8).is_none());
63    /// assert_eq!(ext.insert(9i32), Some(5i32));
64    /// ```
65    pub fn insert<T: Clone + Send + Sync + Debug + 'static>(&mut self, val: T) -> Option<T> {
66        self.map
67            .get_or_insert_with(Box::default)
68            .insert(TypeId::of::<T>(), Box::new(val))
69            .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed))
70    }
71
72    /// Get a reference to a type previously inserted on this `Extensions`.
73    ///
74    /// # Example
75    ///
76    /// ```
77    /// use graphile_worker_extensions::Extensions;
78    /// let mut ext = Extensions::new();
79    /// assert!(ext.get::<i32>().is_none());
80    /// ext.insert(5i32);
81    ///
82    /// assert_eq!(ext.get::<i32>(), Some(&5i32));
83    /// ```
84    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
85        self.map
86            .as_ref()
87            .and_then(|map| map.get(&TypeId::of::<T>()))
88            .and_then(|boxed| (**boxed).as_any().downcast_ref())
89    }
90
91    /// Get a mutable reference to a type previously inserted on this `Extensions`.
92    ///
93    /// # Example
94    ///
95    /// ```
96    /// use graphile_worker_extensions::Extensions;
97    /// let mut ext = Extensions::new();
98    /// ext.insert(String::from("Hello"));
99    /// ext.get_mut::<String>().unwrap().push_str(" World");
100    ///
101    /// assert_eq!(ext.get::<String>().unwrap(), "Hello World");
102    /// ```
103    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
104        self.map
105            .as_mut()
106            .and_then(|map| map.get_mut(&TypeId::of::<T>()))
107            .and_then(|boxed| (**boxed).as_any_mut().downcast_mut())
108    }
109
110    /// Get a mutable reference to a type, inserting `value` if not already present on this
111    /// `Extensions`.
112    ///
113    /// # Example
114    ///
115    /// ```
116    /// use graphile_worker_extensions::Extensions;
117    /// let mut ext = Extensions::new();
118    /// *ext.get_or_insert(1i32) += 2;
119    ///
120    /// assert_eq!(*ext.get::<i32>().unwrap(), 3);
121    /// ```
122    pub fn get_or_insert<T: Clone + Send + Sync + Debug + 'static>(&mut self, value: T) -> &mut T {
123        self.get_or_insert_with(|| value)
124    }
125
126    /// Get a mutable reference to a type, inserting the value created by `f` if not already present
127    /// on this `Extensions`.
128    ///
129    /// # Example
130    ///
131    /// ```
132    /// use graphile_worker_extensions::Extensions;
133    /// let mut ext = Extensions::new();
134    /// *ext.get_or_insert_with(|| 1i32) += 2;
135    ///
136    /// assert_eq!(*ext.get::<i32>().unwrap(), 3);
137    /// ```
138    pub fn get_or_insert_with<T: Clone + Send + Sync + Debug + 'static, F: FnOnce() -> T>(
139        &mut self,
140        f: F,
141    ) -> &mut T {
142        let out = self
143            .map
144            .get_or_insert_with(Box::default)
145            .entry(TypeId::of::<T>())
146            .or_insert_with(|| Box::new(f()));
147        (**out).as_any_mut().downcast_mut().unwrap()
148    }
149
150    /// Get a mutable reference to a type, inserting the type's default value if not already present
151    /// on this `Extensions`.
152    ///
153    /// # Example
154    ///
155    /// ```
156    /// use graphile_worker_extensions::Extensions;
157    /// let mut ext = Extensions::new();
158    /// *ext.get_or_insert_default::<i32>() += 2;
159    ///
160    /// assert_eq!(*ext.get::<i32>().unwrap(), 2);
161    /// ```
162    pub fn get_or_insert_default<T: Default + Clone + Send + Sync + Debug + 'static>(
163        &mut self,
164    ) -> &mut T {
165        self.get_or_insert_with(T::default)
166    }
167
168    /// Remove a type from this `Extensions`.
169    ///
170    /// If a extension of this type existed, it will be returned.
171    ///
172    /// # Example
173    ///
174    /// ```
175    /// use graphile_worker_extensions::Extensions;
176    /// let mut ext = Extensions::new();
177    /// ext.insert(5i32);
178    /// assert_eq!(ext.remove::<i32>(), Some(5i32));
179    /// assert!(ext.get::<i32>().is_none());
180    /// ```
181    pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
182        self.map
183            .as_mut()
184            .and_then(|map| map.remove(&TypeId::of::<T>()))
185            .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed))
186    }
187
188    /// Clear the `Extensions` of all inserted extensions.
189    ///
190    /// # Example
191    ///
192    /// ```
193    /// use graphile_worker_extensions::Extensions;
194    /// let mut ext = Extensions::new();
195    /// ext.insert(5i32);
196    /// ext.clear();
197    ///
198    /// assert!(ext.get::<i32>().is_none());
199    /// ```
200    #[inline]
201    pub fn clear(&mut self) {
202        if let Some(ref mut map) = self.map {
203            map.clear();
204        }
205    }
206
207    /// Check whether the extension set is empty or not.
208    ///
209    /// # Example
210    ///
211    /// ```
212    /// use graphile_worker_extensions::Extensions;
213    /// let mut ext = Extensions::new();
214    /// assert!(ext.is_empty());
215    /// ext.insert(5i32);
216    /// assert!(!ext.is_empty());
217    /// ```
218    #[inline]
219    pub fn is_empty(&self) -> bool {
220        self.map.as_ref().is_none_or(|map| map.is_empty())
221    }
222
223    /// Get the numer of extensions available.
224    ///
225    /// # Example
226    ///
227    /// ```
228    /// use graphile_worker_extensions::Extensions;
229    /// let mut ext = Extensions::new();
230    /// assert_eq!(ext.len(), 0);
231    /// ext.insert(5i32);
232    /// assert_eq!(ext.len(), 1);
233    /// ```
234    #[inline]
235    pub fn len(&self) -> usize {
236        self.map.as_ref().map_or(0, |map| map.len())
237    }
238
239    /// Extends `self` with another `Extensions`.
240    ///
241    /// If an instance of a specific type exists in both, the one in `self` is overwritten with the
242    /// one from `other`.
243    ///
244    /// # Example
245    ///
246    /// ```
247    /// use graphile_worker_extensions::Extensions;
248    /// let mut ext_a = Extensions::new();
249    /// ext_a.insert(8u8);
250    /// ext_a.insert(16u16);
251    ///
252    /// let mut ext_b = Extensions::new();
253    /// ext_b.insert(4u8);
254    /// ext_b.insert("hello");
255    ///
256    /// ext_a.extend(ext_b);
257    /// assert_eq!(ext_a.len(), 3);
258    /// assert_eq!(ext_a.get::<u8>(), Some(&4u8));
259    /// assert_eq!(ext_a.get::<u16>(), Some(&16u16));
260    /// assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello"));
261    /// ```
262    pub fn extend(&mut self, other: Self) {
263        if let Some(other) = other.map {
264            if let Some(map) = &mut self.map {
265                map.extend(*other);
266            } else {
267                self.map = Some(other);
268            }
269        }
270    }
271}
272
273/// A read-only wrapper around `Extensions` that can be safely shared.
274///
275/// `ReadOnlyExtensions` wraps an `Extensions` instance in an `Arc` to allow
276/// sharing it between components without allowing modifications. This is used
277/// to provide task handlers with access to shared application state without
278/// allowing them to modify that state.
279#[derive(Clone, Debug)]
280pub struct ReadOnlyExtensions(Arc<Extensions>);
281
282impl ReadOnlyExtensions {
283    /// Creates a new `ReadOnlyExtensions` from an `Extensions` instance.
284    ///
285    /// This wraps the provided extensions in an `Arc` to allow safe sharing
286    /// across threads and components.
287    ///
288    /// # Arguments
289    ///
290    /// * `ext` - The extensions to wrap
291    ///
292    /// # Returns
293    ///
294    /// A new `ReadOnlyExtensions` instance
295    pub fn new(ext: Extensions) -> Self {
296        ReadOnlyExtensions(Arc::new(ext))
297    }
298
299    /// Get a reference to a type previously inserted in the wrapped `Extensions`.
300    ///
301    /// # Type Parameters
302    ///
303    /// * `T` - The type of extension to retrieve
304    ///
305    /// # Returns
306    ///
307    /// * `Some(&T)` - A reference to the extension value if it exists
308    /// * `None` - If no extension of the requested type is registered
309    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
310        self.0.get()
311    }
312
313    /// Get the number of extensions available.
314    ///
315    /// # Returns
316    ///
317    /// The number of extension values stored in this container
318    pub fn len(&self) -> usize {
319        self.0.len()
320    }
321
322    /// Check whether the extension set is empty or not.
323    ///
324    /// # Returns
325    ///
326    /// `true` if no extensions are stored, `false` otherwise
327    pub fn is_empty(&self) -> bool {
328        self.0.is_empty()
329    }
330}
331
332/// Implements the `From` trait to allow converting an `Extensions` into a
333/// `ReadOnlyExtensions` easily.
334impl From<Extensions> for ReadOnlyExtensions {
335    /// Converts an `Extensions` into a `ReadOnlyExtensions`.
336    ///
337    /// This is a convenience implementation that simply calls `ReadOnlyExtensions::new`.
338    ///
339    /// # Arguments
340    ///
341    /// * `ext` - The extensions to convert
342    ///
343    /// # Returns
344    ///
345    /// A new `ReadOnlyExtensions` instance containing the same extensions
346    fn from(ext: Extensions) -> Self {
347        ReadOnlyExtensions::new(ext)
348    }
349}
350
351pub trait AnyClone: Any + Debug {
352    fn clone_box(&self) -> Box<dyn AnyClone + Send + Sync>;
353    fn as_any(&self) -> &dyn Any;
354    fn as_any_mut(&mut self) -> &mut dyn Any;
355    fn into_any(self: Box<Self>) -> Box<dyn Any>;
356}
357
358impl<T: Clone + Send + Sync + Debug + 'static> AnyClone for T {
359    fn clone_box(&self) -> Box<dyn AnyClone + Send + Sync> {
360        Box::new(self.clone())
361    }
362
363    fn as_any(&self) -> &dyn Any {
364        self
365    }
366
367    fn as_any_mut(&mut self) -> &mut dyn Any {
368        self
369    }
370
371    fn into_any(self: Box<Self>) -> Box<dyn Any> {
372        self
373    }
374}
375
376impl Clone for Box<dyn AnyClone + Send + Sync> {
377    fn clone(&self) -> Self {
378        (**self).clone_box()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_insert_and_get() {
388        let mut ext = Extensions::new();
389        assert!(ext.insert(5i32).is_none());
390        assert_eq!(ext.get::<i32>(), Some(&5i32));
391    }
392
393    #[test]
394    fn test_insert_and_get_mut() {
395        let mut ext = Extensions::new();
396        ext.insert(String::from("Hello"));
397        ext.get_mut::<String>().unwrap().push_str(" World");
398        assert_eq!(ext.get::<String>().unwrap(), "Hello World");
399    }
400
401    #[test]
402    fn test_get_or_insert() {
403        let mut ext = Extensions::new();
404        *ext.get_or_insert(1i32) += 2;
405        assert_eq!(ext.get::<i32>(), Some(&3i32));
406    }
407
408    #[test]
409    fn test_get_or_insert_with() {
410        let mut ext = Extensions::new();
411        *ext.get_or_insert_with(|| 1i32) += 2;
412        assert_eq!(ext.get::<i32>(), Some(&3i32));
413    }
414
415    #[test]
416    fn test_get_or_insert_default() {
417        let mut ext = Extensions::new();
418        *ext.get_or_insert_default::<i32>() += 2;
419        assert_eq!(ext.get::<i32>(), Some(&2i32));
420    }
421
422    #[test]
423    fn test_remove() {
424        let mut ext = Extensions::new();
425        ext.insert(5i32);
426        assert_eq!(ext.remove::<i32>(), Some(5i32));
427        assert!(ext.get::<i32>().is_none());
428    }
429
430    #[test]
431    fn test_clear() {
432        let mut ext = Extensions::new();
433        ext.insert(5i32);
434        ext.clear();
435        assert!(ext.get::<i32>().is_none());
436    }
437
438    #[test]
439    fn test_is_empty() {
440        let mut ext = Extensions::new();
441        assert!(ext.is_empty());
442        ext.insert(5i32);
443        assert!(!ext.is_empty());
444    }
445
446    #[test]
447    fn test_len() {
448        let mut ext = Extensions::new();
449        assert_eq!(ext.len(), 0);
450        ext.insert(5i32);
451        assert_eq!(ext.len(), 1);
452    }
453
454    #[test]
455    fn test_extend() {
456        let mut ext_a = Extensions::new();
457        ext_a.insert(8u8);
458        ext_a.insert(16u16);
459
460        let mut ext_b = Extensions::new();
461        ext_b.insert(4u8);
462        ext_b.insert("hello");
463
464        ext_a.extend(ext_b);
465        assert_eq!(ext_a.len(), 3);
466        assert_eq!(ext_a.get::<u8>(), Some(&4u8));
467        assert_eq!(ext_a.get::<u16>(), Some(&16u16));
468        assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello"));
469    }
470}