Skip to main content

any_intern/
any.rs

1use super::common::{self, Interned, RawInterned, UnsafeLock};
2use bumpalo::Bump;
3use hashbrown::{hash_table::Entry, HashTable};
4use std::{
5    alloc::Layout,
6    any::TypeId,
7    borrow,
8    cell::Cell,
9    hash::Hasher,
10    hash::{BuildHasher, Hash},
11    mem,
12    ptr::NonNull,
13};
14
15/// A type-erased interner for storing and deduplicating values of a single type.
16///
17/// This interner is simply a wrapper of [`AnyInternSet`] with interior mutability. If you need a
18/// collection of interners for various types like a hash map of interners, then consider using the
19/// `AnyInterSet` with a container providing interior mutability such as [`ManualMutex`].
20///
21/// # Examples
22///
23/// ```
24/// use any_intern::AnyInterner;
25///
26/// #[derive(PartialEq, Eq, Hash, Debug)]
27/// struct A(u32);
28///
29/// let interner = AnyInterner::of::<A>();
30///
31/// unsafe {
32///     let a1 = interner.intern(A(42));
33///     let a2 = interner.intern(A(42));
34///     assert_eq!(a1, a2); // Same value, same reference
35///
36///     let a3 = interner.intern(A(99));
37///     assert_ne!(a1, a3); // Different values, different references
38/// }
39/// ```
40///
41/// # Safety
42///
43/// Many methods in `AnyInterner` are marked as `unsafe` because they rely on the caller to ensure
44/// that the correct type is used when interacting with the interner. Using an incorrect type can
45/// lead to undefined behavior.
46pub struct AnyInterner<S = fxhash::FxBuildHasher> {
47    inner: UnsafeLock<AnyInternSet<S>>,
48}
49
50impl AnyInterner {
51    pub fn of<K: 'static>() -> Self {
52        // Safety: Only one instance
53        let inner = unsafe { UnsafeLock::new(AnyInternSet::of::<K>()) };
54        Self { inner }
55    }
56}
57
58impl<S: BuildHasher> AnyInterner<S> {
59    pub fn with_hasher<K: 'static>(hash_builder: S) -> Self {
60        // Safety: Only one instance
61        let inner = unsafe { UnsafeLock::new(AnyInternSet::with_hasher::<K>(hash_builder)) };
62        Self { inner }
63    }
64
65    /// Returns number of values the interner contains.
66    pub fn len(&self) -> usize {
67        self.with_inner(|set| set.len())
68    }
69
70    /// Returns true if the interner is empty.
71    pub fn is_empty(&self) -> bool {
72        self.with_inner(|set| set.is_empty())
73    }
74
75    /// Stores a value in the interner, returning a reference to the interned value.
76    ///
77    /// This method inserts the given value into the interner if it does not already exist. If the
78    /// value already exists, a reference to the existing value is returned.
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use any_intern::AnyInterner;
84    ///
85    /// #[derive(PartialEq, Eq, Hash, Debug)]
86    /// struct A(u32);
87    ///
88    /// let interner = AnyInterner::of::<A>();
89    ///
90    /// unsafe {
91    ///     let a1 = interner.intern(A(42));
92    ///     let a2 = interner.intern(A(42));
93    ///     assert_eq!(a1, a2); // Same value, same reference
94    ///     assert_eq!(a1.raw().as_ptr(), a2.raw().as_ptr());
95    /// }
96    /// ```
97    ///
98    /// # Safety
99    ///
100    /// Undefined behavior if incorrect type `K` is given.
101    pub unsafe fn intern<K>(&self, value: K) -> Interned<'_, K>
102    where
103        K: Hash + Eq + 'static,
104    {
105        self.with_inner(|set| unsafe { set.intern(value) })
106    }
107
108    /// Stores a value in the interner, creating it only if it does not already exist.
109    ///
110    /// This method allows you to provide a key and a closure to generate the value. If the key
111    /// already exists in the interner, the closure is not called, and a reference to the existing
112    /// value is returned. If the key does not exist, the closure is called to create the value,
113    /// which is then stored in the interner.
114    ///
115    /// This method is useful when the value is expensive to compute, as it avoids unnecessary
116    /// computation if the value already exists.
117    ///
118    /// # Examples
119    ///
120    /// ```
121    /// use any_intern::AnyInterner;
122    ///
123    ///
124    /// #[derive(PartialEq, Eq, Hash, Debug)]
125    /// struct A(i32);
126    ///
127    /// impl std::borrow::Borrow<i32> for A {
128    ///     fn borrow(&self) -> &i32 {
129    ///         &self.0
130    ///     }
131    /// }
132    ///
133    /// let interner = AnyInterner::of::<A>();
134    ///
135    /// unsafe {
136    ///     let a = interner.intern_with(&42, || A(42));
137    ///     assert_eq!(interner.len(), 1);
138    ///     assert_eq!(*a, A(42));
139    ///
140    ///     let b = interner.intern_with(&42, || A(99)); // Closure is not called
141    ///     assert_eq!(interner.len(), 1);
142    ///     assert_eq!(*b, A(42));
143    ///
144    ///     let c = interner.intern_with(&43, || A(43));
145    ///     assert_eq!(interner.len(), 2);
146    ///     assert_eq!(*c, A(43));
147    /// }
148    /// ```
149    ///
150    /// # Safety
151    ///
152    /// Undefined behavior if incorrect type `K` is given.
153    pub unsafe fn intern_with<K, Q, F>(&self, key: &Q, make_value: F) -> Interned<'_, K>
154    where
155        K: borrow::Borrow<Q> + 'static,
156        Q: Hash + Eq + ?Sized,
157        F: FnOnce() -> K,
158    {
159        self.with_inner(|set| unsafe { set.intern_with(key, make_value) })
160    }
161
162    /// Retrieves a reference to a value in the interner based on the provided key.
163    ///
164    /// This method checks if a value corresponding to the given key exists in the interner. If it
165    /// exists, a reference to the interned value is returned. Otherwise, `None` is returned.
166    ///
167    /// # Eaxmples
168    ///
169    /// ```
170    /// use any_intern::AnyInterner;
171    ///
172    /// let interner = AnyInterner::of::<i32>();
173    /// unsafe {
174    ///     interner.intern(42);
175    ///     assert_eq!(interner.get::<i32, _>(&42).as_deref(), Some(&42));
176    ///     assert!(interner.get::<i32, _>(&99).is_none());
177    /// }
178    /// ```
179    ///
180    /// # Safety
181    ///
182    /// Undefined behavior if incorrect type `K` is given.
183    pub unsafe fn get<K, Q>(&self, key: &Q) -> Option<Interned<'_, K>>
184    where
185        K: borrow::Borrow<Q> + 'static,
186        Q: Hash + Eq + ?Sized,
187    {
188        self.with_inner(|set| unsafe { set.get(key) })
189    }
190
191    /// Returns true if the interner contains values of the given type.
192    pub fn is_type_of<K: 'static>(&self) -> bool {
193        self.with_inner(|set| set.is_type_of::<K>())
194    }
195
196    /// Removes all items in the interner.
197    ///
198    /// Although the interner support interior mutability, clear method requires mutable access
199    /// to the interner to invalidate all [`Interned`]s referencing the interner.
200    pub fn clear(&mut self) {
201        self.with_inner(|set| set.clear())
202    }
203
204    fn with_inner<'this, F, R>(&'this self, f: F) -> R
205    where
206        F: FnOnce(&'this mut AnyInternSet<S>) -> R,
207        R: 'this,
208    {
209        // Safety: Mutex unlocking is paired with the locking.
210        unsafe {
211            let set = self.inner.lock().as_mut();
212            let ret = f(set);
213            self.inner.unlock();
214            ret
215        }
216    }
217}
218
219/// A type-erased interning set for storing and deduplicating values of a single type without
220/// interior mutability.
221///
222/// # Examples
223///
224/// ```
225/// use any_intern::AnyInternSet;
226///
227/// #[derive(PartialEq, Eq, Hash, Debug)]
228/// struct A(u32);
229///
230/// let mut set = AnyInternSet::of::<A>();
231///
232/// unsafe {
233///     let a1 = set.intern(A(42)).raw();
234///     let a2 = set.intern(A(42)).raw();
235///     assert_eq!(a1, a2); // Same value, same reference
236///
237///     let a3 = set.intern(A(99)).raw();
238///     assert_ne!(a1, a3); // Different values, different references
239/// }
240/// ```
241///
242/// # Safety
243///
244/// Many methods in `AnyInternSet` are marked as `unsafe` because they rely on the caller to ensure
245/// that the correct type is used when interacting with the interner. Using an incorrect type can
246/// lead to undefined behavior.
247pub struct AnyInternSet<S = fxhash::FxBuildHasher> {
248    arena: AnyArena,
249    set: HashTable<RawInterned>,
250    hash_builder: S,
251}
252
253impl AnyInternSet {
254    pub fn of<K: 'static>() -> Self {
255        Self {
256            arena: AnyArena::of::<K>(),
257            set: HashTable::new(),
258            hash_builder: Default::default(),
259        }
260    }
261}
262
263impl<S: Default> AnyInternSet<S> {
264    pub fn default_of<K: 'static>() -> Self {
265        Self {
266            arena: AnyArena::of::<K>(),
267            set: HashTable::new(),
268            hash_builder: Default::default(),
269        }
270    }
271}
272
273impl<S: BuildHasher> AnyInternSet<S> {
274    pub fn with_hasher<K: 'static>(hash_builder: S) -> Self {
275        Self {
276            arena: AnyArena::of::<K>(),
277            set: HashTable::new(),
278            hash_builder,
279        }
280    }
281
282    /// Returns number of values in the set.
283    pub fn len(&self) -> usize {
284        self.arena.len()
285    }
286
287    /// Returns true if the set is empty.
288    pub fn is_empty(&self) -> bool {
289        self.len() == 0
290    }
291
292    /// Stores a value in the set, returning a reference to the value.
293    ///
294    /// This method inserts the given value into the set if it does not already exist. If the value
295    /// already exists, a reference to the existing value is returned.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// use any_intern::AnyInternSet;
301    ///
302    /// #[derive(PartialEq, Eq, Hash, Debug)]
303    /// struct A(u32);
304    ///
305    /// let mut set = AnyInternSet::of::<A>();
306    ///
307    /// unsafe {
308    ///     let a1 = set.intern(A(42)).raw();
309    ///     let a2 = set.intern(A(42)).raw();
310    ///     assert_eq!(a1, a2); // Same value, same reference
311    /// }
312    /// ```
313    ///
314    /// # Safety
315    ///
316    /// Undefined behavior if incorrect type `K` is given.
317    pub unsafe fn intern<K>(&mut self, value: K) -> Interned<'_, K>
318    where
319        K: Hash + Eq + 'static,
320    {
321        debug_assert!(self.is_type_of::<K>());
322
323        unsafe {
324            let hash = Self::hash(&self.hash_builder, &value);
325            let eq = Self::table_eq::<K, K>(&value);
326            let hasher = Self::table_hasher::<K, K>(&self.hash_builder);
327            match self.set.entry(hash, eq, hasher) {
328                Entry::Occupied(entry) => Interned::from_erased_raw(*entry.get()),
329                Entry::Vacant(entry) => {
330                    let ref_ = self.arena.alloc(value);
331                    let interned = Interned::unique(ref_);
332                    let raw = interned.erased_raw();
333                    entry.insert(raw);
334                    interned
335                }
336            }
337        }
338    }
339
340    /// Stores a value in the set, creating it only if it does not already exist.
341    ///
342    /// This method allows you to provide a key and a closure to generate the value. If the key
343    /// already exists in the set, the closure is not called, and a reference to the existing value
344    /// is returned. If the key does not exist, the closure is called to create the value, which is
345    /// then stored in the set.
346    ///
347    /// This method is useful when the value is expensive to compute, as it avoids unnecessary
348    /// computation if the value already exists.
349    ///
350    /// # Examples
351    ///
352    /// ```
353    /// use any_intern::AnyInternSet;
354    ///
355    ///
356    /// #[derive(PartialEq, Eq, Hash, Debug)]
357    /// struct A(i32);
358    ///
359    /// impl std::borrow::Borrow<i32> for A {
360    ///     fn borrow(&self) -> &i32 {
361    ///         &self.0
362    ///     }
363    /// }
364    ///
365    /// let mut set = AnyInternSet::of::<A>();
366    ///
367    /// unsafe {
368    ///     let a = set.intern_with(&42, || A(42));
369    ///     assert_eq!(*a, A(42));
370    ///     assert_eq!(set.len(), 1);
371    ///
372    ///     let b = set.intern_with(&42, || A(99)); // Closure is not called
373    ///     assert_eq!(*b, A(42));
374    ///     assert_eq!(set.len(), 1);
375    ///
376    ///     let c = set.intern_with(&43, || A(43));
377    ///     assert_eq!(*c, A(43));
378    ///     assert_eq!(set.len(), 2);
379    /// }
380    /// ```
381    ///
382    /// # Safety
383    ///
384    /// Undefined behavior if incorrect type `K` is given.
385    pub unsafe fn intern_with<K, Q, F>(&mut self, key: &Q, make_value: F) -> Interned<'_, K>
386    where
387        K: borrow::Borrow<Q> + 'static,
388        Q: Hash + Eq + ?Sized,
389        F: FnOnce() -> K,
390    {
391        debug_assert!(self.is_type_of::<K>());
392
393        unsafe {
394            let hash = Self::hash(&self.hash_builder, key);
395            let eq = Self::table_eq::<K, Q>(key);
396            let hasher = Self::table_hasher::<K, Q>(&self.hash_builder);
397            match self.set.entry(hash, eq, hasher) {
398                Entry::Occupied(entry) => Interned::from_erased_raw(*entry.get()),
399                Entry::Vacant(entry) => {
400                    let value = make_value();
401                    let ref_ = self.arena.alloc(value);
402                    let interned = Interned::unique(ref_);
403                    let raw = interned.erased_raw();
404                    entry.insert(raw);
405                    interned
406                }
407            }
408        }
409    }
410
411    /// Retrieves a reference to a value in the set based on the provided key.
412    ///
413    /// This method checks if a value corresponding to the given key exists in the set. If it
414    /// exists, a reference to the value is returned. Otherwise, `None` is returned.
415    ///
416    /// # Eaxmples
417    ///
418    /// ```
419    /// use any_intern::AnyInternSet;
420    ///
421    /// let mut set = AnyInternSet::of::<i32>();
422    /// unsafe {
423    ///     set.intern(42);
424    ///     assert_eq!(set.get::<i32, _>(&42).as_deref(), Some(&42));
425    ///     assert!(set.get::<i32, _>(&99).is_none());
426    /// }
427    /// ```
428    ///
429    /// # Safety
430    ///
431    /// Undefined behavior if incorrect type `K` is given.
432    pub unsafe fn get<K, Q>(&self, key: &Q) -> Option<Interned<'_, K>>
433    where
434        K: borrow::Borrow<Q> + 'static,
435        Q: Hash + Eq + ?Sized,
436    {
437        debug_assert!(self.is_type_of::<K>());
438
439        unsafe {
440            let hash = Self::hash(&self.hash_builder, key);
441            let eq = Self::table_eq::<K, Q>(key);
442            self.set
443                .find(hash, eq)
444                .map(|raw| Interned::from_erased_raw(*raw))
445        }
446    }
447
448    /// Returns true if the set contains values of the given type.
449    pub fn is_type_of<K: 'static>(&self) -> bool {
450        self.arena.is_type_of::<K>()
451    }
452
453    /// Removes all items in the set.
454    pub fn clear(&mut self) {
455        self.arena.clear();
456        self.set.clear();
457    }
458
459    /// Returns `eq` closure that is used for some methods on the [`HashTable`].
460    ///
461    /// # Safety
462    ///
463    /// Undefined behavior if incorrect type `K` is given.
464    unsafe fn table_eq<'a, K, Q>(key: &'a Q) -> impl FnMut(&RawInterned) -> bool + 'a
465    where
466        K: borrow::Borrow<Q>,
467        Q: Hash + Eq + ?Sized,
468    {
469        move |entry: &RawInterned| unsafe {
470            let value = entry.cast::<K>().as_ref();
471            value.borrow() == key
472        }
473    }
474
475    /// Returns `hasher` closure that is used for some methods on the [`HashTable`].
476    ///
477    /// # Safety
478    ///
479    /// Undefined behavior if incorrect type `K` is given.
480    unsafe fn table_hasher<'a, K, Q>(hash_builder: &'a S) -> impl Fn(&RawInterned) -> u64 + 'a
481    where
482        K: borrow::Borrow<Q>,
483        Q: Hash + Eq + ?Sized,
484    {
485        |entry: &RawInterned| unsafe {
486            let value = entry.cast::<K>().as_ref();
487            Self::hash(hash_builder, value.borrow())
488        }
489    }
490
491    fn hash<K: Hash + ?Sized>(hash_builder: &S, value: &K) -> u64 {
492        let mut hasher = hash_builder.build_hasher();
493        value.hash(&mut hasher);
494        hasher.finish()
495    }
496}
497
498pub struct AnyArena {
499    bump: Bump,
500    ty: TypeId,
501    stride: usize,
502    raw_drop_slice: Option<unsafe fn(*mut u8, usize)>,
503    len: Cell<usize>,
504}
505
506impl AnyArena {
507    pub fn of<T: 'static>() -> Self {
508        Self {
509            bump: Bump::new(),
510            ty: TypeId::of::<T>(),
511            stride: Layout::new::<T>().pad_to_align().size(),
512            raw_drop_slice: if mem::needs_drop::<T>() {
513                Some(common::cast_then_drop_slice::<T>)
514            } else {
515                None
516            },
517            len: Cell::new(0),
518        }
519    }
520
521    pub fn is_type_of<T: 'static>(&self) -> bool {
522        TypeId::of::<T>() == self.ty
523    }
524
525    /// Returns number of elements in this arena.
526    pub fn len(&self) -> usize {
527        self.len.get()
528    }
529
530    pub fn is_empty(&self) -> bool {
531        self.len() == 0
532    }
533
534    pub fn alloc<T: 'static>(&self, value: T) -> &mut T {
535        debug_assert!(self.is_type_of::<T>());
536
537        self.len.set(self.len() + 1);
538        self.bump.alloc(value)
539    }
540
541    pub fn clear(&mut self) {
542        self.drop_all();
543        self.bump.reset();
544        self.len.set(0);
545    }
546
547    fn drop_all(&mut self) {
548        if let Some(raw_drop_slice) = self.raw_drop_slice {
549            if self.stride > 0 {
550                unsafe {
551                    for (ptr, len) in self.bump.iter_allocated_chunks_raw() {
552                        // Chunk would not be divisible by the `stride` especially when the stride
553                        // is greater than 16. In that case, we should ignore the remainder.
554                        let num_elems = len / self.stride;
555                        raw_drop_slice(ptr, num_elems);
556                    }
557                }
558            } else {
559                // ZST, but it has drop impl, which means the drop could have side effects, so we
560                // should call it.
561                let ptr = NonNull::<()>::dangling(); // aligned dangling pointer for ZST
562                unsafe {
563                    raw_drop_slice(ptr.as_ptr().cast(), self.len());
564                }
565            }
566        }
567    }
568}
569
570impl Drop for AnyArena {
571    fn drop(&mut self) {
572        self.drop_all();
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_any_interner() {
582        #[derive(PartialEq, Eq, Hash, Debug)]
583        struct A(i32);
584
585        let interner = AnyInterner::of::<A>();
586
587        unsafe {
588            let a = interner.intern(A(0));
589            let b = interner.intern(A(0));
590            let c = interner.intern(A(1));
591            assert_eq!(a, b);
592            assert_ne!(a, c);
593        }
594    }
595
596    #[test]
597    fn test_arena() {
598        test_arena_alloc();
599        test_arena_drop();
600    }
601
602    fn test_arena_alloc() {
603        const START: u32 = 0;
604        const END: u32 = 100;
605        const EXPECTED: u32 = (END + START) * (END - START + 1) / 2;
606
607        let arena = AnyArena::of::<u32>();
608        let mut refs = Vec::new();
609        for i in START..=END {
610            let ref_ = arena.alloc(i);
611            refs.push(ref_);
612        }
613        let acc = refs.into_iter().map(|ref_| *ref_).sum::<u32>();
614        assert_eq!(acc, EXPECTED);
615    }
616
617    fn test_arena_drop() {
618        macro_rules! test {
619            ($arr_len:literal, $align:literal) => {{
620                thread_local! {
621                    static SUM: Cell<u32> = Cell::new(0);
622                    static CNT: Cell<u32> = Cell::new(0);
623                }
624
625                #[repr(align($align))]
626                struct A([u8; $arr_len]);
627
628                // Restricted by `u8` and `A::new()`.
629                const _: () = const { assert!($arr_len < 256) };
630
631                impl A {
632                    fn new() -> Self {
633                        Self(std::array::from_fn(|i| i as u8))
634                    }
635
636                    fn sum() -> u32 {
637                        ($arr_len - 1) * $arr_len / 2
638                    }
639                }
640
641                impl Drop for A {
642                    fn drop(&mut self) {
643                        let sum = self.0.iter().map(|n| *n as u32).sum::<u32>();
644                        SUM.set(SUM.get() + sum);
645                        CNT.set(CNT.get() + 1);
646                    }
647                }
648
649                struct Zst;
650
651                impl Drop for Zst {
652                    fn drop(&mut self) {
653                        CNT.set(CNT.get() + 1);
654                    }
655                }
656
657                const REPEAT: u32 = 10;
658
659                // === Non-ZST type ===
660
661                let arena = AnyArena::of::<A>();
662                for _ in 0..REPEAT {
663                    arena.alloc(A::new());
664                }
665                drop(arena);
666
667                assert_eq!(SUM.get(), A::sum() * REPEAT);
668                assert_eq!(CNT.get(), REPEAT);
669                SUM.set(0);
670                CNT.set(0);
671
672                // === ZST type ===
673
674                let arena = AnyArena::of::<Zst>();
675                for _ in 0..REPEAT {
676                    arena.alloc(Zst);
677                }
678                drop(arena);
679
680                assert_eq!(CNT.get(), REPEAT);
681                CNT.set(0);
682            }};
683        }
684
685        // Array len, align
686        test!(1, 1);
687        test!(1, 2);
688        test!(1, 4);
689        test!(1, 8);
690        test!(1, 16);
691        test!(1, 32);
692        test!(1, 64);
693        test!(1, 128);
694        test!(1, 256);
695
696        test!(100, 1);
697        test!(100, 2);
698        test!(100, 4);
699        test!(100, 8);
700        test!(100, 16);
701        test!(100, 32);
702        test!(100, 64);
703        test!(100, 128);
704        test!(100, 256);
705    }
706}