internment/
arc.rs

1#![deny(missing_docs)]
2use ahash::RandomState;
3use std::any::{Any, TypeId};
4use std::fmt::{Debug, Display, Pointer};
5type Container<T> = DashMap<BoxRefCount<T>, (), RandomState>;
6type Untyped = &'static (dyn Any + Send + Sync + 'static);
7use std::borrow::Borrow;
8use std::convert::AsRef;
9use std::ffi::OsStr;
10use std::hash::{Hash, Hasher};
11use std::ops::Deref;
12use std::path::Path;
13use std::sync::atomic::AtomicUsize;
14use std::sync::atomic::Ordering;
15
16use dashmap::{mapref::entry::Entry, DashMap};
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20
21/// A pointer to a reference-counted interned object.
22///
23/// This type requires feature "arc".  The interned object will be held in memory only until its
24/// reference count reaches zero.
25///
26/// # Example
27/// ```rust
28/// use internment::ArcIntern;
29///
30/// let x = ArcIntern::new("hello");
31/// let y = ArcIntern::new("world");
32/// assert_ne!(x, y);
33/// assert_eq!(x, ArcIntern::new("hello"));
34/// assert_eq!(*x, "hello"); // dereference an ArcIntern like a pointer
35/// ```
36///
37/// # Example with owned `String` data
38///
39/// ```rust
40/// use internment::ArcIntern;
41///
42/// let x = ArcIntern::new("hello".to_string());
43/// let y = ArcIntern::<String>::from_ref("world");
44/// assert_eq!(x, ArcIntern::from_ref("hello"));
45/// assert_eq!(&*x, "hello"); // dereference an ArcIntern like a pointer
46/// ```
47#[cfg_attr(docsrs, doc(cfg(feature = "arc")))]
48pub struct ArcIntern<T: ?Sized + Eq + Hash + Send + Sync + 'static> {
49    pub(crate) pointer: std::ptr::NonNull<RefCount<T>>,
50}
51
52#[cfg(feature = "deepsize")]
53impl<T: ?Sized + Eq + Hash + Send + Sync + 'static> deepsize::DeepSizeOf for ArcIntern<T> {
54    #[inline(always)]
55    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
56        0
57    }
58}
59
60#[cfg_attr(docsrs, doc(cfg(all(feature = "deepsize", feature = "arc"))))]
61/// Return the memory used by all interned objects of the given type.
62#[cfg(feature = "deepsize")]
63pub fn deep_size_of_arc_interned<
64    T: ?Sized + Eq + Hash + Send + Sync + 'static + deepsize::DeepSizeOf,
65>() -> usize {
66    let x = ArcIntern::<T>::get_container();
67    let pointers = x.capacity() * std::mem::size_of::<BoxRefCount<T>>();
68    let heap_memory = x
69        .iter()
70        .map(|n| std::mem::size_of::<usize>() + n.key().0.data.deep_size_of())
71        .sum::<usize>();
72    pointers + heap_memory
73}
74
75unsafe impl<T: ?Sized + Eq + Hash + Send + Sync> Send for ArcIntern<T> {}
76unsafe impl<T: ?Sized + Eq + Hash + Send + Sync> Sync for ArcIntern<T> {}
77
78#[derive(Debug)]
79pub(crate) struct RefCount<T: ?Sized> {
80    pub(crate) count: AtomicUsize,
81    pub(crate) data: T,
82}
83
84impl<T: ?Sized + Eq> Eq for RefCount<T> {}
85impl<T: ?Sized + PartialEq> PartialEq for RefCount<T> {
86    #[inline]
87    fn eq(&self, other: &Self) -> bool {
88        self.data == other.data
89    }
90}
91impl<T: ?Sized + Hash> Hash for RefCount<T> {
92    #[inline]
93    fn hash<H: Hasher>(&self, hasher: &mut H) {
94        self.data.hash(hasher)
95    }
96}
97
98#[derive(Eq, PartialEq)]
99pub(crate) struct BoxRefCount<T: ?Sized>(pub Box<RefCount<T>>);
100impl<T: ?Sized + Hash> Hash for BoxRefCount<T> {
101    #[inline]
102    fn hash<H: Hasher>(&self, hasher: &mut H) {
103        self.0.data.hash(hasher)
104    }
105}
106
107impl<T> BoxRefCount<T> {
108    #[inline(always)]
109    fn into_inner(self) -> T {
110        self.0.data
111    }
112}
113
114impl<T: ?Sized> Borrow<T> for BoxRefCount<T> {
115    #[inline(always)]
116    fn borrow(&self) -> &T {
117        &self.0.data
118    }
119}
120impl<T: ?Sized> Borrow<RefCount<T>> for BoxRefCount<T> {
121    #[inline(always)]
122    fn borrow(&self) -> &RefCount<T> {
123        &self.0
124    }
125}
126impl<T: ?Sized> Deref for BoxRefCount<T> {
127    type Target = T;
128    #[inline(always)]
129    fn deref(&self) -> &Self::Target {
130        &self.0.data
131    }
132}
133
134impl<T: ?Sized + Eq + Hash + Send + Sync + 'static> ArcIntern<T> {
135    #[inline(always)]
136    fn get_pointer(&self) -> *const RefCount<T> {
137        self.pointer.as_ptr()
138    }
139    pub(crate) fn get_container() -> &'static Container<T> {
140        use once_cell::sync::OnceCell;
141        static ARC_CONTAINERS: OnceCell<DashMap<TypeId, Untyped, RandomState>> = OnceCell::new();
142
143        // make some shortcuts to speed up get_container.
144        macro_rules! common_containers {
145            ($($t:ty),*) => {
146                $(
147                // hopefully this will be optimized away by compiler for types that are not matched.
148                // for matched types, this completely avoids the need to look up dashmap.
149                if TypeId::of::<T>() == TypeId::of::<$t>() {
150                    static CONTAINER: OnceCell<Container<$t>> = OnceCell::new();
151                    let c: &'static Container<$t> = CONTAINER.get_or_init(|| Container::with_hasher(RandomState::new()));
152                    // SAFETY: we just compared to make sure `T` == `$t`.
153                    // This converts Container<$t> to Container<T> to make the compiler happy.
154                    return unsafe { &*((c as *const Container<$t>).cast::<Container<T>>()) };
155                }
156                )*
157            };
158        }
159        common_containers!(str, String);
160
161        let type_map = ARC_CONTAINERS.get_or_init(|| DashMap::with_hasher(RandomState::new()));
162        // Prefer taking the read lock to reduce contention, only use entry api if necessary.
163        let boxed = if let Some(boxed) = type_map.get(&TypeId::of::<T>()) {
164            boxed
165        } else {
166            type_map
167                .entry(TypeId::of::<T>())
168                .or_insert_with(|| {
169                    Box::leak(Box::new(Container::<T>::with_hasher(RandomState::new())))
170                })
171                .downgrade()
172        };
173        (*boxed).downcast_ref().unwrap()
174    }
175    /// Intern a value from a reference with atomic reference counting.
176    ///
177    /// If this value has not previously been
178    /// interned, then `new` will allocate a spot for the value on the
179    /// heap and generate that value using `T::from(val)`.
180    pub fn from_ref<'a, Q: ?Sized + Eq + Hash + 'a>(val: &'a Q) -> ArcIntern<T>
181    where
182        T: Borrow<Q> + From<&'a Q>,
183    {
184        // No reference only fast-path as
185        // the trait `std::borrow::Borrow<Q>` is not implemented for `Arc<T>`
186        Self::new(val.into())
187    }
188    /// See how many objects have been interned.  This may be helpful
189    /// in analyzing memory use.
190    pub fn num_objects_interned() -> usize {
191        Self::get_container().len()
192    }
193    /// Return the number of counts for this pointer.
194    pub fn refcount(&self) -> usize {
195        unsafe { self.pointer.as_ref().count.load(Ordering::Acquire) }
196    }
197
198    /// Only for benchmarking, this will cause problems
199    #[cfg(feature = "bench")]
200    pub fn benchmarking_only_clear_interns() {}
201}
202
203impl<T: Eq + Hash + Send + Sync + 'static> ArcIntern<T> {
204    /// Intern a value.  If this value has not previously been
205    /// interned, then `new` will allocate a spot for the value on the
206    /// heap.  Otherwise, it will return a pointer to the object
207    /// previously allocated.
208    ///
209    /// Note that `ArcIntern::new` is a bit slow, since it needs to check
210    /// a `DashMap` which is protected by internal sharded locks.
211    pub fn new(mut val: T) -> ArcIntern<T> {
212        loop {
213            let m = Self::get_container();
214            if let Some(b) = m.get_mut(&val) {
215                let b = b.key();
216                // First increment the count.  We are holding the write mutex here.
217                // Has to be the write mutex to avoid a race
218                let oldval = b.0.count.fetch_add(1, Ordering::SeqCst);
219                if oldval != 0 {
220                    // we can only use this value if the value is not about to be freed
221                    return ArcIntern {
222                        pointer: std::ptr::NonNull::from(b.0.borrow()),
223                    };
224                } else {
225                    // we have encountered a race condition here.
226                    // we will just wait for the object to finish
227                    // being freed.
228                    b.0.count.fetch_sub(1, Ordering::SeqCst);
229                }
230            } else {
231                let b = Box::new(RefCount {
232                    count: AtomicUsize::new(1),
233                    data: val,
234                });
235                match m.entry(BoxRefCount(b)) {
236                    Entry::Vacant(e) => {
237                        // We can insert, all is good
238                        let p = ArcIntern {
239                            pointer: std::ptr::NonNull::from(e.key().0.borrow()),
240                        };
241                        e.insert(());
242                        return p;
243                    }
244                    Entry::Occupied(e) => {
245                        // Race, map already has data, go round again
246                        let box_ref_count = e.into_key();
247                        val = box_ref_count.into_inner();
248                    }
249                }
250            }
251            // yield so that the object can finish being freed,
252            // and then we will be able to intern a new copy.
253            std::thread::yield_now();
254        }
255    }
256}
257
258impl<T: ?Sized + Eq + Hash + Send + Sync + 'static> Clone for ArcIntern<T> {
259    fn clone(&self) -> Self {
260        // First increment the count.  Using a relaxed ordering is
261        // alright here, as knowledge of the original reference
262        // prevents other threads from erroneously deleting the
263        // object.  (See `std::sync::Arc` documentation for more
264        // explanation.)
265        unsafe { self.pointer.as_ref().count.fetch_add(1, Ordering::Relaxed) };
266        ArcIntern {
267            pointer: self.pointer,
268        }
269    }
270}
271
272#[cfg(not(test))]
273fn yield_on_tests() {}
274#[cfg(test)]
275fn yield_on_tests() {
276    std::thread::yield_now();
277}
278
279impl<T: ?Sized + Eq + Hash + Send + Sync> Drop for ArcIntern<T> {
280    fn drop(&mut self) {
281        // (Quoting from std::sync::Arc again): Because `fetch_sub` is
282        // already atomic, we do not need to synchronize with other
283        // threads unless we are going to delete the object. This same
284        // logic applies to the below `fetch_sub` to the `weak` count.
285        let count_was = unsafe { self.pointer.as_ref().count.fetch_sub(1, Ordering::SeqCst) };
286        if count_was == 1 {
287            // The following causes the code only when testing, to yield
288            // control before taking the mutex, which should make it
289            // easier to trigger any race condition (and hopefully won't
290            // mask any other race conditions).
291            yield_on_tests();
292            // (Quoting from std::sync::Arc again): This fence is
293            // needed to prevent reordering of use of the data and
294            // deletion of the data.  Because it is marked `Release`,
295            // the decreasing of the reference count synchronizes with
296            // this `Acquire` fence. This means that use of the data
297            // happens before decreasing the reference count, which
298            // happens before this fence, which happens before the
299            // deletion of the data.
300            std::sync::atomic::fence(Ordering::SeqCst);
301
302            // removed is declared before m, so the mutex guard will be
303            // dropped *before* the removed content is dropped, since it
304            // might need to lock the mutex.
305            #[allow(clippy::needless_late_init)]
306            let _remove;
307            let m = Self::get_container();
308            _remove = m.remove(unsafe { self.pointer.as_ref() });
309        }
310    }
311}
312
313impl<T: ?Sized + Send + Sync + Hash + Eq> AsRef<T> for ArcIntern<T> {
314    #[inline(always)]
315    fn as_ref(&self) -> &T {
316        unsafe { &self.pointer.as_ref().data }
317    }
318}
319
320macro_rules! impl_as_ref {
321    ($from:ty => $to:ty) => {
322        impl AsRef<$to> for ArcIntern<$from> {
323            #[inline(always)]
324            fn as_ref(&self) -> &$to {
325                let ptr: &$from = &*self;
326                ptr.as_ref()
327            }
328        }
329    };
330}
331
332impl_as_ref!(str => OsStr);
333impl_as_ref!(str => Path);
334impl_as_ref!(OsStr => Path);
335impl_as_ref!(Path => OsStr);
336
337impl<T: ?Sized + Eq + Hash + Send + Sync> Deref for ArcIntern<T> {
338    type Target = T;
339    #[inline(always)]
340    fn deref(&self) -> &T {
341        self.as_ref()
342    }
343}
344
345impl<T: ?Sized + Eq + Hash + Send + Sync + Display> Display for ArcIntern<T> {
346    #[inline]
347    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
348        self.deref().fmt(f)
349    }
350}
351
352impl<T: ?Sized + Eq + Hash + Send + Sync> Pointer for ArcIntern<T> {
353    #[inline]
354    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
355        Pointer::fmt(&self.get_pointer(), f)
356    }
357}
358
359/// The hash implementation returns the hash of the pointer
360/// value, not the hash of the value pointed to.  This should
361/// be irrelevant, since there is a unique pointer for every
362/// value, but it *is* observable, since you could compare the
363/// hash of the pointer with hash of the data itself.
364impl<T: ?Sized + Eq + Hash + Send + Sync> Hash for ArcIntern<T> {
365    #[inline]
366    fn hash<H: Hasher>(&self, state: &mut H) {
367        self.get_pointer().hash(state);
368    }
369}
370
371impl<T: ?Sized + Eq + Hash + Send + Sync> PartialEq for ArcIntern<T> {
372    #[inline(always)]
373    fn eq(&self, other: &Self) -> bool {
374        std::ptr::eq(self.get_pointer(), other.get_pointer())
375    }
376}
377impl<T: ?Sized + Eq + Hash + Send + Sync> Eq for ArcIntern<T> {}
378
379impl<T: ?Sized + Eq + Hash + Send + Sync + PartialOrd> PartialOrd for ArcIntern<T> {
380    #[inline]
381    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
382        self.as_ref().partial_cmp(other)
383    }
384    #[inline]
385    fn lt(&self, other: &Self) -> bool {
386        self.as_ref().lt(other)
387    }
388    #[inline]
389    fn le(&self, other: &Self) -> bool {
390        self.as_ref().le(other)
391    }
392    #[inline]
393    fn gt(&self, other: &Self) -> bool {
394        self.as_ref().gt(other)
395    }
396    #[inline]
397    fn ge(&self, other: &Self) -> bool {
398        self.as_ref().ge(other)
399    }
400}
401impl<T: ?Sized + Eq + Hash + Send + Sync + Ord> Ord for ArcIntern<T> {
402    #[inline]
403    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
404        self.as_ref().cmp(other)
405    }
406}
407
408#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
409#[cfg(feature = "serde")]
410impl<T: ?Sized + Eq + Hash + Send + Sync + Serialize> Serialize for ArcIntern<T> {
411    #[inline]
412    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
413        self.as_ref().serialize(serializer)
414    }
415}
416
417impl<T: Eq + Hash + Send + Sync + 'static> From<T> for ArcIntern<T> {
418    #[inline]
419    fn from(t: T) -> Self {
420        ArcIntern::new(t)
421    }
422}
423
424impl<T: Eq + Hash + Send + Sync + Default + 'static> Default for ArcIntern<T> {
425    #[inline]
426    fn default() -> Self {
427        ArcIntern::new(Default::default())
428    }
429}
430
431#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
432#[cfg(feature = "serde")]
433impl<'de, T> Deserialize<'de> for ArcIntern<T>
434where
435    T: Eq + Hash + Send + Sync + 'static + Deserialize<'de>,
436{
437    #[inline]
438    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
439        T::deserialize(deserializer).map(|x: T| Self::new(x))
440    }
441}
442
443#[cfg(test)]
444mod arc_test {
445    use super::ArcIntern;
446    use super::{Borrow, Deref};
447    #[test]
448    fn eq_string() {
449        assert_eq!(ArcIntern::new("hello"), ArcIntern::new("hello"));
450        assert_ne!(ArcIntern::new("goodbye"), ArcIntern::new("farewell"));
451    }
452    #[test]
453    fn display() {
454        let world = ArcIntern::new("world");
455        println!("Hello {}", world);
456    }
457    #[test]
458    fn debug() {
459        let world = ArcIntern::new("world");
460        println!("Hello {:?}", world);
461    }
462    #[test]
463    fn has_default() {
464        assert_eq!(
465            ArcIntern::<Option<String>>::default(),
466            ArcIntern::<Option<String>>::new(None)
467        );
468    }
469    #[test]
470    fn can_clone() {
471        assert_eq!(
472            ArcIntern::<Option<String>>::default().clone(),
473            ArcIntern::<Option<String>>::new(None)
474        );
475    }
476    #[test]
477    fn has_borrow() {
478        let x = ArcIntern::<Option<String>>::default();
479        let b: &Option<String> = x.borrow();
480        assert_eq!(b, ArcIntern::<Option<String>>::new(None).as_ref());
481    }
482    #[test]
483    fn has_deref() {
484        let x = ArcIntern::<Option<String>>::default();
485        let b: &Option<String> = x.as_ref();
486        assert_eq!(b, ArcIntern::<Option<String>>::new(None).deref());
487    }
488}
489
490#[test]
491fn test_arcintern_freeing() {
492    assert_eq!(ArcIntern::<i32>::num_objects_interned(), 0);
493    assert_eq!(ArcIntern::new(5), ArcIntern::new(5));
494    {
495        let _interned = ArcIntern::new(6);
496        assert_eq!(ArcIntern::<i32>::num_objects_interned(), 1);
497    }
498    {
499        let _interned = ArcIntern::new(6);
500        assert_eq!(ArcIntern::<i32>::num_objects_interned(), 1);
501    }
502    {
503        let _interned = ArcIntern::new(7);
504        assert_eq!(ArcIntern::<i32>::num_objects_interned(), 1);
505    }
506
507    let six = ArcIntern::new(6);
508
509    {
510        let _interned = ArcIntern::new(7);
511        assert_eq!(ArcIntern::<i32>::num_objects_interned(), 2);
512    }
513    assert_eq!(ArcIntern::new(6), six);
514}
515
516#[test]
517fn test_arcintern_nested_drop() {
518    #[derive(PartialEq, Eq, Hash)]
519    enum Nat {
520        Zero,
521        Successor(ArcIntern<Nat>),
522    }
523    let zero = ArcIntern::new(Nat::Zero);
524    let _one = ArcIntern::new(Nat::Successor(zero));
525}
526
527impl<T: ?Sized + Eq + Hash + Send + Sync + Debug> Debug for ArcIntern<T> {
528    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
529        self.deref().fmt(f)
530    }
531}
532
533#[cfg(test)]
534#[derive(Eq, PartialEq, Hash)]
535pub struct TestStructCount(String, u64, std::sync::Arc<bool>);
536
537#[cfg(test)]
538#[derive(Eq, PartialEq, Hash)]
539pub struct TestStruct(String, u64);
540
541// Quickly create and destroy a small number of interned objects from
542// multiple threads.
543#[test]
544fn multithreading1() {
545    use std::sync::Arc;
546    use std::thread;
547    let mut thandles = vec![];
548    let drop_check = Arc::new(true);
549    for _i in 0..10 {
550        thandles.push(thread::spawn({
551            let drop_check = drop_check.clone();
552            move || {
553                for _i in 0..100_000 {
554                    let _interned1 =
555                        ArcIntern::new(TestStructCount("foo".to_string(), 5, drop_check.clone()));
556                    let _interned2 =
557                        ArcIntern::new(TestStructCount("bar".to_string(), 10, drop_check.clone()));
558                }
559            }
560        }));
561    }
562    for h in thandles.into_iter() {
563        h.join().unwrap()
564    }
565    assert_eq!(Arc::strong_count(&drop_check), 1);
566    assert_eq!(ArcIntern::<TestStructCount>::num_objects_interned(), 0);
567}
568
569#[test]
570fn arc_has_niche() {
571    assert_eq!(
572        std::mem::size_of::<ArcIntern<String>>(),
573        std::mem::size_of::<usize>(),
574    );
575    assert_eq!(
576        std::mem::size_of::<Option<ArcIntern<String>>>(),
577        std::mem::size_of::<usize>(),
578    );
579}
580
581#[test]
582fn like_doctest_arcintern() {
583    let x = ArcIntern::new("hello".to_string());
584    let y = ArcIntern::<String>::from_ref("world");
585    assert_ne!(x, y);
586    assert_eq!(x, ArcIntern::from_ref("hello"));
587    assert_eq!(y, ArcIntern::from_ref("world"));
588    assert_eq!(&*x, "hello"); // dereference a Intern like a pointer\
589}
590
591/// This function illustrates that dashmap has a failure under miri
592///
593/// This prevents us from using miri to validate our unsafe code.
594#[test]
595fn just_dashmap() {
596    let m: DashMap<Box<&'static str>, ()> = DashMap::new();
597    match m.entry(Box::new("hello")) {
598        Entry::Vacant(e) => {
599            e.insert(());
600        }
601        Entry::Occupied(_) => {
602            panic!("Should not exist yet");
603        }
604    };
605}