loro_common/
internal_string.rs

1use rustc_hash::FxHashSet;
2use serde::{Deserialize, Serialize};
3use std::borrow::Borrow;
4use std::slice;
5use std::sync::LazyLock;
6use std::{
7    fmt::Display,
8    num::NonZeroU64,
9    ops::Deref,
10    sync::{atomic::AtomicUsize, Arc, Mutex},
11};
12
13const DYNAMIC_TAG: u8 = 0b_00;
14const INLINE_TAG: u8 = 0b_01;
15const TAG_MASK: u64 = 0b_11;
16const LEN_OFFSET: u64 = 4;
17const LEN_MASK: u64 = 0xF0;
18
19#[repr(transparent)]
20#[derive(Clone)]
21pub struct InternalString {
22    unsafe_data: UnsafeData,
23}
24
25union UnsafeData {
26    inline: NonZeroU64,
27    dynamic: *const Box<str>,
28}
29
30unsafe impl Sync for UnsafeData {}
31unsafe impl Send for UnsafeData {}
32
33impl UnsafeData {
34    #[inline(always)]
35    fn is_inline(&self) -> bool {
36        unsafe { (self.inline.get() & TAG_MASK) as u8 == INLINE_TAG }
37    }
38}
39
40impl Clone for UnsafeData {
41    fn clone(&self) -> Self {
42        if self.is_inline() {
43            Self {
44                inline: unsafe { self.inline },
45            }
46        } else {
47            unsafe {
48                Arc::increment_strong_count(self.dynamic);
49                Self {
50                    dynamic: self.dynamic,
51                }
52            }
53        }
54    }
55}
56
57impl std::fmt::Debug for InternalString {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.write_str("InternalString(")?;
60        std::fmt::Debug::fmt(self.as_str(), f)?;
61        f.write_str(")")
62    }
63}
64
65impl std::hash::Hash for InternalString {
66    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
67        self.as_str().hash(state);
68    }
69}
70
71impl PartialEq for InternalString {
72    fn eq(&self, other: &Self) -> bool {
73        self.as_str() == other.as_str()
74    }
75}
76
77impl Eq for InternalString {}
78
79impl PartialOrd for InternalString {
80    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
81        Some(self.cmp(other))
82    }
83}
84
85impl Ord for InternalString {
86    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
87        self.as_str().cmp(other.as_str())
88    }
89}
90
91impl Serialize for InternalString {
92    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93    where
94        S: serde::Serializer,
95    {
96        serializer.serialize_str(self.as_str())
97    }
98}
99
100impl<'de> Deserialize<'de> for InternalString {
101    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102    where
103        D: serde::Deserializer<'de>,
104    {
105        let s = String::deserialize(deserializer)?;
106        Ok(InternalString::from(s.as_str()))
107    }
108}
109
110impl Default for InternalString {
111    fn default() -> Self {
112        let v: u64 = INLINE_TAG as u64;
113        Self {
114            // SAFETY: INLINE_TAG is non-zero
115            unsafe_data: UnsafeData {
116                inline: unsafe { NonZeroU64::new_unchecked(v) },
117            },
118        }
119    }
120}
121
122impl InternalString {
123    pub fn as_str(&self) -> &str {
124        unsafe {
125            match (self.unsafe_data.inline.get() & TAG_MASK) as u8 {
126                INLINE_TAG => {
127                    let len = (self.unsafe_data.inline.get() & LEN_MASK) >> LEN_OFFSET;
128                    let src = inline_atom_slice(&self.unsafe_data.inline);
129                    // SAFETY: the chosen range is guaranteed to be valid str
130                    std::str::from_utf8_unchecked(&src[..(len as usize)])
131                }
132                DYNAMIC_TAG => {
133                    let ptr = self.unsafe_data.dynamic;
134                    // SAFETY: ptr is valid
135                    (*ptr).deref()
136                }
137                _ => unreachable!(),
138            }
139        }
140    }
141}
142
143impl AsRef<str> for InternalString {
144    fn as_ref(&self) -> &str {
145        self.as_str()
146    }
147}
148
149impl From<&str> for InternalString {
150    #[inline(always)]
151    fn from(s: &str) -> Self {
152        if s.len() <= 7 {
153            let mut v: u64 = (INLINE_TAG as u64) | ((s.len() as u64) << LEN_OFFSET);
154            let arr = inline_atom_slice_mut(&mut v);
155            arr[..s.len()].copy_from_slice(s.as_bytes());
156            Self {
157                unsafe_data: UnsafeData {
158                    // SAFETY: The tag is 1
159                    inline: unsafe { NonZeroU64::new_unchecked(v) },
160                },
161            }
162        } else {
163            let ans: Arc<Box<str>> = get_or_init_internalized_string(s);
164            let raw = Arc::into_raw(ans);
165            // SAFETY: Pointer is non-zero
166            Self {
167                unsafe_data: UnsafeData { dynamic: raw },
168            }
169        }
170    }
171}
172
173#[inline(always)]
174fn inline_atom_slice(x: &NonZeroU64) -> &[u8] {
175    unsafe {
176        let x: *const NonZeroU64 = x;
177        let mut data = x as *const u8;
178        // All except the lowest byte, which is first in little-endian, last in big-endian.
179        if cfg!(target_endian = "little") {
180            data = data.offset(1);
181        }
182        let len = 7;
183        slice::from_raw_parts(data, len)
184    }
185}
186
187#[inline(always)]
188fn inline_atom_slice_mut(x: &mut u64) -> &mut [u8] {
189    unsafe {
190        let x: *mut u64 = x;
191        let mut data = x as *mut u8;
192        // All except the lowest byte, which is first in little-endian, last in big-endian.
193        if cfg!(target_endian = "little") {
194            data = data.offset(1);
195        }
196        let len = 7;
197        slice::from_raw_parts_mut(data, len)
198    }
199}
200
201impl From<String> for InternalString {
202    fn from(s: String) -> Self {
203        Self::from(s.as_str())
204    }
205}
206
207impl From<&InternalString> for String {
208    #[inline(always)]
209    fn from(value: &InternalString) -> Self {
210        value.as_str().to_string()
211    }
212}
213
214impl Display for InternalString {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        self.as_str().fmt(f)
217    }
218}
219
220impl Deref for InternalString {
221    type Target = str;
222
223    fn deref(&self) -> &Self::Target {
224        self.as_str()
225    }
226}
227
228#[derive(Hash, PartialEq, Eq)]
229struct ArcWrapper(Arc<Box<str>>);
230
231impl Borrow<str> for ArcWrapper {
232    fn borrow(&self) -> &str {
233        &self.0
234    }
235}
236
237static STRING_SET: LazyLock<Mutex<FxHashSet<ArcWrapper>>> =
238    LazyLock::new(|| Mutex::new(FxHashSet::default()));
239
240fn get_or_init_internalized_string(s: &str) -> Arc<Box<str>> {
241    static MAX_MET_CACHE_SIZE: AtomicUsize = AtomicUsize::new(1 << 16);
242
243    let mut set = STRING_SET.lock().unwrap();
244    if let Some(v) = set.get(s) {
245        v.0.clone()
246    } else {
247        let ans: Arc<Box<str>> = Arc::new(Box::from(s));
248        set.insert(ArcWrapper(ans.clone()));
249        let max = MAX_MET_CACHE_SIZE.load(std::sync::atomic::Ordering::Relaxed);
250        if set.capacity() >= max {
251            let old = set.len();
252            set.retain(|s| Arc::strong_count(&s.0) > 1);
253            let new = set.len();
254            if old - new > new / 2 {
255                set.shrink_to_fit();
256            }
257
258            MAX_MET_CACHE_SIZE.store(max * 2, std::sync::atomic::Ordering::Relaxed);
259        }
260
261        ans
262    }
263}
264
265fn drop_cache(s: Arc<Box<str>>) {
266    let mut set = STRING_SET.lock().unwrap();
267    set.remove(&ArcWrapper(s));
268    if set.len() < set.capacity() / 2 && set.capacity() > 128 {
269        set.shrink_to_fit();
270    }
271}
272
273impl Drop for InternalString {
274    fn drop(&mut self) {
275        unsafe {
276            if (self.unsafe_data.inline.get() & TAG_MASK) as u8 == DYNAMIC_TAG {
277                let ptr = self.unsafe_data.dynamic;
278                // SAFETY: ptr is a valid Arc
279                let arc: Arc<Box<str>> = Arc::from_raw(ptr);
280                if Arc::strong_count(&arc) == 2 {
281                    drop_cache(arc);
282                } else {
283                    drop(arc)
284                }
285            }
286        }
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_string_cache() {
296        let s1 = InternalString::from("hello");
297        let s3 = InternalString::from("world");
298
299        // Content should match
300        assert_eq!("hello", s1.as_str());
301        assert_eq!(s3.as_str(), "world");
302    }
303
304    #[test]
305    fn test_long_string_cache() {
306        let long_str1 = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.";
307        let long_str2 = "A very long string that contains lots of repeated characters: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
308
309        let s1 = InternalString::from(long_str1);
310        let s2 = InternalString::from(long_str1);
311        let s3 = InternalString::from(long_str2);
312
313        // Same long strings should be equal
314        assert_eq!(s1, s2);
315
316        // Different long strings should be different
317        assert_ne!(s1, s3);
318
319        // Content should match exactly
320        assert_eq!(s1.as_str(), long_str1);
321        assert_eq!(s1.as_str(), long_str1);
322        assert_eq!(s2.as_str(), long_str1);
323        assert_eq!(s3.as_str(), long_str2);
324
325        // Internal pointers should be same for equal strings
326        assert!(std::ptr::eq(s1.as_str().as_ptr(), s2.as_str().as_ptr()));
327        assert!(!std::ptr::eq(s1.as_str().as_ptr(), s3.as_str().as_ptr()));
328    }
329
330    #[test]
331    fn test_long_string_cache_drop() {
332        {
333            let set = STRING_SET.lock().unwrap();
334            assert_eq!(set.len(), 0);
335        }
336        {
337            let s1 = InternalString::from("hello".repeat(10));
338            let s2 = InternalString::from("hello".repeat(10));
339            assert!(std::ptr::eq(s1.as_str().as_ptr(), s2.as_str().as_ptr()));
340        }
341        let set = STRING_SET.lock().unwrap();
342        assert_eq!(set.len(), 0);
343    }
344}