1use std::{
2 borrow::Borrow,
3 collections::HashMap,
4 fmt::{self, Display, Formatter},
5 hash::{Hash, Hasher},
6 sync::{Arc, Mutex, OnceLock, Weak},
7};
8
9#[derive(Clone, Debug, serde::Serialize)]
23#[serde(transparent)]
24pub struct Symbol(Arc<str>);
25
26impl Symbol {
27 pub fn as_str(&self) -> &str {
28 self.0.as_ref()
29 }
30
31 pub fn intern(name: &str) -> Self {
32 let mutex = INTERNER.get_or_init(|| Mutex::new(HashMap::new()));
33 let mut table = match mutex.lock() {
34 Ok(guard) => guard,
35 Err(poisoned) => poisoned.into_inner(),
36 };
37
38 if let Some(existing) = table.get(name).and_then(Weak::upgrade) {
39 return Symbol(existing);
40 }
41
42 prune_dead_symbols_if_needed(&mut table);
43 if let Some(existing) = table.get(name).and_then(Weak::upgrade) {
44 return Symbol(existing);
45 }
46
47 let sym = Symbol(Arc::from(name));
48 table.insert(name.to_string(), Arc::downgrade(&sym.0));
49 sym
50 }
51}
52
53impl PartialEq for Symbol {
54 fn eq(&self, other: &Self) -> bool {
55 let same_allocation = Arc::ptr_eq(&self.0, &other.0);
56 debug_assert!(
57 same_allocation || self.as_ref() != other.as_ref(),
58 "symbol interner invariant violated: duplicate live symbols for `{}`",
59 self.as_ref()
60 );
61 same_allocation
62 }
63}
64
65impl Eq for Symbol {}
66
67impl PartialOrd for Symbol {
68 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
69 Some(self.cmp(other))
70 }
71}
72
73impl Ord for Symbol {
74 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
75 let ordering = self.as_ref().cmp(other.as_ref());
76 debug_assert!(
77 ordering != std::cmp::Ordering::Equal || Arc::ptr_eq(&self.0, &other.0),
78 "symbol interner invariant violated: duplicate live symbols for `{}`",
79 self.as_ref()
80 );
81 ordering
82 }
83}
84
85impl Hash for Symbol {
86 fn hash<H: Hasher>(&self, state: &mut H) {
87 self.as_ref().hash(state);
88 }
89}
90
91impl AsRef<str> for Symbol {
92 fn as_ref(&self) -> &str {
93 self.0.as_ref()
94 }
95}
96
97impl PartialEq<&str> for Symbol {
98 fn eq(&self, other: &&str) -> bool {
99 self.as_ref() == *other
100 }
101}
102
103impl PartialEq<Symbol> for &str {
104 fn eq(&self, other: &Symbol) -> bool {
105 *self == other.as_ref()
106 }
107}
108
109impl Borrow<str> for Symbol {
110 fn borrow(&self) -> &str {
111 self.0.as_ref()
112 }
113}
114
115impl Display for Symbol {
116 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
117 self.0.fmt(f)
118 }
119}
120
121impl<'de> serde::Deserialize<'de> for Symbol {
122 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123 where
124 D: serde::Deserializer<'de>,
125 {
126 let name = <String as serde::Deserialize>::deserialize(deserializer)?;
127 Ok(Symbol::intern(&name))
128 }
129}
130
131static INTERNER: OnceLock<Mutex<HashMap<String, Weak<str>>>> = OnceLock::new();
143
144const PRUNE_DEAD_SYMBOLS_MIN_LEN: usize = 1024;
145
146fn prune_dead_symbols_if_needed(table: &mut HashMap<String, Weak<str>>) {
147 let len = table.len();
148 if len >= PRUNE_DEAD_SYMBOLS_MIN_LEN && len.is_power_of_two() {
149 table.retain(|_, symbol| symbol.strong_count() > 0);
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use std::{
157 collections::{BTreeMap, HashMap},
158 hash::{DefaultHasher, Hash, Hasher},
159 thread,
160 };
161
162 #[test]
163 fn deserializing_symbol_interns_it() {
164 let interned = Symbol::intern("alpha");
165 let decoded: Symbol = serde_json::from_str("\"alpha\"").unwrap();
166
167 assert!(Arc::ptr_eq(&interned.0, &decoded.0));
168 }
169
170 #[test]
171 fn serializing_symbol_stays_string_shaped() {
172 let encoded = serde_json::to_string(&Symbol::intern("alpha")).unwrap();
173
174 assert_eq!(encoded, "\"alpha\"");
175 }
176
177 #[test]
178 fn symbol_compares_with_str_refs_on_either_side() {
179 let symbol = Symbol::intern("alpha");
180
181 assert!(symbol == "alpha");
182 assert!("alpha" == symbol);
183 }
184
185 #[test]
186 fn symbol_hash_and_order_remain_text_shaped() {
187 let alpha = Symbol::intern("alpha");
188 let same_alpha = Symbol::intern("alpha");
189 let beta = Symbol::intern("beta");
190
191 assert_eq!(symbol_hash(&alpha), symbol_hash(&same_alpha));
192 assert!(alpha < beta);
193
194 let mut hash_map = HashMap::new();
195 hash_map.insert(alpha.clone(), 1);
196 assert_eq!(hash_map.get("alpha"), Some(&1));
197
198 let mut tree_map = BTreeMap::new();
199 tree_map.insert(alpha, 1);
200 assert_eq!(tree_map.get("alpha"), Some(&1));
201 }
202
203 #[test]
204 fn concurrent_interning_returns_one_live_allocation() {
205 let symbols: Vec<_> = (0..32)
206 .map(|_| thread::spawn(|| Symbol::intern("concurrent-symbol")))
207 .map(|handle| handle.join().unwrap())
208 .collect();
209
210 let first = &symbols[0];
211 for symbol in &symbols[1..] {
212 assert_eq!(first, symbol);
213 assert!(Arc::ptr_eq(&first.0, &symbol.0));
214 }
215 }
216
217 #[test]
218 fn weak_interner_releases_payload_after_last_symbol_drops() {
219 let name = "weak-interner-releases-payload-after-last-symbol-drops";
220 let old = {
221 let symbol = Symbol::intern(name);
222 let same_symbol = Symbol::intern(name);
223
224 assert_eq!(symbol, same_symbol);
225 assert!(Arc::ptr_eq(&symbol.0, &same_symbol.0));
226
227 Arc::downgrade(&symbol.0)
228 };
229
230 assert!(old.upgrade().is_none());
231
232 let symbol = Symbol::intern(name);
233 let same_symbol = Symbol::intern(name);
234
235 assert!(old.upgrade().is_none());
236 assert_eq!(symbol, same_symbol);
237 assert!(Arc::ptr_eq(&symbol.0, &same_symbol.0));
238 }
239
240 fn symbol_hash(symbol: &Symbol) -> u64 {
241 let mut hasher = DefaultHasher::new();
242 symbol.hash(&mut hasher);
243 hasher.finish()
244 }
245}