trillium_http/
state_set.rs

1// Originally from https://github.com/http-rs/http-types/blob/main/src/extensions.rs
2//
3// Implementation is based on
4// - https://github.com/hyperium/http/blob/master/src/extensions.rs
5// - https://github.com/kardeiz/type-map/blob/master/src/lib.rs
6use hashbrown::HashMap;
7use std::{
8    any::{Any, TypeId},
9    fmt,
10    hash::{BuildHasherDefault, Hasher},
11};
12
13/// Store and retrieve values by
14/// [`TypeId`](https://doc.rust-lang.org/std/any/struct.TypeId.html). This
15/// allows storing arbitrary data that implements `Sync + Send +
16/// 'static`.
17#[derive(Default)]
18pub struct StateSet(HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>);
19
20impl StateSet {
21    /// Create an empty `StateSet`.
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Insert a value into this `StateSet`.
27    ///
28    /// If a value of this type already exists, it will be returned.
29    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
30        self.0
31            .insert(TypeId::of::<T>(), Box::new(val))
32            .and_then(|boxed| (boxed as Box<dyn Any>).downcast().ok().map(|boxed| *boxed))
33    }
34
35    /// Check if container contains value for type
36    pub fn contains<T: 'static>(&self) -> bool {
37        self.0.get(&TypeId::of::<T>()).is_some()
38    }
39
40    /// Get a reference to a value previously inserted on this `StateSet`.
41    pub fn get<T: 'static>(&self) -> Option<&T> {
42        self.0
43            .get(&TypeId::of::<T>())
44            .and_then(|boxed| (&**boxed as &(dyn Any)).downcast_ref())
45    }
46
47    /// Get a mutable reference to a value previously inserted on this `StateSet`.
48    pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
49        self.0
50            .get_mut(&TypeId::of::<T>())
51            .and_then(|boxed| (&mut **boxed as &mut (dyn Any)).downcast_mut())
52    }
53
54    /// Remove a value from this `StateSet`.
55    ///
56    /// If a value of this type exists, it will be returned.
57    pub fn take<T: 'static>(&mut self) -> Option<T> {
58        self.0
59            .remove(&TypeId::of::<T>())
60            .and_then(|boxed| (boxed as Box<dyn Any>).downcast().ok().map(|boxed| *boxed))
61    }
62
63    /// Gets a value from this `StateSet` or populates it with the
64    /// provided default.
65    #[allow(clippy::missing_panics_doc)]
66    pub fn get_or_insert<T: Send + Sync + 'static>(&mut self, default: T) -> &mut T {
67        self.0
68            .entry(TypeId::of::<T>())
69            .or_insert_with(|| Box::new(default))
70            .downcast_mut()
71            .expect("StateSet maintains the invariant the value associated with a given TypeId is always the type associated with that TypeId.")
72    }
73
74    /// Gets a value from this `StateSet` or populates it with the
75    /// provided default function.
76    #[allow(clippy::missing_panics_doc)]
77    pub fn get_or_insert_with<F, T>(&mut self, default: F) -> &mut T
78    where
79        F: FnOnce() -> T,
80        T: Send + Sync + 'static,
81    {
82        self.0
83            .entry(TypeId::of::<T>())
84            .or_insert_with(|| Box::new(default()))
85            .downcast_mut()
86            .expect("StateSet maintains the invariant the value associated with a given TypeId is always the type associated with that TypeId.")
87    }
88}
89
90impl fmt::Debug for StateSet {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("StateSet").finish()
93    }
94}
95
96// With TypeIds as keys, there's no need to hash them. So we simply use an identy hasher.
97#[derive(Default)]
98struct IdHasher(u64);
99
100impl Hasher for IdHasher {
101    fn write(&mut self, _: &[u8]) {
102        unreachable!("TypeId calls write_u64");
103    }
104
105    #[inline]
106    fn write_u64(&mut self, id: u64) {
107        self.0 = id;
108    }
109
110    #[inline]
111    fn finish(&self) -> u64 {
112        self.0
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    #[test]
120    fn test_extensions() {
121        #[derive(Debug, PartialEq)]
122        struct MyType(i32);
123
124        let mut map = StateSet::new();
125
126        map.insert(5i32);
127        map.insert(MyType(10));
128
129        assert_eq!(map.get(), Some(&5i32));
130        assert_eq!(map.get_mut(), Some(&mut 5i32));
131
132        assert_eq!(map.take::<i32>(), Some(5i32));
133        assert!(map.get::<i32>().is_none());
134
135        assert_eq!(map.get::<bool>(), None);
136        assert_eq!(map.get(), Some(&MyType(10)));
137    }
138}