1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::hash::{BuildHasherDefault, Hasher};
4
5#[derive(Default)]
12struct IdHasher(u64);
13
14impl Hasher for IdHasher {
15 fn write(&mut self, bytes: &[u8]) {
16 let mut value = 0u64;
19 for (i, &b) in bytes.iter().enumerate() {
20 value ^= (b as u64) << ((i % 8) * 8);
21 }
22 self.0 = value;
23 }
24
25 fn write_u64(&mut self, i: u64) {
26 self.0 = i;
27 }
28
29 fn finish(&self) -> u64 {
30 self.0
31 }
32}
33
34type TypeMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
35
36pub struct State(TypeMap);
43
44impl State {
45 pub fn new() -> Self {
46 State(HashMap::default())
47 }
48
49 pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
51 self.0.insert(TypeId::of::<T>(), Box::new(value));
52 }
53
54 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
56 self.0
57 .get(&TypeId::of::<T>())
58 .and_then(|v| v.downcast_ref())
59 }
60
61 pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
63 self.0
64 .get_mut(&TypeId::of::<T>())
65 .and_then(|v| v.downcast_mut())
66 }
67
68 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
70 self.0.contains_key(&TypeId::of::<T>())
71 }
72
73 pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
75 self.0
76 .remove(&TypeId::of::<T>())
77 .and_then(|v| v.downcast().ok())
78 .map(|v| *v)
79 }
80}
81
82impl Default for State {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl std::fmt::Debug for State {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_tuple("State")
91 .field(&format!("{} entries", self.0.len()))
92 .finish()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn insert_and_get() {
102 let mut state = State::new();
103 state.insert(42u32);
104 assert_eq!(state.get::<u32>(), Some(&42));
105 }
106
107 #[test]
108 fn insert_replaces_existing() {
109 let mut state = State::new();
110 state.insert(1u32);
111 state.insert(2u32);
112 assert_eq!(state.get::<u32>(), Some(&2));
113 }
114
115 #[test]
116 fn get_missing_returns_none() {
117 let state = State::new();
118 assert_eq!(state.get::<u32>(), None);
119 }
120
121 #[test]
122 fn get_mut_allows_modification() {
123 let mut state = State::new();
124 state.insert(String::from("hello"));
125 if let Some(s) = state.get_mut::<String>() {
126 s.push_str(" world");
127 }
128 assert_eq!(state.get::<String>().unwrap(), "hello world");
129 }
130
131 #[test]
132 fn contains() {
133 let mut state = State::new();
134 assert!(!state.contains::<u32>());
135 state.insert(42u32);
136 assert!(state.contains::<u32>());
137 }
138
139 #[test]
140 fn remove_returns_owned_value() {
141 let mut state = State::new();
142 state.insert(String::from("removed"));
143 let removed = state.remove::<String>();
144 assert_eq!(removed, Some(String::from("removed")));
145 assert!(!state.contains::<String>());
146 }
147
148 #[test]
149 fn multiple_types() {
150 let mut state = State::new();
151 state.insert(42u32);
152 state.insert("hello");
153 state.insert(vec![1, 2, 3]);
154
155 assert_eq!(state.get::<u32>(), Some(&42));
156 assert_eq!(state.get::<&str>(), Some(&"hello"));
157 assert_eq!(state.get::<Vec<i32>>(), Some(&vec![1, 2, 3]));
158 }
159
160 #[test]
161 fn debug_output() {
162 let mut state = State::new();
163 state.insert(1u32);
164 state.insert("hello");
165 let debug = format!("{state:?}");
166 assert!(debug.contains("2 entries"));
167 }
168
169 #[test]
170 fn state_is_send_sync() {
171 fn assert_send_sync<T: Send + Sync>() {}
172 assert_send_sync::<State>();
173 }
174}