use std::sync::{Arc, RwLock};
use dashmap::DashMap;
#[derive(Debug, Default)]
pub struct Interner {
to_id: DashMap<Arc<str>, u32>,
to_str: RwLock<Vec<Arc<str>>>,
}
impl Interner {
pub fn intern(&self, s: Arc<str>) -> u32 {
if let Some(id) = self.to_id.get(s.as_ref()) {
return *id;
}
let mut vec = self.to_str.write().expect("interner lock poisoned");
if let Some(id) = self.to_id.get(s.as_ref()) {
return *id;
}
let id = vec.len() as u32;
vec.push(s.clone());
self.to_id.insert(s, id);
id
}
pub fn intern_str(&self, s: &str) -> u32 {
if let Some(id) = self.to_id.get(s) {
return *id;
}
self.intern(Arc::from(s))
}
pub fn get(&self, id: u32) -> Arc<str> {
self.to_str.read().expect("interner lock poisoned")[id as usize].clone()
}
pub fn get_id(&self, s: &str) -> Option<u32> {
self.to_id.get(s).map(|id| *id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn same_string_gives_same_id() {
let interner = Interner::default();
let a = interner.intern_str("Foo::bar");
let b = interner.intern_str("Foo::bar");
assert_eq!(a, b);
}
#[test]
fn different_strings_give_different_ids() {
let interner = Interner::default();
let a = interner.intern_str("Foo::bar");
let b = interner.intern_str("Foo::baz");
assert_ne!(a, b);
}
#[test]
fn get_roundtrips_id_to_string() {
let interner = Interner::default();
let id = interner.intern_str("App\\Service");
assert_eq!(interner.get(id).as_ref(), "App\\Service");
}
#[test]
fn get_id_returns_none_for_unknown_string() {
let interner = Interner::default();
assert!(interner.get_id("unknown").is_none());
}
#[test]
fn intern_and_intern_str_agree() {
let interner = Interner::default();
let id_arc = interner.intern(Arc::from("hello"));
let id_str = interner.intern_str("hello");
assert_eq!(id_arc, id_str);
}
#[test]
fn concurrent_intern_is_consistent() {
use std::sync::Arc as StdArc;
let interner = StdArc::new(Interner::default());
let handles: Vec<_> = (0..8)
.map(|_| {
let i = interner.clone();
std::thread::spawn(move || i.intern_str("shared"))
})
.collect();
let ids: Vec<u32> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert!(
ids.iter().all(|&id| id == ids[0]),
"all threads must see the same ID"
);
}
}