loro_common/
internal_string.rs1use rustc_hash::FxHashSet;
2use serde::{Deserialize, Serialize};
3use std::borrow::Borrow;
4use std::slice;
5use std::sync::LazyLock;
6use std::{
7 fmt::Display,
8 num::NonZeroU64,
9 ops::Deref,
10 sync::{atomic::AtomicUsize, Arc, Mutex},
11};
12
13const DYNAMIC_TAG: u8 = 0b_00;
14const INLINE_TAG: u8 = 0b_01;
15const TAG_MASK: u64 = 0b_11;
16const LEN_OFFSET: u64 = 4;
17const LEN_MASK: u64 = 0xF0;
18
19#[repr(transparent)]
20#[derive(Clone)]
21pub struct InternalString {
22 unsafe_data: UnsafeData,
23}
24
25union UnsafeData {
26 inline: NonZeroU64,
27 dynamic: *const Box<str>,
28}
29
30unsafe impl Sync for UnsafeData {}
31unsafe impl Send for UnsafeData {}
32
33impl UnsafeData {
34 #[inline(always)]
35 fn is_inline(&self) -> bool {
36 unsafe { (self.inline.get() & TAG_MASK) as u8 == INLINE_TAG }
37 }
38}
39
40impl Clone for UnsafeData {
41 fn clone(&self) -> Self {
42 if self.is_inline() {
43 Self {
44 inline: unsafe { self.inline },
45 }
46 } else {
47 unsafe {
48 Arc::increment_strong_count(self.dynamic);
49 Self {
50 dynamic: self.dynamic,
51 }
52 }
53 }
54 }
55}
56
57impl std::fmt::Debug for InternalString {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.write_str("InternalString(")?;
60 std::fmt::Debug::fmt(self.as_str(), f)?;
61 f.write_str(")")
62 }
63}
64
65impl std::hash::Hash for InternalString {
66 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
67 self.as_str().hash(state);
68 }
69}
70
71impl PartialEq for InternalString {
72 fn eq(&self, other: &Self) -> bool {
73 self.as_str() == other.as_str()
74 }
75}
76
77impl Eq for InternalString {}
78
79impl PartialOrd for InternalString {
80 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
81 Some(self.cmp(other))
82 }
83}
84
85impl Ord for InternalString {
86 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
87 self.as_str().cmp(other.as_str())
88 }
89}
90
91impl Serialize for InternalString {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: serde::Serializer,
95 {
96 serializer.serialize_str(self.as_str())
97 }
98}
99
100impl<'de> Deserialize<'de> for InternalString {
101 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102 where
103 D: serde::Deserializer<'de>,
104 {
105 let s = String::deserialize(deserializer)?;
106 Ok(InternalString::from(s.as_str()))
107 }
108}
109
110impl Default for InternalString {
111 fn default() -> Self {
112 let v: u64 = INLINE_TAG as u64;
113 Self {
114 unsafe_data: UnsafeData {
116 inline: unsafe { NonZeroU64::new_unchecked(v) },
117 },
118 }
119 }
120}
121
122impl InternalString {
123 pub fn as_str(&self) -> &str {
124 unsafe {
125 match (self.unsafe_data.inline.get() & TAG_MASK) as u8 {
126 INLINE_TAG => {
127 let len = (self.unsafe_data.inline.get() & LEN_MASK) >> LEN_OFFSET;
128 let src = inline_atom_slice(&self.unsafe_data.inline);
129 std::str::from_utf8_unchecked(&src[..(len as usize)])
131 }
132 DYNAMIC_TAG => {
133 let ptr = self.unsafe_data.dynamic;
134 (*ptr).deref()
136 }
137 _ => unreachable!(),
138 }
139 }
140 }
141}
142
143impl AsRef<str> for InternalString {
144 fn as_ref(&self) -> &str {
145 self.as_str()
146 }
147}
148
149impl From<&str> for InternalString {
150 #[inline(always)]
151 fn from(s: &str) -> Self {
152 if s.len() <= 7 {
153 let mut v: u64 = (INLINE_TAG as u64) | ((s.len() as u64) << LEN_OFFSET);
154 let arr = inline_atom_slice_mut(&mut v);
155 arr[..s.len()].copy_from_slice(s.as_bytes());
156 Self {
157 unsafe_data: UnsafeData {
158 inline: unsafe { NonZeroU64::new_unchecked(v) },
160 },
161 }
162 } else {
163 let ans: Arc<Box<str>> = get_or_init_internalized_string(s);
164 let raw = Arc::into_raw(ans);
165 Self {
167 unsafe_data: UnsafeData { dynamic: raw },
168 }
169 }
170 }
171}
172
173#[inline(always)]
174fn inline_atom_slice(x: &NonZeroU64) -> &[u8] {
175 unsafe {
176 let x: *const NonZeroU64 = x;
177 let mut data = x as *const u8;
178 if cfg!(target_endian = "little") {
180 data = data.offset(1);
181 }
182 let len = 7;
183 slice::from_raw_parts(data, len)
184 }
185}
186
187#[inline(always)]
188fn inline_atom_slice_mut(x: &mut u64) -> &mut [u8] {
189 unsafe {
190 let x: *mut u64 = x;
191 let mut data = x as *mut u8;
192 if cfg!(target_endian = "little") {
194 data = data.offset(1);
195 }
196 let len = 7;
197 slice::from_raw_parts_mut(data, len)
198 }
199}
200
201impl From<String> for InternalString {
202 fn from(s: String) -> Self {
203 Self::from(s.as_str())
204 }
205}
206
207impl From<&InternalString> for String {
208 #[inline(always)]
209 fn from(value: &InternalString) -> Self {
210 value.as_str().to_string()
211 }
212}
213
214impl Display for InternalString {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 self.as_str().fmt(f)
217 }
218}
219
220impl Deref for InternalString {
221 type Target = str;
222
223 fn deref(&self) -> &Self::Target {
224 self.as_str()
225 }
226}
227
228#[derive(Hash, PartialEq, Eq)]
229struct ArcWrapper(Arc<Box<str>>);
230
231impl Borrow<str> for ArcWrapper {
232 fn borrow(&self) -> &str {
233 &self.0
234 }
235}
236
237static STRING_SET: LazyLock<Mutex<FxHashSet<ArcWrapper>>> =
238 LazyLock::new(|| Mutex::new(FxHashSet::default()));
239
240fn get_or_init_internalized_string(s: &str) -> Arc<Box<str>> {
241 static MAX_MET_CACHE_SIZE: AtomicUsize = AtomicUsize::new(1 << 16);
242
243 let mut set = STRING_SET.lock().unwrap();
244 if let Some(v) = set.get(s) {
245 v.0.clone()
246 } else {
247 let ans: Arc<Box<str>> = Arc::new(Box::from(s));
248 set.insert(ArcWrapper(ans.clone()));
249 let max = MAX_MET_CACHE_SIZE.load(std::sync::atomic::Ordering::Relaxed);
250 if set.capacity() >= max {
251 let old = set.len();
252 set.retain(|s| Arc::strong_count(&s.0) > 1);
253 let new = set.len();
254 if old - new > new / 2 {
255 set.shrink_to_fit();
256 }
257
258 MAX_MET_CACHE_SIZE.store(max * 2, std::sync::atomic::Ordering::Relaxed);
259 }
260
261 ans
262 }
263}
264
265fn drop_cache(s: Arc<Box<str>>) {
266 let mut set = STRING_SET.lock().unwrap();
267 set.remove(&ArcWrapper(s));
268 if set.len() < set.capacity() / 2 && set.capacity() > 128 {
269 set.shrink_to_fit();
270 }
271}
272
273impl Drop for InternalString {
274 fn drop(&mut self) {
275 unsafe {
276 if (self.unsafe_data.inline.get() & TAG_MASK) as u8 == DYNAMIC_TAG {
277 let ptr = self.unsafe_data.dynamic;
278 let arc: Arc<Box<str>> = Arc::from_raw(ptr);
280 if Arc::strong_count(&arc) == 2 {
281 drop_cache(arc);
282 } else {
283 drop(arc)
284 }
285 }
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_string_cache() {
296 let s1 = InternalString::from("hello");
297 let s3 = InternalString::from("world");
298
299 assert_eq!("hello", s1.as_str());
301 assert_eq!(s3.as_str(), "world");
302 }
303
304 #[test]
305 fn test_long_string_cache() {
306 let long_str1 = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.";
307 let long_str2 = "A very long string that contains lots of repeated characters: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
308
309 let s1 = InternalString::from(long_str1);
310 let s2 = InternalString::from(long_str1);
311 let s3 = InternalString::from(long_str2);
312
313 assert_eq!(s1, s2);
315
316 assert_ne!(s1, s3);
318
319 assert_eq!(s1.as_str(), long_str1);
321 assert_eq!(s1.as_str(), long_str1);
322 assert_eq!(s2.as_str(), long_str1);
323 assert_eq!(s3.as_str(), long_str2);
324
325 assert!(std::ptr::eq(s1.as_str().as_ptr(), s2.as_str().as_ptr()));
327 assert!(!std::ptr::eq(s1.as_str().as_ptr(), s3.as_str().as_ptr()));
328 }
329
330 #[test]
331 fn test_long_string_cache_drop() {
332 {
333 let set = STRING_SET.lock().unwrap();
334 assert_eq!(set.len(), 0);
335 }
336 {
337 let s1 = InternalString::from("hello".repeat(10));
338 let s2 = InternalString::from("hello".repeat(10));
339 assert!(std::ptr::eq(s1.as_str().as_ptr(), s2.as_str().as_ptr()));
340 }
341 let set = STRING_SET.lock().unwrap();
342 assert_eq!(set.len(), 0);
343 }
344}