use crate::handle_hash_map::{Handle, HandleHashMap, InsertError};
use crate::hash::DefaultHashBuilder;
use crate::tokens::{Count, Token, UsizeCount};
#[derive(Debug)]
pub struct Counted<V> {
pub refcount: UsizeCount,
pub value: V,
}
impl<V> Counted<V> {
pub fn new(value: V, initial: usize) -> Self {
Self {
refcount: UsizeCount::new(initial),
value,
}
}
}
pub struct CountedHashMap<K, V, S = DefaultHashBuilder> {
pub(crate) inner: HandleHashMap<K, Counted<V>, S>,
}
pub struct CountedHandle<'a> {
pub(crate) handle: Handle,
pub(crate) token: Token<'a, UsizeCount>, }
impl<'a> CountedHandle<'a> {
pub fn key_ref<'m, K, V, S>(&self, map: &'m CountedHashMap<K, V, S>) -> Option<&'m K>
where
K: Eq + core::hash::Hash,
S: core::hash::BuildHasher + Clone + Default,
{
map.inner.handle_key(self.handle)
}
pub fn value_ref<'m, K, V, S>(&self, map: &'m CountedHashMap<K, V, S>) -> Option<&'m V>
where
K: Eq + core::hash::Hash,
S: core::hash::BuildHasher + Clone + Default,
{
map.inner.handle_value(self.handle).map(|c| &c.value)
}
pub fn value_mut<'m, K, V, S>(&self, map: &'m mut CountedHashMap<K, V, S>) -> Option<&'m mut V>
where
K: Eq + core::hash::Hash,
S: core::hash::BuildHasher + Clone + Default,
{
map.inner
.handle_value_mut(self.handle)
.map(|c| &mut c.value)
}
}
pub enum PutResult<K, V> {
Live,
Removed { key: K, value: V },
}
impl<K, V> CountedHashMap<K, V>
where
K: Eq + core::hash::Hash,
{
pub fn new() -> Self {
Self {
inner: HandleHashMap::new(),
}
}
}
impl<K, V> Default for CountedHashMap<K, V>
where
K: Eq + core::hash::Hash,
{
fn default() -> Self {
Self::new()
}
}
pub(crate) struct Iter<'a, K, V, S> {
pub(crate) it: crate::handle_hash_map::Iter<'a, K, Counted<V>, S>,
pub(crate) _pd: core::marker::PhantomData<&'a (K, V, S)>,
}
impl<'a, K, V, S> Iterator for Iter<'a, K, V, S> {
type Item = (CountedHandle<'static>, &'a K, &'a V);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.it.next().map(|(h, k, c)| {
let ch = CountedHandle {
handle: h,
token: c.refcount.get(),
};
(ch, k, &c.value)
})
}
}
pub(crate) struct IterMut<'a, K, V, S> {
pub(crate) it: crate::handle_hash_map::IterMut<'a, K, Counted<V>, S>,
pub(crate) _pd: core::marker::PhantomData<&'a (K, V, S)>,
}
impl<'a, K, V, S> Iterator for IterMut<'a, K, V, S> {
type Item = (CountedHandle<'static>, &'a K, &'a mut V);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.it.next().map(|(h, k, c)| {
let token = c.refcount.get();
let ch = CountedHandle { handle: h, token };
(ch, k, &mut c.value)
})
}
}
impl<K, V, S> CountedHashMap<K, V, S>
where
K: Eq + core::hash::Hash,
S: core::hash::BuildHasher + Clone + Default,
{
pub fn with_hasher(hasher: S) -> Self {
Self {
inner: HandleHashMap::with_hasher(hasher),
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn find<Q>(&self, q: &Q) -> Option<CountedHandle<'static>>
where
K: core::borrow::Borrow<Q>,
Q: ?Sized + core::hash::Hash + Eq,
{
let handle = self.inner.find(q)?;
let entry = self.inner.handle_value(handle)?;
let counter = &entry.refcount;
let token = counter.get();
Some(CountedHandle { handle, token })
}
pub fn contains_key<Q>(&self, q: &Q) -> bool
where
K: core::borrow::Borrow<Q>,
Q: ?Sized + core::hash::Hash + Eq,
{
self.inner.contains_key(q)
}
#[allow(dead_code)]
pub fn insert(&mut self, key: K, value: V) -> Result<CountedHandle<'static>, InsertError> {
let counted = Counted::new(value, 0);
match self.inner.insert(key, counted) {
Ok(handle) => {
let entry = self
.inner
.handle_value(handle)
.expect("entry must exist immediately after successful insert");
let counter = &entry.refcount;
let token = counter.get();
Ok(CountedHandle { handle, token })
}
Err(e) => Err(e),
}
}
pub fn get(&self, h: &CountedHandle<'_>) -> CountedHandle<'static> {
let entry = self
.inner
.handle_value(h.handle)
.expect("handle must be valid while counted handle is live");
let token = entry.refcount.get();
CountedHandle {
handle: h.handle,
token,
}
}
pub fn insert_with<F>(
&mut self,
key: K,
default: F,
) -> Result<CountedHandle<'static>, InsertError>
where
F: FnOnce() -> V,
{
match self.inner.insert_with(key, || Counted::new(default(), 0)) {
Ok(handle) => {
let entry = self
.inner
.handle_value(handle)
.expect("entry must exist immediately after successful insert");
let token = entry.refcount.get();
Ok(CountedHandle { handle, token })
}
Err(e) => Err(e),
}
}
pub fn put(&mut self, h: CountedHandle<'_>) -> PutResult<K, V> {
let CountedHandle { handle, token, .. } = h;
let entry = self
.inner
.handle_value(handle)
.expect("CountedHandle must refer to a live entry when returned to put()");
let now_zero = entry.refcount.put(token);
if now_zero {
let (k, v) = self
.inner
.remove(handle)
.expect("entry must exist when count reaches zero");
PutResult::Removed {
key: k,
value: v.value,
}
} else {
PutResult::Live
}
}
#[allow(dead_code)]
pub fn iter(&self) -> impl Iterator<Item = (Handle, &K, &V)> {
self.inner.iter().map(|(h, k, c)| (h, k, &c.value))
}
#[allow(dead_code)]
pub fn iter_mut(&mut self) -> impl Iterator<Item = (Handle, &K, &mut V)> {
self.inner.iter_mut().map(|(h, k, c)| (h, k, &mut c.value))
}
pub(crate) fn iter_raw(&self) -> Iter<'_, K, V, S> {
let it = self.inner.iter();
Iter {
it,
_pd: core::marker::PhantomData,
}
}
pub(crate) fn iter_mut_raw(&mut self) -> IterMut<'_, K, V, S> {
let it = self.inner.iter_mut();
IterMut {
it,
_pd: core::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::cell::Cell;
use std::collections::BTreeSet;
proptest! {
#[test]
fn prop_counted_hashmap_liveness(keys in 1usize..=5, ops in proptest::collection::vec((0u8..=4u8, 0usize..100usize), 1..100)) {
let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
let mut live: Vec<Vec<CountedHandle<'static>>> = std::iter::repeat_with(Vec::new).take(keys).collect();
for (op, raw_k) in ops.into_iter() {
let k = raw_k % keys;
let key = format!("k{}", k);
match op {
0 => {
let res = m.insert(key.clone(), k as i32);
match res {
Ok(h) => live[k].push(h),
Err(InsertError::DuplicateKey) => {}
}
}
1 => {
if let Some(h) = m.find(&key) {
live[k].push(h);
}
}
2 => {
if let Some(h) = live[k].pop() {
let h2 = m.get(&h);
live[k].push(h);
live[k].push(h2);
}
}
3 => {
if let Some(h) = live[k].pop() {
match m.put(h) {
PutResult::Live => {}
PutResult::Removed { key: _, value: _ } => {
prop_assert!(live[k].is_empty());
}
}
}
}
4 => {
while let Some(h) = live[k].pop() { let _ = m.put(h); }
}
_ => unreachable!(),
}
let present = m.contains_key(&key);
prop_assert_eq!(present, !live[k].is_empty());
}
for (k, handles) in live.iter_mut().enumerate() {
while let Some(h) = handles.pop() { let _ = m.put(h); }
let key = format!("k{}", k);
prop_assert_eq!(m.contains_key(&key), false);
}
}
}
#[test]
fn insert_with_is_lazy_and_mints_token() {
use crate::handle_hash_map::InsertError;
let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
let calls = Cell::new(0);
let ch = m
.insert_with("k".to_string(), || {
calls.set(calls.get() + 1);
7
})
.unwrap();
assert_eq!(calls.get(), 1);
assert_eq!(ch.value_ref(&m), Some(&7));
{
let dup = m.insert_with("k".to_string(), || {
calls.set(calls.get() + 1);
99
});
match dup {
Err(InsertError::DuplicateKey) => {}
_ => panic!("unexpected result"),
}
}
assert_eq!(calls.get(), 1);
match m.put(ch) {
PutResult::Removed { key, value } => {
assert_eq!(key, "k".to_string());
assert_eq!(value, 7);
}
_ => panic!("expected removal"),
}
assert!(!m.contains_key(&"k".to_string()));
}
#[test]
fn insert_with_then_mutate_value() {
let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
let ch = m.insert_with("k".to_string(), || 10).unwrap();
if let Some(v) = ch.value_mut(&mut m) {
*v += 5;
}
assert_eq!(ch.value_ref(&m), Some(&15));
let _ = m.put(ch);
}
#[test]
fn get_mints_new_token_and_put_removes_at_zero() {
let mut m: CountedHashMap<&'static str, i32> = CountedHashMap::new();
let h1 = m.insert("a", 1).unwrap();
let h2 = m.get(&h1);
match m.put(h1) {
PutResult::Live => {}
_ => panic!("expected Live when one handle remains"),
}
assert!(m.contains_key(&"a"));
match m.put(h2) {
PutResult::Removed { key, value } => {
assert_eq!(key, "a");
assert_eq!(value, 1);
}
_ => panic!("expected Removed at zero"),
}
assert!(!m.contains_key(&"a"));
}
#[test]
fn key_ref_value_ref_and_mutation_persist() {
let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
let h = m.insert("k1".to_string(), 10).unwrap();
assert_eq!(h.key_ref(&m), Some(&"k1".to_string()));
assert_eq!(h.value_ref(&m), Some(&10));
if let Some(v) = h.value_mut(&mut m) {
*v += 7;
}
assert_eq!(h.value_ref(&m), Some(&17));
let _ = m.put(h);
}
#[test]
fn iter_yields_all_entries_once_and_iter_mut_updates_values() {
let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
let keys = ["k1", "k2", "k3", "k4"];
let mut handles = Vec::new();
for (i, k) in keys.iter().enumerate() {
handles.push(m.insert((*k).to_string(), i as i32).unwrap());
}
let seen: BTreeSet<String> = m.iter().map(|(_h, k, _v)| k.clone()).collect();
let expected: BTreeSet<String> = keys.iter().map(|s| (*s).to_string()).collect();
assert_eq!(seen, expected);
for (_h, _k, v) in m.iter_mut() {
*v += 100;
}
for (i, _k) in keys.iter().enumerate() {
let hv = handles[i].value_ref(&m).copied();
assert_eq!(hv, Some((i as i32) + 100));
}
for h in handles {
let _ = m.put(h);
}
}
#[test]
fn iter_raw_requires_put_and_keeps_entries_live() {
let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
let h1 = m.insert("a".to_string(), 1).unwrap();
let h2 = m.insert("b".to_string(), 2).unwrap();
let h3 = m.insert("c".to_string(), 3).unwrap();
let mut raw: Vec<CountedHandle<'static>> = m.iter_raw().map(|(ch, _k, _v)| ch).collect();
match m.put(h1) {
PutResult::Live => {}
_ => panic!("expected Live"),
}
match m.put(h2) {
PutResult::Live => {}
_ => panic!("expected Live"),
}
match m.put(h3) {
PutResult::Live => {}
_ => panic!("expected Live"),
}
assert!(m.contains_key(&"a".to_string()));
assert!(m.contains_key(&"b".to_string()));
assert!(m.contains_key(&"c".to_string()));
let mut removed: BTreeSet<String> = BTreeSet::new();
while let Some(ch) = raw.pop() {
match m.put(ch) {
PutResult::Removed { key, value } => {
removed.insert(key.clone());
match key.as_str() {
"a" => assert_eq!(value, 1),
"b" => assert_eq!(value, 2),
"c" => assert_eq!(value, 3),
_ => unreachable!(),
}
}
PutResult::Live => {}
}
}
assert_eq!(
removed,
["a", "b", "c"].into_iter().map(|s| s.to_string()).collect()
);
assert!(!m.contains_key(&"a".to_string()));
assert!(!m.contains_key(&"b".to_string()));
assert!(!m.contains_key(&"c".to_string()));
}
#[test]
fn iter_mut_raw_requires_put_and_keeps_entries_live() {
let mut m: CountedHashMap<&'static str, i32> = CountedHashMap::new();
let h1 = m.insert("x", 10).unwrap();
let h2 = m.insert("y", 20).unwrap();
let mut raw: Vec<CountedHandle<'static>> = m
.iter_mut_raw()
.map(|(ch, _k, v)| {
*v += 1;
ch
})
.collect();
assert!(matches!(m.put(h1), PutResult::Live));
assert!(matches!(m.put(h2), PutResult::Live));
assert!(m.contains_key(&"x"));
assert!(m.contains_key(&"y"));
let xr = m.find(&"x").unwrap();
let yr = m.find(&"y").unwrap();
assert_eq!(xr.value_ref(&m), Some(&11));
assert_eq!(yr.value_ref(&m), Some(&21));
let _ = m.put(xr);
let _ = m.put(yr);
let mut removed = 0;
while let Some(ch) = raw.pop() {
match m.put(ch) {
PutResult::Removed { key, value } => {
removed += 1;
match key {
"x" => assert_eq!(value, 11),
"y" => assert_eq!(value, 21),
_ => unreachable!(),
}
}
PutResult::Live => {}
}
}
assert_eq!(removed, 2);
assert!(!m.contains_key(&"x"));
assert!(!m.contains_key(&"y"));
}
#[test]
fn dropping_counted_handle_without_put_panics() {
use std::panic::{catch_unwind, AssertUnwindSafe};
let res = catch_unwind(AssertUnwindSafe(|| {
let mut m: CountedHashMap<&'static str, i32> = CountedHashMap::new();
let h = m.insert("boom", 1).unwrap();
drop(h); }));
assert!(
res.is_err(),
"expected panic when CountedHandle is dropped without put"
);
let res2 = catch_unwind(AssertUnwindSafe(|| {
let m: CountedHashMap<&'static str, i32> = {
let mut mm = CountedHashMap::new();
let _ = mm.insert("a", 1).unwrap();
let _ = mm.insert("b", 2).unwrap();
mm
};
let v: Vec<_> = m.iter_raw().collect();
drop(v); }));
assert!(
res2.is_err(),
"expected panic when raw handles are dropped without put"
);
}
}