Skip to main content

pawkit_interner/
lib.rs

1#![feature(str_from_raw_parts)]
2
3use core::{fmt, str};
4use std::{
5    alloc::{Layout, alloc, dealloc},
6    fmt::{Debug, Display},
7    hint::cold_path,
8    mem::forget,
9    ops::Deref,
10    ptr::{self, NonNull, copy_nonoverlapping},
11    str::FromStr,
12    sync::{
13        LazyLock,
14        atomic::{AtomicUsize, Ordering},
15    },
16};
17
18use dashmap::{DashMap, Entry};
19use serde::{Deserialize, Serialize, de::Visitor};
20
21const MAX_REFCOUNT: usize = (isize::MAX) as usize;
22
23struct InternInner {
24    strong: AtomicUsize,
25    weak: AtomicUsize,
26    len: usize,
27}
28
29/// Represents an interned string.
30/// Similar to an Arc<str>, but it also deduplicates strings.
31/// It is designed for fast forward lookup and cloning, with reverse lookup being prioritized over insertions / deletions.
32#[derive(PartialEq, Eq, Hash)]
33pub struct InternString {
34    inner: NonNull<InternInner>,
35}
36
37/// A weak interned string.
38/// Similar to Weak<str>, but references an InternString instead
39#[derive(PartialEq, Eq, Hash)]
40pub struct WeakInternString {
41    inner: NonNull<InternInner>,
42}
43
44/// # SAFETY
45/// The keys are casted to a static lifetime, and reference the underlying InternInner data,
46/// but the keys are removed before the InternInner is freed.
47static DATA: LazyLock<DashMap<&'static str, WeakInternString>> = LazyLock::new(Default::default);
48
49impl InternInner {
50    unsafe fn data_ptr<'a>(value: NonNull<Self>) -> *const u8 {
51        unsafe {
52            return value.as_ptr().add(1) as *const u8;
53        }
54    }
55
56    unsafe fn data_ptr_mut<'a>(value: NonNull<Self>) -> *mut u8 {
57        unsafe {
58            return value.as_ptr().add(1) as *mut u8;
59        }
60    }
61
62    unsafe fn data_mut<'a>(value: NonNull<Self>) -> &'a mut str {
63        unsafe {
64            return str::from_raw_parts_mut(Self::data_ptr_mut(value), (*value.as_ptr()).len);
65        }
66    }
67
68    unsafe fn data<'a>(value: NonNull<Self>) -> &'a str {
69        unsafe {
70            return str::from_raw_parts(Self::data_ptr(value), (*value.as_ptr()).len);
71        }
72    }
73
74    fn layout_for(len: usize) -> Layout {
75        return Layout::new::<Self>()
76            .extend(Layout::array::<u8>(len).unwrap())
77            .unwrap()
78            .0
79            .pad_to_align();
80    }
81
82    fn layout(value: NonNull<Self>) -> Layout {
83        unsafe {
84            return Self::layout_for((*value.as_ptr()).len);
85        }
86    }
87
88    unsafe fn alloc(s: &str) -> (NonNull<Self>, &'static str) {
89        unsafe {
90            let layout = Self::layout_for(s.len());
91
92            let ptr = alloc(layout) as *mut Self;
93
94            let value = &mut *ptr;
95
96            value.strong = AtomicUsize::new(1);
97            value.weak = AtomicUsize::new(1);
98            value.len = s.len();
99
100            let ptr = NonNull::new_unchecked(ptr);
101            let data = InternInner::data_mut(ptr);
102
103            copy_nonoverlapping(s.as_ptr(), data.as_mut_ptr(), s.len());
104
105            return (ptr, data);
106        }
107    }
108
109    unsafe fn dealloc(value: NonNull<Self>) {
110        unsafe {
111            let ptr = value.as_ptr();
112            let layout = Self::layout(value);
113
114            dealloc(ptr as *mut u8, layout);
115        }
116    }
117}
118
119impl InternString {
120    fn inner(&self) -> &InternInner {
121        unsafe {
122            return &*self.inner.as_ptr();
123        }
124    }
125
126    /// Consumes the `InternString`, returning the wrapped pointer.
127    ///
128    /// To avoid a memory leak the pointer must be converted back to an `InternString` using
129    /// [`InternString::from_raw`].
130    pub fn into_raw(self) -> *const u8 {
131        let ptr = self.inner.as_ptr() as *const u8;
132
133        forget(self);
134
135        return ptr;
136    }
137
138    /// Constructs an `InternString` from a raw pointer.
139    ///
140    /// The raw pointer must have been previously returned by a call to [`InternString::into_raw`].
141    pub unsafe fn from_raw(value: *const u8) -> Option<Self> {
142        let inner = NonNull::new(value as *mut InternInner)?;
143        if !inner.is_aligned() {
144            return None;
145        }
146
147        return Some(Self { inner: inner });
148    }
149
150    /// Creates a new intern string from the provided string
151    /// If an intern string exists with the same content, it returns a clone of that.
152    pub fn new(s: &str) -> Self {
153        if let Some(weak) = DATA.get(s) {
154            let value = Self { inner: weak.inner };
155
156            value.inner().strong.fetch_add(1, Ordering::Relaxed);
157
158            return value;
159        }
160
161        let (ptr, str) = unsafe { InternInner::alloc(s) };
162
163        match DATA.entry(str) {
164            Entry::Occupied(occupied) => {
165                cold_path();
166
167                unsafe {
168                    InternInner::dealloc(ptr);
169                }
170
171                let value = Self {
172                    inner: occupied.get().inner,
173                };
174
175                value.inner().strong.fetch_add(1, Ordering::Relaxed);
176
177                return value;
178            }
179            Entry::Vacant(vacant) => {
180                vacant.insert(WeakInternString { inner: ptr });
181
182                let value = Self { inner: ptr };
183
184                return value;
185            }
186        }
187    }
188
189    /// Retruns the underlying string data.
190    pub fn as_str<'a>(&'a self) -> &'a str {
191        unsafe {
192            return InternInner::data(self.inner);
193        }
194    }
195
196    /// Converts into a weak reference
197    pub fn into_weak(&self) -> WeakInternString {
198        let weak = WeakInternString { inner: self.inner };
199
200        self.inner().weak.fetch_add(1, Ordering::Relaxed);
201
202        return weak;
203    }
204}
205
206impl WeakInternString {
207    fn inner(&self) -> &InternInner {
208        unsafe {
209            return &*self.inner.as_ptr();
210        }
211    }
212
213    /// Returns true if the string is still alive
214    /// The string is considered dead if there are no strong references
215    pub fn is_alive(&self) -> bool {
216        return self.inner().strong.load(Ordering::Acquire) != 0;
217    }
218
219    /// Gets the underlying string if it's still alive
220    pub fn as_str<'a>(&'a self) -> Option<&'a str> {
221        if !self.is_alive() {
222            return None;
223        }
224
225        unsafe {
226            return Some(InternInner::data(self.inner));
227        }
228    }
229
230    /// Converts to a strong reference if the string is still alive
231    pub fn into_strong(&self) -> Option<InternString> {
232        #[inline]
233        fn checked_increment(n: usize) -> Option<usize> {
234            if n == 0 {
235                return None;
236            }
237            assert!(n <= MAX_REFCOUNT);
238            return Some(n + 1);
239        }
240
241        self.inner()
242            .strong
243            .fetch_update(Ordering::Acquire, Ordering::Relaxed, checked_increment)
244            .ok()?;
245
246        return Some(InternString { inner: self.inner });
247    }
248}
249
250unsafe impl Send for InternString {}
251unsafe impl Sync for InternString {}
252
253unsafe impl Send for WeakInternString {}
254unsafe impl Sync for WeakInternString {}
255
256impl Drop for InternString {
257    fn drop(&mut self) {
258        if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
259            return;
260        }
261
262        self.inner().strong.load(Ordering::Acquire);
263
264        DATA.remove(self.as_str());
265    }
266}
267
268impl Drop for WeakInternString {
269    fn drop(&mut self) {
270        if self.inner().weak.fetch_sub(1, Ordering::Release) == 1 {
271            self.inner().weak.load(Ordering::Acquire);
272
273            unsafe {
274                InternInner::dealloc(self.inner);
275            }
276        }
277    }
278}
279
280impl Clone for InternString {
281    fn clone(&self) -> Self {
282        let ref_count = self.inner().strong.fetch_add(1, Ordering::Relaxed);
283
284        assert!(ref_count < MAX_REFCOUNT);
285
286        return Self {
287            inner: self.inner.clone(),
288        };
289    }
290}
291
292impl Clone for WeakInternString {
293    fn clone(&self) -> Self {
294        let ref_count = self.inner().weak.fetch_add(1, Ordering::Relaxed);
295
296        assert!(ref_count < MAX_REFCOUNT);
297
298        return Self {
299            inner: self.inner.clone(),
300        };
301    }
302}
303
304impl PartialEq<WeakInternString> for InternString {
305    fn eq(&self, other: &WeakInternString) -> bool {
306        return ptr::addr_eq(self.inner.as_ptr(), other.inner.as_ptr());
307    }
308}
309
310impl PartialEq<InternString> for WeakInternString {
311    fn eq(&self, other: &InternString) -> bool {
312        return ptr::addr_eq(self.inner.as_ptr(), other.inner.as_ptr());
313    }
314}
315
316impl PartialEq<&str> for InternString {
317    fn eq(&self, other: &&str) -> bool {
318        return self.as_str() == *other;
319    }
320}
321
322impl PartialEq<&str> for WeakInternString {
323    fn eq(&self, other: &&str) -> bool {
324        let Some(str) = self.as_str() else {
325            return false;
326        };
327
328        return str == *other;
329    }
330}
331
332impl Deref for InternString {
333    type Target = str;
334
335    fn deref(&self) -> &Self::Target {
336        return self.as_str();
337    }
338}
339
340impl Debug for InternString {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        return f.debug_tuple("InternString").field(&self.as_str()).finish();
343    }
344}
345
346impl Display for InternString {
347    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348        return f.write_str(self.as_str());
349    }
350}
351
352impl Debug for WeakInternString {
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        let mut t = f.debug_tuple("WeakInternString");
355
356        let Some(s) = self.as_str() else {
357            return t.field(&"<dead>").finish();
358        };
359
360        return t.field(&s).finish();
361    }
362}
363
364impl Display for WeakInternString {
365    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366        let Some(s) = self.as_str() else {
367            return f.write_str(&"<dead>");
368        };
369
370        return f.write_str(&s);
371    }
372}
373
374impl FromStr for InternString {
375    type Err = ();
376
377    fn from_str(s: &str) -> Result<Self, Self::Err> {
378        return Ok(Self::new(s));
379    }
380}
381
382impl From<&str> for InternString {
383    fn from(value: &str) -> Self {
384        return Self::new(value);
385    }
386}
387
388impl From<String> for InternString {
389    fn from(value: String) -> Self {
390        return Self::new(&value);
391    }
392}
393
394impl Serialize for InternString {
395    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
396    where
397        S: serde::Serializer,
398    {
399        return serializer.serialize_str(self);
400    }
401}
402
403struct InternStringVisitor;
404
405impl<'a> Visitor<'a> for InternStringVisitor {
406    type Value = InternString;
407
408    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
409        formatter.write_str("a string")
410    }
411
412    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
413    where
414        E: serde::de::Error,
415    {
416        Ok(v.into())
417    }
418
419    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
420    where
421        E: serde::de::Error,
422    {
423        Ok(v.into())
424    }
425
426    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
427    where
428        E: serde::de::Error,
429    {
430        match str::from_utf8(v) {
431            Ok(s) => Ok(s.into()),
432            Err(_) => Err(serde::de::Error::invalid_value(
433                serde::de::Unexpected::Bytes(v),
434                &self,
435            )),
436        }
437    }
438
439    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
440    where
441        E: serde::de::Error,
442    {
443        match String::from_utf8(v) {
444            Ok(s) => Ok(s.into()),
445            Err(e) => Err(serde::de::Error::invalid_value(
446                serde::de::Unexpected::Bytes(&e.into_bytes()),
447                &self,
448            )),
449        }
450    }
451}
452
453impl<'de> Deserialize<'de> for InternString {
454    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
455    where
456        D: serde::Deserializer<'de>,
457    {
458        return deserializer.deserialize_string(InternStringVisitor);
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use std::sync::{Arc, Barrier, Mutex};
466    use std::thread;
467
468    #[test]
469    fn basic_interning() {
470        let s1 = InternString::new("hello");
471        let s2 = InternString::new("hello");
472        let s3 = InternString::new("world");
473
474        assert_eq!(s1, s2);
475        assert_ne!(s1, s3);
476
477        assert_eq!(s1.as_str(), "hello");
478        assert_eq!(s3.as_str(), "world");
479    }
480
481    #[test]
482    fn weak_references() {
483        let s1 = InternString::new("weak_test");
484        let w1 = s1.into_weak();
485
486        assert!(w1.is_alive());
487        assert_eq!(w1.as_str(), Some("weak_test"));
488
489        drop(s1);
490
491        assert!(!w1.is_alive());
492        assert_eq!(w1.as_str(), None);
493    }
494
495    #[test]
496    fn cloning_strong() {
497        let s1 = InternString::new("clone_test");
498        let s2 = s1.clone();
499        let s3 = s2.clone();
500
501        assert_eq!(s1.inner.as_ptr(), s2.inner.as_ptr());
502        assert_eq!(s2.inner.as_ptr(), s3.inner.as_ptr());
503
504        assert_eq!(s1.as_str(), "clone_test");
505    }
506
507    #[test]
508    fn cloning_weak() {
509        let s1 = InternString::new("weak_clone");
510        let w1 = s1.into_weak();
511        let w2 = w1.clone();
512
513        assert_eq!(w1.inner.as_ptr(), w2.inner.as_ptr());
514
515        drop(s1);
516
517        assert!(!w1.is_alive());
518        assert!(!w2.is_alive());
519    }
520
521    #[test]
522    fn eq_partial() {
523        let s = InternString::new("eq_test");
524        let w = s.into_weak();
525
526        assert_eq!(s, w);
527        assert_eq!(w, s);
528
529        assert_eq!(s, "eq_test");
530        assert_eq!(w, "eq_test");
531    }
532
533    #[test]
534    fn data_deduplication() {
535        let s1 = InternString::new("dedup");
536        let s2 = InternString::new("dedup");
537
538        assert!(ptr::addr_eq(s1.inner.as_ptr(), s2.inner.as_ptr()));
539    }
540
541    #[test]
542    fn drop_cleans_data() {
543        let s = InternString::new("cleanup_test");
544
545        drop(s);
546
547        assert!(DATA.get("cleanup_test").is_none());
548    }
549
550    #[test]
551    fn multithreaded_usage() {
552        let s = Arc::new(InternString::new("thread_test"));
553
554        let mut handles = vec![];
555
556        for _ in 0..10 {
557            let s_clone = Arc::clone(&s);
558            handles.push(thread::spawn(move || {
559                let local = s_clone.clone();
560                assert_eq!(local.as_str(), "thread_test");
561            }));
562        }
563
564        for handle in handles {
565            handle.join().unwrap();
566        }
567
568        assert_eq!(s.as_str(), "thread_test");
569    }
570
571    #[test]
572    fn weak_after_drop_multithreaded() {
573        let s = InternString::new("weak_thread");
574        let w = s.into_weak();
575
576        let handle = thread::spawn(move || {
577            drop(s);
578        });
579
580        handle.join().unwrap();
581
582        assert!(!w.is_alive());
583        assert_eq!(w.as_str(), None);
584    }
585
586    #[test]
587    fn simultaneous_intern() {
588        let barrier = Arc::new(Barrier::new(2));
589        let s1 = Arc::new(Mutex::new(None));
590        let s2 = Arc::new(Mutex::new(None));
591
592        let b1 = barrier.clone();
593        let s1c = s1.clone();
594        let t1 = thread::spawn(move || {
595            b1.wait();
596            *s1c.lock().unwrap() = Some(InternString::new("race_test"));
597        });
598
599        let b2 = barrier.clone();
600        let s2c = s2.clone();
601        let t2 = thread::spawn(move || {
602            b2.wait();
603            *s2c.lock().unwrap() = Some(InternString::new("race_test"));
604        });
605
606        t1.join().unwrap();
607        t2.join().unwrap();
608
609        let s1 = s1.lock().unwrap().take().unwrap();
610        let s2 = s2.lock().unwrap().take().unwrap();
611
612        assert!(ptr::addr_eq(s1.inner.as_ptr(), s2.inner.as_ptr()));
613    }
614
615    #[test]
616    fn weak_upgrade() {
617        let s = InternString::new("upgrade_test");
618        let w = s.into_weak();
619
620        let s2 = w.into_strong().expect("Should upgrade");
621        assert_eq!(s2.as_str(), "upgrade_test");
622
623        drop(s);
624        drop(s2);
625
626        assert!(w.into_strong().is_none());
627    }
628}