1#![feature(str_from_raw_parts)]
2
3use core::{fmt, str};
4use std::{
5 alloc::{Layout, alloc, dealloc},
6 fmt::{Debug, Display},
7 hint::cold_path,
8 mem::forget,
9 ops::Deref,
10 ptr::{self, NonNull, copy_nonoverlapping},
11 str::FromStr,
12 sync::{
13 LazyLock,
14 atomic::{AtomicUsize, Ordering},
15 },
16};
17
18use dashmap::{DashMap, Entry};
19use serde::{Deserialize, Serialize, de::Visitor};
20
21const MAX_REFCOUNT: usize = (isize::MAX) as usize;
22
23struct InternInner {
24 strong: AtomicUsize,
25 weak: AtomicUsize,
26 len: usize,
27}
28
29#[derive(PartialEq, Eq, Hash)]
33pub struct InternString {
34 inner: NonNull<InternInner>,
35}
36
37#[derive(PartialEq, Eq, Hash)]
40pub struct WeakInternString {
41 inner: NonNull<InternInner>,
42}
43
44static DATA: LazyLock<DashMap<&'static str, WeakInternString>> = LazyLock::new(Default::default);
48
49impl InternInner {
50 unsafe fn data_ptr<'a>(value: NonNull<Self>) -> *const u8 {
51 unsafe {
52 return value.as_ptr().add(1) as *const u8;
53 }
54 }
55
56 unsafe fn data_ptr_mut<'a>(value: NonNull<Self>) -> *mut u8 {
57 unsafe {
58 return value.as_ptr().add(1) as *mut u8;
59 }
60 }
61
62 unsafe fn data_mut<'a>(value: NonNull<Self>) -> &'a mut str {
63 unsafe {
64 return str::from_raw_parts_mut(Self::data_ptr_mut(value), (*value.as_ptr()).len);
65 }
66 }
67
68 unsafe fn data<'a>(value: NonNull<Self>) -> &'a str {
69 unsafe {
70 return str::from_raw_parts(Self::data_ptr(value), (*value.as_ptr()).len);
71 }
72 }
73
74 fn layout_for(len: usize) -> Layout {
75 return Layout::new::<Self>()
76 .extend(Layout::array::<u8>(len).unwrap())
77 .unwrap()
78 .0
79 .pad_to_align();
80 }
81
82 fn layout(value: NonNull<Self>) -> Layout {
83 unsafe {
84 return Self::layout_for((*value.as_ptr()).len);
85 }
86 }
87
88 unsafe fn alloc(s: &str) -> (NonNull<Self>, &'static str) {
89 unsafe {
90 let layout = Self::layout_for(s.len());
91
92 let ptr = alloc(layout) as *mut Self;
93
94 let value = &mut *ptr;
95
96 value.strong = AtomicUsize::new(1);
97 value.weak = AtomicUsize::new(1);
98 value.len = s.len();
99
100 let ptr = NonNull::new_unchecked(ptr);
101 let data = InternInner::data_mut(ptr);
102
103 copy_nonoverlapping(s.as_ptr(), data.as_mut_ptr(), s.len());
104
105 return (ptr, data);
106 }
107 }
108
109 unsafe fn dealloc(value: NonNull<Self>) {
110 unsafe {
111 let ptr = value.as_ptr();
112 let layout = Self::layout(value);
113
114 dealloc(ptr as *mut u8, layout);
115 }
116 }
117}
118
119impl InternString {
120 fn inner(&self) -> &InternInner {
121 unsafe {
122 return &*self.inner.as_ptr();
123 }
124 }
125
126 pub fn into_raw(self) -> *const u8 {
131 let ptr = self.inner.as_ptr() as *const u8;
132
133 forget(self);
134
135 return ptr;
136 }
137
138 pub unsafe fn from_raw(value: *const u8) -> Option<Self> {
142 let inner = NonNull::new(value as *mut InternInner)?;
143 if !inner.is_aligned() {
144 return None;
145 }
146
147 return Some(Self { inner: inner });
148 }
149
150 pub fn new(s: &str) -> Self {
153 if let Some(weak) = DATA.get(s) {
154 let value = Self { inner: weak.inner };
155
156 value.inner().strong.fetch_add(1, Ordering::Relaxed);
157
158 return value;
159 }
160
161 let (ptr, str) = unsafe { InternInner::alloc(s) };
162
163 match DATA.entry(str) {
164 Entry::Occupied(occupied) => {
165 cold_path();
166
167 unsafe {
168 InternInner::dealloc(ptr);
169 }
170
171 let value = Self {
172 inner: occupied.get().inner,
173 };
174
175 value.inner().strong.fetch_add(1, Ordering::Relaxed);
176
177 return value;
178 }
179 Entry::Vacant(vacant) => {
180 vacant.insert(WeakInternString { inner: ptr });
181
182 let value = Self { inner: ptr };
183
184 return value;
185 }
186 }
187 }
188
189 pub fn as_str<'a>(&'a self) -> &'a str {
191 unsafe {
192 return InternInner::data(self.inner);
193 }
194 }
195
196 pub fn into_weak(&self) -> WeakInternString {
198 let weak = WeakInternString { inner: self.inner };
199
200 self.inner().weak.fetch_add(1, Ordering::Relaxed);
201
202 return weak;
203 }
204}
205
206impl WeakInternString {
207 fn inner(&self) -> &InternInner {
208 unsafe {
209 return &*self.inner.as_ptr();
210 }
211 }
212
213 pub fn is_alive(&self) -> bool {
216 return self.inner().strong.load(Ordering::Acquire) != 0;
217 }
218
219 pub fn as_str<'a>(&'a self) -> Option<&'a str> {
221 if !self.is_alive() {
222 return None;
223 }
224
225 unsafe {
226 return Some(InternInner::data(self.inner));
227 }
228 }
229
230 pub fn into_strong(&self) -> Option<InternString> {
232 #[inline]
233 fn checked_increment(n: usize) -> Option<usize> {
234 if n == 0 {
235 return None;
236 }
237 assert!(n <= MAX_REFCOUNT);
238 return Some(n + 1);
239 }
240
241 self.inner()
242 .strong
243 .fetch_update(Ordering::Acquire, Ordering::Relaxed, checked_increment)
244 .ok()?;
245
246 return Some(InternString { inner: self.inner });
247 }
248}
249
250unsafe impl Send for InternString {}
251unsafe impl Sync for InternString {}
252
253unsafe impl Send for WeakInternString {}
254unsafe impl Sync for WeakInternString {}
255
256impl Drop for InternString {
257 fn drop(&mut self) {
258 if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
259 return;
260 }
261
262 self.inner().strong.load(Ordering::Acquire);
263
264 DATA.remove(self.as_str());
265 }
266}
267
268impl Drop for WeakInternString {
269 fn drop(&mut self) {
270 if self.inner().weak.fetch_sub(1, Ordering::Release) == 1 {
271 self.inner().weak.load(Ordering::Acquire);
272
273 unsafe {
274 InternInner::dealloc(self.inner);
275 }
276 }
277 }
278}
279
280impl Clone for InternString {
281 fn clone(&self) -> Self {
282 let ref_count = self.inner().strong.fetch_add(1, Ordering::Relaxed);
283
284 assert!(ref_count < MAX_REFCOUNT);
285
286 return Self {
287 inner: self.inner.clone(),
288 };
289 }
290}
291
292impl Clone for WeakInternString {
293 fn clone(&self) -> Self {
294 let ref_count = self.inner().weak.fetch_add(1, Ordering::Relaxed);
295
296 assert!(ref_count < MAX_REFCOUNT);
297
298 return Self {
299 inner: self.inner.clone(),
300 };
301 }
302}
303
304impl PartialEq<WeakInternString> for InternString {
305 fn eq(&self, other: &WeakInternString) -> bool {
306 return ptr::addr_eq(self.inner.as_ptr(), other.inner.as_ptr());
307 }
308}
309
310impl PartialEq<InternString> for WeakInternString {
311 fn eq(&self, other: &InternString) -> bool {
312 return ptr::addr_eq(self.inner.as_ptr(), other.inner.as_ptr());
313 }
314}
315
316impl PartialEq<&str> for InternString {
317 fn eq(&self, other: &&str) -> bool {
318 return self.as_str() == *other;
319 }
320}
321
322impl PartialEq<&str> for WeakInternString {
323 fn eq(&self, other: &&str) -> bool {
324 let Some(str) = self.as_str() else {
325 return false;
326 };
327
328 return str == *other;
329 }
330}
331
332impl Deref for InternString {
333 type Target = str;
334
335 fn deref(&self) -> &Self::Target {
336 return self.as_str();
337 }
338}
339
340impl Debug for InternString {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 return f.debug_tuple("InternString").field(&self.as_str()).finish();
343 }
344}
345
346impl Display for InternString {
347 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348 return f.write_str(self.as_str());
349 }
350}
351
352impl Debug for WeakInternString {
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 let mut t = f.debug_tuple("WeakInternString");
355
356 let Some(s) = self.as_str() else {
357 return t.field(&"<dead>").finish();
358 };
359
360 return t.field(&s).finish();
361 }
362}
363
364impl Display for WeakInternString {
365 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366 let Some(s) = self.as_str() else {
367 return f.write_str(&"<dead>");
368 };
369
370 return f.write_str(&s);
371 }
372}
373
374impl FromStr for InternString {
375 type Err = ();
376
377 fn from_str(s: &str) -> Result<Self, Self::Err> {
378 return Ok(Self::new(s));
379 }
380}
381
382impl From<&str> for InternString {
383 fn from(value: &str) -> Self {
384 return Self::new(value);
385 }
386}
387
388impl From<String> for InternString {
389 fn from(value: String) -> Self {
390 return Self::new(&value);
391 }
392}
393
394impl Serialize for InternString {
395 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
396 where
397 S: serde::Serializer,
398 {
399 return serializer.serialize_str(self);
400 }
401}
402
403struct InternStringVisitor;
404
405impl<'a> Visitor<'a> for InternStringVisitor {
406 type Value = InternString;
407
408 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
409 formatter.write_str("a string")
410 }
411
412 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
413 where
414 E: serde::de::Error,
415 {
416 Ok(v.into())
417 }
418
419 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
420 where
421 E: serde::de::Error,
422 {
423 Ok(v.into())
424 }
425
426 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
427 where
428 E: serde::de::Error,
429 {
430 match str::from_utf8(v) {
431 Ok(s) => Ok(s.into()),
432 Err(_) => Err(serde::de::Error::invalid_value(
433 serde::de::Unexpected::Bytes(v),
434 &self,
435 )),
436 }
437 }
438
439 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
440 where
441 E: serde::de::Error,
442 {
443 match String::from_utf8(v) {
444 Ok(s) => Ok(s.into()),
445 Err(e) => Err(serde::de::Error::invalid_value(
446 serde::de::Unexpected::Bytes(&e.into_bytes()),
447 &self,
448 )),
449 }
450 }
451}
452
453impl<'de> Deserialize<'de> for InternString {
454 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
455 where
456 D: serde::Deserializer<'de>,
457 {
458 return deserializer.deserialize_string(InternStringVisitor);
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use std::sync::{Arc, Barrier, Mutex};
466 use std::thread;
467
468 #[test]
469 fn basic_interning() {
470 let s1 = InternString::new("hello");
471 let s2 = InternString::new("hello");
472 let s3 = InternString::new("world");
473
474 assert_eq!(s1, s2);
475 assert_ne!(s1, s3);
476
477 assert_eq!(s1.as_str(), "hello");
478 assert_eq!(s3.as_str(), "world");
479 }
480
481 #[test]
482 fn weak_references() {
483 let s1 = InternString::new("weak_test");
484 let w1 = s1.into_weak();
485
486 assert!(w1.is_alive());
487 assert_eq!(w1.as_str(), Some("weak_test"));
488
489 drop(s1);
490
491 assert!(!w1.is_alive());
492 assert_eq!(w1.as_str(), None);
493 }
494
495 #[test]
496 fn cloning_strong() {
497 let s1 = InternString::new("clone_test");
498 let s2 = s1.clone();
499 let s3 = s2.clone();
500
501 assert_eq!(s1.inner.as_ptr(), s2.inner.as_ptr());
502 assert_eq!(s2.inner.as_ptr(), s3.inner.as_ptr());
503
504 assert_eq!(s1.as_str(), "clone_test");
505 }
506
507 #[test]
508 fn cloning_weak() {
509 let s1 = InternString::new("weak_clone");
510 let w1 = s1.into_weak();
511 let w2 = w1.clone();
512
513 assert_eq!(w1.inner.as_ptr(), w2.inner.as_ptr());
514
515 drop(s1);
516
517 assert!(!w1.is_alive());
518 assert!(!w2.is_alive());
519 }
520
521 #[test]
522 fn eq_partial() {
523 let s = InternString::new("eq_test");
524 let w = s.into_weak();
525
526 assert_eq!(s, w);
527 assert_eq!(w, s);
528
529 assert_eq!(s, "eq_test");
530 assert_eq!(w, "eq_test");
531 }
532
533 #[test]
534 fn data_deduplication() {
535 let s1 = InternString::new("dedup");
536 let s2 = InternString::new("dedup");
537
538 assert!(ptr::addr_eq(s1.inner.as_ptr(), s2.inner.as_ptr()));
539 }
540
541 #[test]
542 fn drop_cleans_data() {
543 let s = InternString::new("cleanup_test");
544
545 drop(s);
546
547 assert!(DATA.get("cleanup_test").is_none());
548 }
549
550 #[test]
551 fn multithreaded_usage() {
552 let s = Arc::new(InternString::new("thread_test"));
553
554 let mut handles = vec![];
555
556 for _ in 0..10 {
557 let s_clone = Arc::clone(&s);
558 handles.push(thread::spawn(move || {
559 let local = s_clone.clone();
560 assert_eq!(local.as_str(), "thread_test");
561 }));
562 }
563
564 for handle in handles {
565 handle.join().unwrap();
566 }
567
568 assert_eq!(s.as_str(), "thread_test");
569 }
570
571 #[test]
572 fn weak_after_drop_multithreaded() {
573 let s = InternString::new("weak_thread");
574 let w = s.into_weak();
575
576 let handle = thread::spawn(move || {
577 drop(s);
578 });
579
580 handle.join().unwrap();
581
582 assert!(!w.is_alive());
583 assert_eq!(w.as_str(), None);
584 }
585
586 #[test]
587 fn simultaneous_intern() {
588 let barrier = Arc::new(Barrier::new(2));
589 let s1 = Arc::new(Mutex::new(None));
590 let s2 = Arc::new(Mutex::new(None));
591
592 let b1 = barrier.clone();
593 let s1c = s1.clone();
594 let t1 = thread::spawn(move || {
595 b1.wait();
596 *s1c.lock().unwrap() = Some(InternString::new("race_test"));
597 });
598
599 let b2 = barrier.clone();
600 let s2c = s2.clone();
601 let t2 = thread::spawn(move || {
602 b2.wait();
603 *s2c.lock().unwrap() = Some(InternString::new("race_test"));
604 });
605
606 t1.join().unwrap();
607 t2.join().unwrap();
608
609 let s1 = s1.lock().unwrap().take().unwrap();
610 let s2 = s2.lock().unwrap().take().unwrap();
611
612 assert!(ptr::addr_eq(s1.inner.as_ptr(), s2.inner.as_ptr()));
613 }
614
615 #[test]
616 fn weak_upgrade() {
617 let s = InternString::new("upgrade_test");
618 let w = s.into_weak();
619
620 let s2 = w.into_strong().expect("Should upgrade");
621 assert_eq!(s2.as_str(), "upgrade_test");
622
623 drop(s);
624 drop(s2);
625
626 assert!(w.into_strong().is_none());
627 }
628}