Skip to main content

loro_common/
internal_string.rs

1use rustc_hash::FxHashSet;
2use serde::{Deserialize, Serialize};
3use std::borrow::Borrow;
4use std::slice;
5use std::sync::LazyLock;
6use std::{
7    fmt::Display,
8    num::NonZeroU64,
9    ops::Deref,
10    sync::{atomic::AtomicUsize, Arc, Mutex},
11};
12
13const INLINE_TAG: u8 = 0b_01;
14const TAG_MASK: u64 = 0b_11;
15const LEN_OFFSET: u64 = 4;
16const LEN_MASK: u64 = 0xF0;
17
18#[derive(Clone)]
19pub struct InternalString {
20    data: InternalStringData,
21}
22
23#[derive(Clone)]
24enum InternalStringData {
25    Inline(NonZeroU64),
26    Dynamic(Arc<Box<str>>),
27}
28
29impl std::fmt::Debug for InternalString {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.write_str("InternalString(")?;
32        std::fmt::Debug::fmt(self.as_str(), f)?;
33        f.write_str(")")
34    }
35}
36
37impl std::hash::Hash for InternalString {
38    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39        self.as_str().hash(state);
40    }
41}
42
43impl PartialEq for InternalString {
44    fn eq(&self, other: &Self) -> bool {
45        self.as_str() == other.as_str()
46    }
47}
48
49impl Eq for InternalString {}
50
51impl PartialOrd for InternalString {
52    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
53        Some(self.cmp(other))
54    }
55}
56
57impl Ord for InternalString {
58    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
59        self.as_str().cmp(other.as_str())
60    }
61}
62
63impl Serialize for InternalString {
64    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
65    where
66        S: serde::Serializer,
67    {
68        serializer.serialize_str(self.as_str())
69    }
70}
71
72impl<'de> Deserialize<'de> for InternalString {
73    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
74    where
75        D: serde::Deserializer<'de>,
76    {
77        let s = String::deserialize(deserializer)?;
78        Ok(InternalString::from(s.as_str()))
79    }
80}
81
82impl Default for InternalString {
83    fn default() -> Self {
84        let v: u64 = INLINE_TAG as u64;
85        Self {
86            // SAFETY: INLINE_TAG is non-zero
87            data: InternalStringData::new_inline(unsafe { NonZeroU64::new_unchecked(v) }),
88        }
89    }
90}
91
92impl InternalString {
93    pub fn as_str(&self) -> &str {
94        match &self.data {
95            InternalStringData::Inline(inline) => unsafe {
96                let len = (inline.get() & LEN_MASK) >> LEN_OFFSET;
97                let src = inline_atom_slice(inline);
98                // SAFETY: the chosen range is guaranteed to be valid str
99                std::str::from_utf8_unchecked(&src[..(len as usize)])
100            },
101            InternalStringData::Dynamic(dynamic) => dynamic.deref(),
102        }
103    }
104}
105
106impl InternalStringData {
107    fn new_inline(inline: NonZeroU64) -> Self {
108        debug_assert_eq!((inline.get() & TAG_MASK) as u8, INLINE_TAG);
109        Self::Inline(inline)
110    }
111
112    fn new_dynamic(dynamic: Arc<Box<str>>) -> Self {
113        Self::Dynamic(dynamic)
114    }
115}
116
117impl Drop for InternalStringData {
118    fn drop(&mut self) {
119        if let InternalStringData::Dynamic(arc) = self {
120            if Arc::strong_count(arc) == 2 {
121                drop_cache(arc.clone());
122            }
123        }
124    }
125}
126
127impl AsRef<str> for InternalString {
128    fn as_ref(&self) -> &str {
129        self.as_str()
130    }
131}
132
133impl From<&str> for InternalString {
134    #[inline(always)]
135    fn from(s: &str) -> Self {
136        if s.len() <= 7 {
137            let mut v: u64 = (INLINE_TAG as u64) | ((s.len() as u64) << LEN_OFFSET);
138            let arr = inline_atom_slice_mut(&mut v);
139            arr[..s.len()].copy_from_slice(s.as_bytes());
140            Self {
141                // SAFETY: The tag is 1
142                data: InternalStringData::new_inline(unsafe { NonZeroU64::new_unchecked(v) }),
143            }
144        } else {
145            let ans: Arc<Box<str>> = get_or_init_internalized_string(s);
146            Self {
147                data: InternalStringData::new_dynamic(ans),
148            }
149        }
150    }
151}
152
153#[inline(always)]
154fn inline_atom_slice(x: &NonZeroU64) -> &[u8] {
155    unsafe {
156        let x: *const NonZeroU64 = x;
157        let mut data = x as *const u8;
158        // All except the lowest byte, which is first in little-endian, last in big-endian.
159        if cfg!(target_endian = "little") {
160            data = data.offset(1);
161        }
162        let len = 7;
163        slice::from_raw_parts(data, len)
164    }
165}
166
167#[inline(always)]
168fn inline_atom_slice_mut(x: &mut u64) -> &mut [u8] {
169    unsafe {
170        let x: *mut u64 = x;
171        let mut data = x as *mut u8;
172        // All except the lowest byte, which is first in little-endian, last in big-endian.
173        if cfg!(target_endian = "little") {
174            data = data.offset(1);
175        }
176        let len = 7;
177        slice::from_raw_parts_mut(data, len)
178    }
179}
180
181impl From<String> for InternalString {
182    fn from(s: String) -> Self {
183        Self::from(s.as_str())
184    }
185}
186
187impl From<&InternalString> for String {
188    #[inline(always)]
189    fn from(value: &InternalString) -> Self {
190        value.as_str().to_string()
191    }
192}
193
194impl Display for InternalString {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        self.as_str().fmt(f)
197    }
198}
199
200impl Deref for InternalString {
201    type Target = str;
202
203    fn deref(&self) -> &Self::Target {
204        self.as_str()
205    }
206}
207
208#[derive(Hash, PartialEq, Eq)]
209struct ArcWrapper(Arc<Box<str>>);
210
211impl Borrow<str> for ArcWrapper {
212    fn borrow(&self) -> &str {
213        &self.0
214    }
215}
216
217static STRING_SET: LazyLock<Mutex<FxHashSet<ArcWrapper>>> =
218    LazyLock::new(|| Mutex::new(FxHashSet::default()));
219
220fn get_or_init_internalized_string(s: &str) -> Arc<Box<str>> {
221    static MAX_MET_CACHE_SIZE: AtomicUsize = AtomicUsize::new(1 << 16);
222
223    let mut set = STRING_SET.lock().unwrap();
224    if let Some(v) = set.get(s) {
225        v.0.clone()
226    } else {
227        let ans: Arc<Box<str>> = Arc::new(Box::from(s));
228        set.insert(ArcWrapper(ans.clone()));
229        let max = MAX_MET_CACHE_SIZE.load(std::sync::atomic::Ordering::Relaxed);
230        if set.capacity() >= max {
231            let old = set.len();
232            set.retain(|s| Arc::strong_count(&s.0) > 1);
233            let new = set.len();
234            if old - new > new / 2 {
235                set.shrink_to_fit();
236            }
237
238            MAX_MET_CACHE_SIZE.store(max * 2, std::sync::atomic::Ordering::Relaxed);
239        }
240
241        ans
242    }
243}
244
245fn drop_cache(s: Arc<Box<str>>) {
246    let mut set = STRING_SET.lock().unwrap();
247    set.remove(&ArcWrapper(s));
248    if set.len() < set.capacity() / 2 && set.capacity() > 128 {
249        set.shrink_to_fit();
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_string_cache() {
259        let s1 = InternalString::from("hello");
260        let s3 = InternalString::from("world");
261
262        // Content should match
263        assert_eq!("hello", s1.as_str());
264        assert_eq!(s3.as_str(), "world");
265    }
266
267    #[cfg(all(miri, target_pointer_width = "32"))]
268    #[test]
269    fn miri_dynamic_string_does_not_read_uninitialized_tag_bytes_on_32_bit() {
270        let s = InternalString::from("long enough to use the dynamic representation");
271
272        assert_eq!(s.as_str(), "long enough to use the dynamic representation");
273    }
274
275    #[test]
276    fn test_long_string_cache() {
277        let long_str1 = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.";
278        let long_str2 = "A very long string that contains lots of repeated characters: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
279
280        let s1 = InternalString::from(long_str1);
281        let s2 = InternalString::from(long_str1);
282        let s3 = InternalString::from(long_str2);
283
284        // Same long strings should be equal
285        assert_eq!(s1, s2);
286
287        // Different long strings should be different
288        assert_ne!(s1, s3);
289
290        // Content should match exactly
291        assert_eq!(s1.as_str(), long_str1);
292        assert_eq!(s1.as_str(), long_str1);
293        assert_eq!(s2.as_str(), long_str1);
294        assert_eq!(s3.as_str(), long_str2);
295
296        // Internal pointers should be same for equal strings
297        assert!(std::ptr::eq(s1.as_str().as_ptr(), s2.as_str().as_ptr()));
298        assert!(!std::ptr::eq(s1.as_str().as_ptr(), s3.as_str().as_ptr()));
299    }
300
301    #[test]
302    fn test_long_string_cache_drop() {
303        {
304            let set = STRING_SET.lock().unwrap();
305            assert_eq!(set.len(), 0);
306        }
307        {
308            let s1 = InternalString::from("hello".repeat(10));
309            let s2 = InternalString::from("hello".repeat(10));
310            assert!(std::ptr::eq(s1.as_str().as_ptr(), s2.as_str().as_ptr()));
311        }
312        let set = STRING_SET.lock().unwrap();
313        assert_eq!(set.len(), 0);
314    }
315}