use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
const SENTINEL_ID: u32 = 0;
const SENTINEL_STR: &str = "";
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub(crate) struct InternId(pub(crate) u32);
impl InternId {
#[inline]
pub(crate) fn as_u32(self) -> u32 {
self.0
}
#[cfg(test)]
#[inline]
pub(crate) fn sentinel() -> Self {
Self(SENTINEL_ID)
}
#[cfg(test)]
#[inline]
pub(crate) fn is_sentinel(self) -> bool {
self.0 == SENTINEL_ID
}
}
pub(crate) struct StringInterner {
forward: RwLock<HashMap<Arc<str>, InternId>>,
reverse: RwLock<Vec<Arc<str>>>,
next_id: AtomicU32,
}
impl StringInterner {
pub(crate) fn new() -> Self {
let sentinel: Arc<str> = Arc::from(SENTINEL_STR);
let mut forward_map = HashMap::new();
forward_map.insert(Arc::clone(&sentinel), InternId(SENTINEL_ID));
Self {
forward: RwLock::new(forward_map),
reverse: RwLock::new(vec![sentinel]),
next_id: AtomicU32::new(1),
}
}
fn install_snapshot_iter<I>(&self, strings: I)
where
I: IntoIterator<Item = Arc<str>>,
{
let sentinel: Arc<str> = Arc::from(SENTINEL_STR);
let mut forward_map = HashMap::new();
let mut reverse = vec![Arc::clone(&sentinel)];
forward_map.insert(sentinel, InternId(SENTINEL_ID));
for (idx, string) in strings.into_iter().enumerate() {
let raw = u32::try_from(idx + 1).expect("interner snapshot exceeds u32 slots");
forward_map.insert(Arc::clone(&string), InternId(raw));
reverse.push(string);
}
*self.forward.write() = forward_map;
*self.reverse.write() = reverse;
self.next_id.store(
u32::try_from(self.reverse.read().len()).expect("interner size exceeds u32 slots"),
Ordering::Release,
);
}
pub(crate) fn intern(&self, s: &str) -> InternId {
{
let fwd = self.forward.read();
if let Some(&id) = fwd.get(s) {
return id;
}
}
let mut fwd = self.forward.write();
if let Some(&id) = fwd.get(s) {
return id;
}
let raw = self.next_id.fetch_add(1, Ordering::AcqRel);
let id = InternId(raw);
let arc: Arc<str> = Arc::from(s);
fwd.insert(Arc::clone(&arc), id);
{
let mut rev = self.reverse.write();
debug_assert_eq!(
u32::try_from(rev.len()).unwrap_or(u32::MAX),
raw,
"reverse table length mismatch: expected {raw}, got {}",
rev.len()
);
rev.push(arc);
}
id
}
#[cfg(test)]
pub(crate) fn resolve(&self, id: InternId) -> Option<Arc<str>> {
let rev = self.reverse.read();
rev.get(id.0 as usize).map(Arc::clone)
}
pub(crate) fn len(&self) -> usize {
(self.next_id.load(Ordering::Acquire)).saturating_sub(1) as usize
}
pub(crate) fn to_snapshot(&self) -> Vec<String> {
let rev = self.reverse.read();
rev.iter()
.skip(1)
.map(|arc| arc.as_ref().to_owned())
.collect()
}
#[cfg(test)]
pub(crate) fn from_snapshot(strings: Vec<String>) -> Self {
let interner = Self::new();
for s in strings {
let _ = interner.intern(&s);
}
interner
}
pub(crate) fn reset(&self) {
self.install_snapshot_iter(std::iter::empty());
}
pub(crate) fn replace_from_full_snapshot(&self, strings: &[String]) {
let iter = strings.iter().skip(1).map(|s| Arc::<str>::from(s.as_str()));
self.install_snapshot_iter(iter);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sentinel_is_slot_zero() {
let interner = StringInterner::new();
let resolved = interner
.resolve(InternId(0))
.expect("sentinel must always resolve");
assert_eq!(resolved.as_ref(), "");
}
#[test]
fn sentinel_is_sentinel() {
assert!(InternId::sentinel().is_sentinel());
let interner = StringInterner::new();
let real_id = interner.intern("hello");
assert!(!real_id.is_sentinel());
}
#[test]
fn intern_returns_stable_id() {
let interner = StringInterner::new();
let id1 = interner.intern("entity:1");
let id2 = interner.intern("entity:1");
assert_eq!(id1, id2);
}
#[test]
fn distinct_strings_get_distinct_ids() {
let interner = StringInterner::new();
let a = interner.intern("entity:1");
let b = interner.intern("entity:2");
assert_ne!(a, b);
}
#[test]
fn ids_are_one_based() {
let interner = StringInterner::new();
let id = interner.intern("first");
assert_eq!(id.as_u32(), 1);
}
#[test]
fn resolve_roundtrips() {
let interner = StringInterner::new();
let id = interner.intern("scope:orders");
let resolved = interner.resolve(id).expect("must resolve a valid id");
assert_eq!(resolved.as_ref(), "scope:orders");
}
#[test]
fn resolve_out_of_range_returns_none() {
let interner = StringInterner::new();
assert!(interner.resolve(InternId(999)).is_none());
}
#[test]
fn len_excludes_sentinel() {
let interner = StringInterner::new();
assert_eq!(interner.len(), 0);
let _ = interner.intern("a");
assert_eq!(interner.len(), 1);
let _ = interner.intern("a"); assert_eq!(interner.len(), 1);
let _ = interner.intern("b");
assert_eq!(interner.len(), 2);
}
#[test]
fn snapshot_roundtrip_preserves_ids() {
let original = StringInterner::new();
let id_a = original.intern("entity:apple");
let id_b = original.intern("entity:banana");
let snapshot = original.to_snapshot();
let restored = StringInterner::from_snapshot(snapshot);
assert_eq!(restored.intern("entity:apple"), id_a);
assert_eq!(restored.intern("entity:banana"), id_b);
}
#[test]
fn snapshot_excludes_sentinel() {
let interner = StringInterner::new();
let _ = interner.intern("x");
let snap = interner.to_snapshot();
assert!(!snap.iter().any(|s| s.is_empty()));
assert_eq!(snap.len(), 1);
}
#[test]
fn concurrent_intern_is_consistent() {
use std::sync::Arc;
use std::thread;
let interner = Arc::new(StringInterner::new());
let n_threads = 8_usize;
let n_strings = 50_usize;
let handles: Vec<_> = (0..n_threads)
.map(|t| {
let interner = Arc::clone(&interner);
thread::Builder::new()
.name(format!("test-intern-{t}"))
.spawn(move || {
(0..n_strings)
.map(|i| {
let s = format!("string:{i}");
interner.intern(&s)
})
.collect::<Vec<_>>()
})
.expect("thread spawn must succeed in tests")
})
.collect();
let all_results: Vec<Vec<InternId>> = handles
.into_iter()
.map(|h| h.join().expect("thread must not panic"))
.collect();
for i in 0..n_strings {
let s = format!("string:{i}");
let expected_id = interner.intern(&s);
for thread_results in &all_results {
assert_eq!(thread_results[i], expected_id, "mismatch for {s}");
}
}
assert_eq!(interner.len(), n_strings);
}
}