use std::collections::hash_map;
use std::collections::HashMap;
use std::hash;
use std::num;
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct Interner<K = num::NonZeroU32, S = hash_map::RandomState> {
buffer: String,
ends: Vec<usize>,
#[cfg_attr(feature = "serde", serde(skip))]
dedup: DeDupMap<K>,
#[cfg_attr(feature = "serde", serde(skip))]
hash_builder: S,
}
impl<K, S: Default> Default for Interner<K, S> {
fn default() -> Self {
Self {
buffer: Default::default(),
ends: Default::default(),
dedup: Default::default(),
hash_builder: Default::default(),
}
}
}
pub trait Key: Copy + Eq {
fn try_from_usize(index: usize) -> Option<Self>;
fn into_usize(self) -> usize;
}
impl Key for num::NonZeroU32 {
fn try_from_usize(index: usize) -> Option<Self> {
let u32: u32 = match index.try_into() {
Ok(u32) => u32,
Err(_) => return None,
};
num::NonZeroU32::new(u32 + 1)
}
fn into_usize(self) -> usize {
self.get() as usize
}
}
impl<K: Key, S: hash::BuildHasher> Interner<K, S> {
pub fn get_or_intern(&mut self, s: &str) -> K {
let hash = hash(&self.hash_builder, s);
if let Some(key) = self.get_internal(s, hash) {
return key;
}
let key = K::try_from_usize(self.ends.len()).unwrap();
self.buffer.push_str(s);
let end = self.buffer.len();
self.ends.push(end);
populate_dedup_map(&mut self.dedup, hash, key);
key
}
pub fn get(&self, s: &str) -> Option<K> {
self.get_internal(s, hash(&self.hash_builder, s))
}
fn get_internal(&self, s: &str, hash: u64) -> Option<K> {
let mut node_or = self.dedup.get(&hash);
while let Some(node) = node_or {
if self.resolve(node.key).unwrap() == s {
return Some(node.key);
}
node_or = match &node.next {
None => None,
Some(node) => Some(node),
};
}
None
}
pub fn resolve(&self, k: K) -> Option<&str> {
let i = k.into_usize().wrapping_sub(1);
let start = match i.checked_sub(1) {
None => 0,
Some(prev_k) => match self.ends.get(prev_k) {
None => return None,
Some(start) => *start,
},
};
let end = match self.ends.get(i) {
None => return None,
Some(end) => *end,
};
Some(&self.buffer[start..end])
}
}
fn hash<S: hash::BuildHasher>(hash_builder: &S, s: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = hash_builder.build_hasher();
s.hash(&mut hasher);
hasher.finish()
}
type DeDupMap<K> = HashMap<u64, LinkedList<K>, hash::BuildHasherDefault<SingleU64Hasher>>;
fn populate_dedup_map<K>(map: &mut DeDupMap<K>, hash: u64, key: K) {
match map.entry(hash) {
hash_map::Entry::Occupied(mut o) => {
let first = o.get_mut();
let second = std::mem::replace(first, LinkedList { key, next: None });
first.next = Some(Box::new(second));
}
hash_map::Entry::Vacant(v) => {
v.insert(LinkedList { key, next: None });
}
};
}
struct LinkedList<K> {
key: K,
next: Option<Box<LinkedList<K>>>,
}
#[derive(Default)]
struct SingleU64Hasher {
val: Option<u64>,
}
impl hash::Hasher for SingleU64Hasher {
#[inline]
fn finish(&self) -> u64 {
self.val.unwrap()
}
fn write(&mut self, _: &[u8]) {
panic!("this hasher does not support writing arbitrary bytes, only a single u64 value")
}
#[inline]
fn write_u64(&mut self, i: u64) {
if self.val.is_some() {
panic!("this hasher does not support writing multiple u64 values")
}
self.val = Some(i)
}
}
#[cfg(feature = "serde")]
impl<'de, K: Key, S: Default + hash::BuildHasher> serde::Deserialize<'de> for Interner<K, S> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct DeserializedInterner {
buffer: String,
ends: Vec<usize>,
}
let DeserializedInterner { buffer, ends } =
DeserializedInterner::deserialize(deserializer)?;
let hash_builder = S::default();
let mut dedup = DeDupMap::<K>::default();
dedup.reserve(ends.len());
let mut start: usize = 0;
for (i, end) in ends.iter().enumerate() {
let s = &buffer[start..*end];
let hash = hash(&hash_builder, s);
let key = K::try_from_usize(i).unwrap();
populate_dedup_map(&mut dedup, hash, key);
start = *end;
}
Ok(Self {
buffer,
ends,
dedup,
hash_builder,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Default)]
struct FixedHasher;
impl hash::Hasher for FixedHasher {
fn finish(&self) -> u64 {
12
}
fn write(&mut self, _: &[u8]) {}
}
#[test]
fn test_hash_collision() {
let mut interner: Interner<num::NonZeroU32, hash::BuildHasherDefault<FixedHasher>> =
Default::default();
let hello_1 = interner.get_or_intern("hello");
let world_1 = interner.get_or_intern("world");
let hello_2 = interner.get_or_intern("hello");
assert_eq!(hello_1, hello_2);
assert_ne!(hello_1, world_1);
assert_eq!(interner.resolve(hello_1), Some("hello"));
assert_eq!(interner.resolve(world_1), Some("world"));
}
#[cfg(feature = "serde")]
#[test]
fn test_serde() {
let mut interner: Interner = Default::default();
let hello_1 = interner.get_or_intern("hello");
let world_1 = interner.get_or_intern("world");
let serialized = serde_json::to_string_pretty(&interner).unwrap();
let mut interner_de: Interner = serde_json::from_str(&serialized).unwrap();
let hello_2 = interner_de.get_or_intern("hello");
let world_2 = interner_de.get_or_intern("world");
assert_eq!(hello_1, hello_2);
assert_eq!(world_1, world_2);
}
}