1use ahash::RandomState;
54use dashmap::DashMap;
55use once_cell::sync::OnceCell;
56use serde::{Deserialize, Deserializer, Serialize, Serializer};
57use std::any::{Any, TypeId};
58use std::borrow::Borrow;
59use std::fmt::Display;
60use std::hash::{Hash, Hasher};
61use std::ops::Deref;
62use std::sync::Arc;
63
64#[derive(Debug)]
80pub struct ArcIntern<T: Eq + Hash + Send + Sync + 'static + ?Sized> {
81 arc: Arc<T>,
82}
83
84type Container<T> = DashMap<Arc<T>, (), RandomState>;
85
86static CONTAINER: OnceCell<DashMap<TypeId, Box<dyn Any + Send + Sync>, RandomState>> =
87 OnceCell::new();
88
89impl<T: Eq + Hash + Send + Sync + 'static + ?Sized> ArcIntern<T> {
90 fn from_arc(val: Arc<T>) -> ArcIntern<T> {
91 let type_map = CONTAINER.get_or_init(|| DashMap::with_hasher(RandomState::new()));
92
93 let boxed = if let Some(boxed) = type_map.get(&TypeId::of::<T>()) {
95 boxed
96 } else {
97 type_map
98 .entry(TypeId::of::<T>())
99 .or_insert_with(|| Box::new(Container::<T>::with_hasher(RandomState::new())))
100 .downgrade()
101 };
102
103 let m: &Container<T> = boxed.value().downcast_ref::<Container<T>>().unwrap();
104 let b = m.entry(val).or_insert(());
105 return ArcIntern {
106 arc: b.key().clone(),
107 };
108 }
109
110 pub fn num_objects_interned() -> usize {
113 if let Some(m) = CONTAINER
114 .get()
115 .and_then(|type_map| type_map.get(&TypeId::of::<T>()))
116 {
117 return m.downcast_ref::<Container<T>>().unwrap().len();
118 }
119 0
120 }
121 pub fn refcount(&self) -> usize {
123 Arc::strong_count(&self.arc) - 1
126 }
127}
128
129impl<T: Eq + Hash + Send + Sync + 'static> ArcIntern<T> {
130 pub fn new(val: T) -> ArcIntern<T> {
138 Self::from_arc(Arc::new(val))
139 }
140}
141
142impl<T: Eq + Hash + Send + Sync + 'static + ?Sized> Clone for ArcIntern<T> {
143 fn clone(&self) -> Self {
144 ArcIntern {
145 arc: self.arc.clone(),
146 }
147 }
148}
149
150impl<T: Eq + Hash + Send + Sync + ?Sized> Drop for ArcIntern<T> {
151 fn drop(&mut self) {
152 if let Some(m) = CONTAINER
153 .get()
154 .and_then(|type_map| type_map.get(&TypeId::of::<T>()))
155 {
156 let m: &Container<T> = m.downcast_ref::<Container<T>>().unwrap();
157 m.remove_if(&self.arc, |k, _v| {
158 Arc::strong_count(k) == 2
162 });
163 }
164 }
165}
166
167impl<T: Send + Sync + Hash + Eq + ?Sized> AsRef<T> for ArcIntern<T> {
168 fn as_ref(&self) -> &T {
169 self.arc.as_ref()
170 }
171}
172impl<T: Eq + Hash + Send + Sync + ?Sized> Borrow<T> for ArcIntern<T> {
173 fn borrow(&self) -> &T {
174 self.as_ref()
175 }
176}
177impl<T: Eq + Hash + Send + Sync + ?Sized> Deref for ArcIntern<T> {
178 type Target = T;
179 fn deref(&self) -> &T {
180 self.as_ref()
181 }
182}
183
184impl<T: Eq + Hash + Send + Sync + Display + ?Sized> Display for ArcIntern<T> {
185 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
186 self.deref().fmt(f)
187 }
188}
189
190impl<T: Eq + Hash + Send + Sync + 'static + ?Sized> From<Box<T>> for ArcIntern<T> {
191 fn from(b: Box<T>) -> Self {
192 Self::from_arc(Arc::from(b))
193 }
194}
195
196impl<'a, T> From<&'a T> for ArcIntern<T>
197where
198 T: Eq + Hash + Send + Sync + 'static + ?Sized,
199 Arc<T>: From<&'a T>,
200{
201 fn from(t: &'a T) -> Self {
202 Self::from_arc(Arc::from(t))
203 }
204}
205
206impl<T: Eq + Hash + Send + Sync + 'static> From<T> for ArcIntern<T> {
207 fn from(t: T) -> Self {
208 ArcIntern::new(t)
209 }
210}
211impl<T: Eq + Hash + Send + Sync + Default + 'static + ?Sized> Default for ArcIntern<T> {
212 fn default() -> ArcIntern<T> {
213 ArcIntern::new(Default::default())
214 }
215}
216
217impl<T: Eq + Hash + Send + Sync + ?Sized> Hash for ArcIntern<T> {
218 fn hash<H: Hasher>(&self, state: &mut H) {
220 let borrow: &T = self.borrow();
221 borrow.hash(state);
222 }
223}
224
225impl<T: Eq + Hash + Send + Sync + ?Sized> PartialEq for ArcIntern<T> {
227 fn eq(&self, other: &ArcIntern<T>) -> bool {
228 Arc::ptr_eq(&self.arc, &other.arc)
229 }
230}
231impl<T: Eq + Hash + Send + Sync + ?Sized> Eq for ArcIntern<T> {}
232
233impl<T: Eq + Hash + Send + Sync + PartialOrd + ?Sized> PartialOrd for ArcIntern<T> {
234 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
235 self.as_ref().partial_cmp(other)
236 }
237 fn lt(&self, other: &Self) -> bool {
238 self.as_ref().lt(other)
239 }
240 fn le(&self, other: &Self) -> bool {
241 self.as_ref().le(other)
242 }
243 fn gt(&self, other: &Self) -> bool {
244 self.as_ref().gt(other)
245 }
246 fn ge(&self, other: &Self) -> bool {
247 self.as_ref().ge(other)
248 }
249}
250
251impl<T: Eq + Hash + Send + Sync + Ord + ?Sized> Ord for ArcIntern<T> {
252 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
253 self.as_ref().cmp(other)
254 }
255}
256
257impl<T: Eq + Hash + Send + Sync + Serialize + ?Sized> Serialize for ArcIntern<T> {
258 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
259 self.as_ref().serialize(serializer)
260 }
261}
262
263impl<'de, T: Eq + Hash + Send + Sync + 'static + ?Sized + Deserialize<'de>> Deserialize<'de>
264 for ArcIntern<T>
265{
266 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
267 T::deserialize(deserializer).map(Self::new)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use crate::ArcIntern;
274 use std::collections::HashMap;
275 use std::sync::Arc;
276 use std::thread;
277
278 #[test]
280 fn basic() {
281 assert_eq!(ArcIntern::new("foo"), ArcIntern::new("foo"));
282 assert_ne!(ArcIntern::new("foo"), ArcIntern::new("bar"));
283 assert_eq!(ArcIntern::<&str>::num_objects_interned(), 0);
285
286 let _interned1 = ArcIntern::new("foo".to_string());
287 {
288 let interned2 = ArcIntern::new("foo".to_string());
289 let interned3 = ArcIntern::new("bar".to_string());
290
291 assert_eq!(interned2.refcount(), 2);
292 assert_eq!(interned3.refcount(), 1);
293 assert_eq!(ArcIntern::<String>::num_objects_interned(), 2);
295 }
296
297 assert_eq!(ArcIntern::<String>::num_objects_interned(), 1);
299 }
300
301 #[test]
304 fn sorting() {
305 let mut interned_vals = vec![
306 ArcIntern::new(4),
307 ArcIntern::new(2),
308 ArcIntern::new(5),
309 ArcIntern::new(0),
310 ArcIntern::new(1),
311 ArcIntern::new(3),
312 ];
313 interned_vals.sort();
314 let sorted: Vec<String> = interned_vals.iter().map(|v| format!("{}", v)).collect();
315 assert_eq!(&sorted.join(","), "0,1,2,3,4,5");
316 }
317
318 #[derive(Eq, PartialEq, Hash)]
319 pub struct TestStruct2(String, u64);
320
321 #[test]
322 fn sequential() {
323 for _i in 0..10_000 {
324 let mut interned = Vec::with_capacity(100);
325 for j in 0..100 {
326 interned.push(ArcIntern::new(TestStruct2("foo".to_string(), j)));
327 }
328 }
329
330 assert_eq!(ArcIntern::<TestStruct2>::num_objects_interned(), 0);
331 }
332
333 #[derive(Eq, PartialEq, Hash)]
334 pub struct TestStruct(String, u64, Arc<bool>);
335
336 #[test]
339 fn multithreading1() {
340 let mut thandles = vec![];
341 let drop_check = Arc::new(true);
342 for _i in 0..10 {
343 let t = thread::spawn({
344 let drop_check = drop_check.clone();
345 move || {
346 for _i in 0..100_000 {
347 let interned1 =
348 ArcIntern::new(TestStruct("foo".to_string(), 5, drop_check.clone()));
349 let _interned2 =
350 ArcIntern::new(TestStruct("bar".to_string(), 10, drop_check.clone()));
351 let mut m = HashMap::new();
352 m.insert(interned1, ());
354 }
355 }
356 });
357 thandles.push(t);
358 }
359 for h in thandles.into_iter() {
360 h.join().unwrap()
361 }
362 assert_eq!(Arc::strong_count(&drop_check), 1);
363 assert_eq!(ArcIntern::<TestStruct>::num_objects_interned(), 0);
364 }
365
366 #[test]
367 fn test_unsized() {
368 assert_eq!(
369 ArcIntern::<[usize]>::from(&[1, 2, 3][..]),
370 ArcIntern::from(&[1, 2, 3][..])
371 );
372 assert_ne!(
373 ArcIntern::<[usize]>::from(&[1, 2][..]),
374 ArcIntern::from(&[1, 2, 3][..])
375 );
376 assert_eq!(ArcIntern::<[usize]>::num_objects_interned(), 0);
378 }
379}