use crate::{InternerSymbol, Symbol};
use boxcar::Vec as LFVec;
use bumpalo::Bump;
use dashmap::DashMap;
use hashbrown::hash_table;
use std::{collections::hash_map::RandomState, hash::BuildHasher};
use thread_local::ThreadLocal;
pub(crate) type Map<S> = DashMap<MapKey, S, NoHasherBuilder>;
pub(crate) type MapKey = (u64, &'static [u8]);
pub(crate) type RawMapKey<S> = (MapKey, S);
type Arena = ThreadLocal<Bump>;
pub struct BytesInterner<S = Symbol, H = RandomState> {
pub(crate) map: Map<S>,
hash_builder: H,
strs: LFVec<&'static [u8]>,
arena: Arena,
}
impl Default for BytesInterner {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl BytesInterner<Symbol, RandomState> {
#[inline]
pub fn new() -> Self {
Self::with_capacity(0)
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self::with_capacity_and_hasher(capacity, Default::default())
}
}
impl<S: InternerSymbol, H: BuildHasher> BytesInterner<S, H> {
#[inline]
pub fn with_hasher(hash_builder: H) -> Self {
Self::with_capacity_and_hasher(0, hash_builder)
}
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: H) -> Self {
let map = Map::with_capacity_and_hasher(capacity, Default::default());
let strs = LFVec::with_capacity(capacity);
Self { map, strs, arena: Default::default(), hash_builder }
}
#[inline]
pub fn len(&self) -> usize {
self.strs.count()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn iter(&self) -> impl ExactSizeIterator<Item = (S, &[u8])> + Clone {
self.all_symbols().map(|s| (s, self.resolve(s)))
}
#[inline]
pub fn all_symbols(&self) -> impl ExactSizeIterator<Item = S> + Send + Sync + Clone {
(0..self.len()).map(S::from_usize)
}
pub fn intern(&self, s: &[u8]) -> S {
self.do_intern(s, alloc)
}
pub fn intern_mut(&mut self, s: &[u8]) -> S {
self.do_intern_mut(s, alloc)
}
pub fn intern_static<'a, 'b: 'a>(&'a self, s: &'b [u8]) -> S {
self.do_intern(s, no_alloc)
}
pub fn intern_mut_static<'a, 'b: 'a>(&'a mut self, s: &'b [u8]) -> S {
self.do_intern_mut(s, no_alloc)
}
pub fn intern_many<'a>(&self, strings: impl IntoIterator<Item = &'a [u8]>) {
for s in strings {
self.intern(s);
}
}
pub fn intern_many_mut<'a>(&mut self, strings: impl IntoIterator<Item = &'a [u8]>) {
for s in strings {
self.intern_mut(s);
}
}
pub fn intern_many_static<'a, 'b: 'a>(&'a self, strings: impl IntoIterator<Item = &'b [u8]>) {
for s in strings {
self.intern_static(s);
}
}
pub fn intern_many_mut_static<'a, 'b: 'a>(
&'a mut self,
strings: impl IntoIterator<Item = &'b [u8]>,
) {
for s in strings {
self.intern_mut_static(s);
}
}
#[inline]
#[must_use]
#[cfg_attr(debug_assertions, track_caller)]
pub fn resolve(&self, sym: S) -> &[u8] {
if cfg!(debug_assertions) {
self.strs.get(sym.to_usize()).expect("symbol out of bounds")
} else {
unsafe { self.strs.get_unchecked(sym.to_usize()) }
}
}
#[inline]
fn do_intern(&self, s: &[u8], alloc: impl FnOnce(&Arena, &[u8]) -> &'static [u8]) -> S {
let hash = self.hash(s);
let shard_idx = self.map.determine_shard(hash as usize);
let shard = &*self.map.shards()[shard_idx];
if let Some((_, v)) = cvt(&shard.read()).find(hash, mk_eq(s)) {
return *v.get();
}
get_or_insert(&self.strs, &self.arena, s, hash, cvt_mut(&mut shard.write()), alloc)
}
#[inline]
fn do_intern_mut(&mut self, s: &[u8], alloc: impl FnOnce(&Arena, &[u8]) -> &'static [u8]) -> S {
let hash = self.hash(s);
let shard_idx = self.map.determine_shard(hash as usize);
let shard = &mut *self.map.shards_mut()[shard_idx];
get_or_insert(&self.strs, &self.arena, s, hash, cvt_mut(shard.get_mut()), alloc)
}
#[inline]
fn hash(&self, s: &[u8]) -> u64 {
use std::hash::Hasher;
let mut h = self.hash_builder.build_hasher();
h.write(s);
h.finish()
}
}
pub(crate) type NoHasherBuilder = std::hash::BuildHasherDefault<NoHasher>;
pub(crate) enum NoHasher {}
impl Default for NoHasher {
#[inline]
fn default() -> Self {
unreachable!()
}
}
impl std::hash::Hasher for NoHasher {
#[inline]
fn finish(&self) -> u64 {
match *self {}
}
#[inline]
fn write(&mut self, _bytes: &[u8]) {
match *self {}
}
}
#[inline]
fn get_or_insert<S: InternerSymbol>(
strs: &LFVec<&'static [u8]>,
arena: &Arena,
s: &[u8],
hash: u64,
shard: &mut hash_table::HashTable<RawMapKey<dashmap::SharedValue<S>>>,
alloc: impl FnOnce(&Arena, &[u8]) -> &'static [u8],
) -> S {
match shard.entry(hash, mk_eq(s), hasher) {
hash_table::Entry::Occupied(e) => *e.get().1.get(),
hash_table::Entry::Vacant(e) => {
let s = alloc(arena, s);
let i = strs.push(s);
let new_sym = S::from_usize(i);
e.insert(((hash, s), dashmap::SharedValue::new(new_sym)));
new_sym
}
}
}
#[inline]
fn hasher<S>(((hash, _), _): &RawMapKey<S>) -> u64 {
*hash
}
#[inline]
fn mk_eq<S>(s: &[u8]) -> impl Fn(&RawMapKey<S>) -> bool + Copy + '_ {
move |((_, ss), _): &RawMapKey<S>| s == *ss
}
#[inline]
fn alloc(arena: &Arena, s: &[u8]) -> &'static [u8] {
unsafe {
std::mem::transmute::<&[u8], &'static [u8]>(arena.get_or_default().alloc_slice_copy(s))
}
}
#[inline]
fn no_alloc(_: &Arena, s: &[u8]) -> &'static [u8] {
unsafe { std::mem::transmute::<&[u8], &'static [u8]>(s) }
}
#[inline]
fn cvt<T>(old: &hashbrown::raw::RawTable<T>) -> &hash_table::HashTable<T> {
unsafe { std::mem::transmute(old) }
}
#[inline]
fn cvt_mut<T>(old: &mut hashbrown::raw::RawTable<T>) -> &mut hash_table::HashTable<T> {
unsafe { std::mem::transmute(old) }
}