use core::hash::{Hash, Hasher};
use hashbrown::hash_table::HashTable;
pub trait CtxEq<V1: ?Sized, V2: ?Sized> {
fn ctx_eq(&self, a: &V1, b: &V2) -> bool;
}
pub trait CtxHash<Value: ?Sized>: CtxEq<Value, Value> {
fn ctx_hash<H: Hasher>(&self, state: &mut H, value: &Value);
}
#[derive(Default)]
pub struct NullCtx;
impl<V: Eq + Hash> CtxEq<V, V> for NullCtx {
fn ctx_eq(&self, a: &V, b: &V) -> bool {
a.eq(b)
}
}
impl<V: Eq + Hash> CtxHash<V> for NullCtx {
fn ctx_hash<H: Hasher>(&self, state: &mut H, value: &V) {
value.hash(state);
}
}
struct BucketData<K, V> {
hash: u32,
k: K,
v: V,
}
pub struct CtxHashMap<K, V> {
raw: HashTable<BucketData<K, V>>,
}
impl<K, V> CtxHashMap<K, V> {
pub fn with_capacity(capacity: usize) -> Self {
Self {
raw: HashTable::with_capacity(capacity),
}
}
}
fn compute_hash<Ctx, K>(ctx: &Ctx, k: &K) -> u32
where
Ctx: CtxHash<K>,
{
let mut hasher = rustc_hash::FxHasher::default();
ctx.ctx_hash(&mut hasher, k);
hasher.finish() as u32
}
impl<K, V> CtxHashMap<K, V> {
pub fn insert<Ctx>(&mut self, k: K, v: V, ctx: &Ctx) -> Option<V>
where
Ctx: CtxEq<K, K> + CtxHash<K>,
{
let hash = compute_hash(ctx, &k);
match self.raw.find_mut(hash as u64, |bucket| {
hash == bucket.hash && ctx.ctx_eq(&bucket.k, &k)
}) {
Some(bucket) => Some(core::mem::replace(&mut bucket.v, v)),
None => {
let data = BucketData { hash, k, v };
self.raw
.insert_unique(hash as u64, data, |bucket| bucket.hash as u64);
None
}
}
}
pub fn get<'a, Q, Ctx>(&'a self, k: &Q, ctx: &Ctx) -> Option<&'a V>
where
Ctx: CtxEq<K, Q> + CtxHash<Q> + CtxHash<K>,
{
let hash = compute_hash(ctx, k);
self.raw
.find(hash as u64, |bucket| {
hash == bucket.hash && ctx.ctx_eq(&bucket.k, k)
})
.map(|bucket| &bucket.v)
}
pub fn entry<'a, Ctx>(&'a mut self, k: K, ctx: &Ctx) -> Entry<'a, K, V>
where
Ctx: CtxEq<K, K> + CtxHash<K>,
{
let hash = compute_hash(ctx, &k);
let raw = self.raw.entry(
hash as u64,
|bucket| hash == bucket.hash && ctx.ctx_eq(&bucket.k, &k),
|bucket| compute_hash(ctx, &bucket.k) as u64,
);
match raw {
hashbrown::hash_table::Entry::Occupied(o) => Entry::Occupied(OccupiedEntry { raw: o }),
hashbrown::hash_table::Entry::Vacant(v) => Entry::Vacant(VacantEntry {
hash,
key: k,
raw: v,
}),
}
}
}
pub enum Entry<'a, K, V> {
Occupied(OccupiedEntry<'a, K, V>),
Vacant(VacantEntry<'a, K, V>),
}
pub struct OccupiedEntry<'a, K, V> {
raw: hashbrown::hash_table::OccupiedEntry<'a, BucketData<K, V>>,
}
pub struct VacantEntry<'a, K, V> {
hash: u32,
key: K,
raw: hashbrown::hash_table::VacantEntry<'a, BucketData<K, V>>,
}
impl<'a, K, V> OccupiedEntry<'a, K, V> {
pub fn get(&self) -> &V {
&self.raw.get().v
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.raw.get_mut().v
}
}
impl<'a, K, V> VacantEntry<'a, K, V> {
pub fn insert(self, v: V) {
self.raw.insert(BucketData {
hash: self.hash,
k: self.key,
v,
});
}
}
#[cfg(test)]
mod test {
use super::*;
#[derive(Clone, Copy, Debug)]
struct Key {
index: u32,
}
struct Ctx {
vals: &'static [&'static str],
}
impl CtxEq<Key, Key> for Ctx {
fn ctx_eq(&self, a: &Key, b: &Key) -> bool {
self.vals[a.index as usize].eq(self.vals[b.index as usize])
}
}
impl CtxHash<Key> for Ctx {
fn ctx_hash<H: Hasher>(&self, state: &mut H, value: &Key) {
self.vals[value.index as usize].hash(state);
}
}
#[test]
fn test_basic() {
let ctx = Ctx {
vals: &["a", "b", "a"],
};
let k0 = Key { index: 0 };
let k1 = Key { index: 1 };
let k2 = Key { index: 2 };
assert!(ctx.ctx_eq(&k0, &k2));
assert!(!ctx.ctx_eq(&k0, &k1));
assert!(!ctx.ctx_eq(&k2, &k1));
let mut map: CtxHashMap<Key, u64> = CtxHashMap::with_capacity(4);
assert_eq!(map.insert(k0, 42, &ctx), None);
assert_eq!(map.insert(k2, 84, &ctx), Some(42));
assert_eq!(map.get(&k1, &ctx), None);
assert_eq!(*map.get(&k0, &ctx).unwrap(), 84);
}
}