1use crate::handle_hash_map::{Handle, HandleHashMap, InsertError};
4use crate::hash::DefaultHashBuilder;
5use crate::tokens::{Count, Token, UsizeCount};
6
7#[derive(Debug)]
8pub struct Counted<V> {
9 pub refcount: UsizeCount,
10 pub value: V,
11}
12
13impl<V> Counted<V> {
14 pub fn new(value: V, initial: usize) -> Self {
15 Self {
16 refcount: UsizeCount::new(initial),
17 value,
18 }
19 }
20}
21
22pub struct CountedHashMap<K, V, S = DefaultHashBuilder> {
23 pub(crate) inner: HandleHashMap<K, Counted<V>, S>,
24}
25
26pub struct CountedHandle<'a> {
27 pub(crate) handle: Handle,
28 pub(crate) token: Token<'a, UsizeCount>, }
30
31impl<'a> CountedHandle<'a> {
32 pub fn key_ref<'m, K, V, S>(&self, map: &'m CountedHashMap<K, V, S>) -> Option<&'m K>
33 where
34 K: Eq + core::hash::Hash,
35 S: core::hash::BuildHasher + Clone + Default,
36 {
37 map.inner.handle_key(self.handle)
38 }
39
40 pub fn value_ref<'m, K, V, S>(&self, map: &'m CountedHashMap<K, V, S>) -> Option<&'m V>
41 where
42 K: Eq + core::hash::Hash,
43 S: core::hash::BuildHasher + Clone + Default,
44 {
45 map.inner.handle_value(self.handle).map(|c| &c.value)
46 }
47
48 pub fn value_mut<'m, K, V, S>(&self, map: &'m mut CountedHashMap<K, V, S>) -> Option<&'m mut V>
49 where
50 K: Eq + core::hash::Hash,
51 S: core::hash::BuildHasher + Clone + Default,
52 {
53 map.inner
54 .handle_value_mut(self.handle)
55 .map(|c| &mut c.value)
56 }
57}
58
59pub enum PutResult<K, V> {
61 Live,
62 Removed { key: K, value: V },
63}
64
65impl<K, V> CountedHashMap<K, V>
66where
67 K: Eq + core::hash::Hash,
68{
69 pub fn new() -> Self {
70 Self {
71 inner: HandleHashMap::new(),
72 }
73 }
74}
75
76impl<K, V> Default for CountedHashMap<K, V>
77where
78 K: Eq + core::hash::Hash,
79{
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85pub(crate) struct Iter<'a, K, V, S> {
87 pub(crate) it: crate::handle_hash_map::Iter<'a, K, Counted<V>, S>,
88 pub(crate) _pd: core::marker::PhantomData<&'a (K, V, S)>,
89}
90
91impl<'a, K, V, S> Iterator for Iter<'a, K, V, S> {
92 type Item = (CountedHandle<'static>, &'a K, &'a V);
93 #[inline]
94 fn next(&mut self) -> Option<Self::Item> {
95 self.it.next().map(|(h, k, c)| {
96 let ch = CountedHandle {
97 handle: h,
98 token: c.refcount.get(),
99 };
100 (ch, k, &c.value)
101 })
102 }
103}
104
105pub(crate) struct IterMut<'a, K, V, S> {
107 pub(crate) it: crate::handle_hash_map::IterMut<'a, K, Counted<V>, S>,
108 pub(crate) _pd: core::marker::PhantomData<&'a (K, V, S)>,
109}
110
111impl<'a, K, V, S> Iterator for IterMut<'a, K, V, S> {
112 type Item = (CountedHandle<'static>, &'a K, &'a mut V);
113 #[inline]
114 fn next(&mut self) -> Option<Self::Item> {
115 self.it.next().map(|(h, k, c)| {
116 let token = c.refcount.get();
117 let ch = CountedHandle { handle: h, token };
118 (ch, k, &mut c.value)
119 })
120 }
121}
122
123impl<K, V, S> CountedHashMap<K, V, S>
124where
125 K: Eq + core::hash::Hash,
126 S: core::hash::BuildHasher + Clone + Default,
127{
128 pub fn with_hasher(hasher: S) -> Self {
129 Self {
130 inner: HandleHashMap::with_hasher(hasher),
131 }
132 }
133
134 pub fn len(&self) -> usize {
135 self.inner.len()
136 }
137 pub fn is_empty(&self) -> bool {
138 self.inner.is_empty()
139 }
140
141 pub fn find<Q>(&self, q: &Q) -> Option<CountedHandle<'static>>
142 where
143 K: core::borrow::Borrow<Q>,
144 Q: ?Sized + core::hash::Hash + Eq,
145 {
146 let handle = self.inner.find(q)?;
147 let entry = self.inner.handle_value(handle)?;
148 let counter = &entry.refcount;
149 let token = counter.get();
150 Some(CountedHandle { handle, token })
151 }
152
153 pub fn contains_key<Q>(&self, q: &Q) -> bool
154 where
155 K: core::borrow::Borrow<Q>,
156 Q: ?Sized + core::hash::Hash + Eq,
157 {
158 self.inner.contains_key(q)
159 }
160
161 #[allow(dead_code)]
163 pub fn insert(&mut self, key: K, value: V) -> Result<CountedHandle<'static>, InsertError> {
164 let counted = Counted::new(value, 0);
165 match self.inner.insert(key, counted) {
166 Ok(handle) => {
167 let entry = self
168 .inner
169 .handle_value(handle)
170 .expect("entry must exist immediately after successful insert");
171 let counter = &entry.refcount;
172 let token = counter.get();
173 Ok(CountedHandle { handle, token })
174 }
175 Err(e) => Err(e),
176 }
177 }
178
179 pub fn get(&self, h: &CountedHandle<'_>) -> CountedHandle<'static> {
181 let entry = self
183 .inner
184 .handle_value(h.handle)
185 .expect("handle must be valid while counted handle is live");
186 let token = entry.refcount.get();
187 CountedHandle {
188 handle: h.handle,
189 token,
190 }
191 }
192
193 pub fn insert_with<F>(
195 &mut self,
196 key: K,
197 default: F,
198 ) -> Result<CountedHandle<'static>, InsertError>
199 where
200 F: FnOnce() -> V,
201 {
202 match self.inner.insert_with(key, || Counted::new(default(), 0)) {
203 Ok(handle) => {
204 let entry = self
205 .inner
206 .handle_value(handle)
207 .expect("entry must exist immediately after successful insert");
208 let token = entry.refcount.get();
209 Ok(CountedHandle { handle, token })
210 }
211 Err(e) => Err(e),
212 }
213 }
214
215 pub fn put(&mut self, h: CountedHandle<'_>) -> PutResult<K, V> {
217 let CountedHandle { handle, token, .. } = h;
218 let entry = self
219 .inner
220 .handle_value(handle)
221 .expect("CountedHandle must refer to a live entry when returned to put()");
222 let now_zero = entry.refcount.put(token);
223 if now_zero {
224 let (k, v) = self
225 .inner
226 .remove(handle)
227 .expect("entry must exist when count reaches zero");
228 PutResult::Removed {
229 key: k,
230 value: v.value,
231 }
232 } else {
233 PutResult::Live
234 }
235 }
236
237 #[allow(dead_code)]
241 pub fn iter(&self) -> impl Iterator<Item = (Handle, &K, &V)> {
242 self.inner.iter().map(|(h, k, c)| (h, k, &c.value))
243 }
244
245 #[allow(dead_code)]
246 pub fn iter_mut(&mut self) -> impl Iterator<Item = (Handle, &K, &mut V)> {
247 self.inner.iter_mut().map(|(h, k, c)| (h, k, &mut c.value))
248 }
249
250 pub(crate) fn iter_raw(&self) -> Iter<'_, K, V, S> {
251 let it = self.inner.iter();
252 Iter {
253 it,
254 _pd: core::marker::PhantomData,
255 }
256 }
257
258 pub(crate) fn iter_mut_raw(&mut self) -> IterMut<'_, K, V, S> {
259 let it = self.inner.iter_mut();
260 IterMut {
261 it,
262 _pd: core::marker::PhantomData,
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use proptest::prelude::*;
271 use std::cell::Cell;
272 use std::collections::BTreeSet;
273
274 proptest! {
285 #[test]
286 fn prop_counted_hashmap_liveness(keys in 1usize..=5, ops in proptest::collection::vec((0u8..=4u8, 0usize..100usize), 1..100)) {
287 let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
288 let mut live: Vec<Vec<CountedHandle<'static>>> = std::iter::repeat_with(Vec::new).take(keys).collect();
289
290 for (op, raw_k) in ops.into_iter() {
291 let k = raw_k % keys;
292 let key = format!("k{}", k);
293 match op {
294 0 => {
296 let res = m.insert(key.clone(), k as i32);
297 match res {
298 Ok(h) => live[k].push(h),
299 Err(InsertError::DuplicateKey) => {}
300 }
301 }
302 1 => {
304 if let Some(h) = m.find(&key) {
305 live[k].push(h);
306 }
307 }
308 2 => {
310 if let Some(h) = live[k].pop() {
311 let h2 = m.get(&h);
312 live[k].push(h);
313 live[k].push(h2);
314 }
315 }
316 3 => {
318 if let Some(h) = live[k].pop() {
319 match m.put(h) {
320 PutResult::Live => {}
321 PutResult::Removed { key: _, value: _ } => {
322 prop_assert!(live[k].is_empty());
325 }
326 }
327 }
328 }
329 4 => {
331 while let Some(h) = live[k].pop() { let _ = m.put(h); }
332 }
333 _ => unreachable!(),
334 }
335
336 let present = m.contains_key(&key);
337 prop_assert_eq!(present, !live[k].is_empty());
338 }
339
340 for (k, handles) in live.iter_mut().enumerate() {
342 while let Some(h) = handles.pop() { let _ = m.put(h); }
343 let key = format!("k{}", k);
344 prop_assert_eq!(m.contains_key(&key), false);
345 }
346 }
347 }
348
349 #[test]
354 fn insert_with_is_lazy_and_mints_token() {
355 use crate::handle_hash_map::InsertError;
356
357 let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
358 let calls = Cell::new(0);
359
360 let ch = m
361 .insert_with("k".to_string(), || {
362 calls.set(calls.get() + 1);
363 7
364 })
365 .unwrap();
366 assert_eq!(calls.get(), 1);
367 assert_eq!(ch.value_ref(&m), Some(&7));
368
369 {
371 let dup = m.insert_with("k".to_string(), || {
372 calls.set(calls.get() + 1);
373 99
374 });
375 match dup {
376 Err(InsertError::DuplicateKey) => {}
377 _ => panic!("unexpected result"),
378 }
379 }
380 assert_eq!(calls.get(), 1);
381
382 match m.put(ch) {
384 PutResult::Removed { key, value } => {
385 assert_eq!(key, "k".to_string());
386 assert_eq!(value, 7);
387 }
388 _ => panic!("expected removal"),
389 }
390 assert!(!m.contains_key(&"k".to_string()));
391 }
392
393 #[test]
397 fn insert_with_then_mutate_value() {
398 let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
399 let ch = m.insert_with("k".to_string(), || 10).unwrap();
400 if let Some(v) = ch.value_mut(&mut m) {
401 *v += 5;
402 }
403 assert_eq!(ch.value_ref(&m), Some(&15));
404 let _ = m.put(ch);
405 }
406
407 #[test]
411 fn get_mints_new_token_and_put_removes_at_zero() {
412 let mut m: CountedHashMap<&'static str, i32> = CountedHashMap::new();
413 let h1 = m.insert("a", 1).unwrap();
414 let h2 = m.get(&h1);
415
416 match m.put(h1) {
418 PutResult::Live => {}
419 _ => panic!("expected Live when one handle remains"),
420 }
421 assert!(m.contains_key(&"a"));
422
423 match m.put(h2) {
425 PutResult::Removed { key, value } => {
426 assert_eq!(key, "a");
427 assert_eq!(value, 1);
428 }
429 _ => panic!("expected Removed at zero"),
430 }
431 assert!(!m.contains_key(&"a"));
432 }
433
434 #[test]
437 fn key_ref_value_ref_and_mutation_persist() {
438 let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
439 let h = m.insert("k1".to_string(), 10).unwrap();
440 assert_eq!(h.key_ref(&m), Some(&"k1".to_string()));
441 assert_eq!(h.value_ref(&m), Some(&10));
442 if let Some(v) = h.value_mut(&mut m) {
443 *v += 7;
444 }
445 assert_eq!(h.value_ref(&m), Some(&17));
446 let _ = m.put(h);
447 }
448
449 #[test]
452 fn iter_yields_all_entries_once_and_iter_mut_updates_values() {
453 let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
454 let keys = ["k1", "k2", "k3", "k4"];
455 let mut handles = Vec::new();
456 for (i, k) in keys.iter().enumerate() {
457 handles.push(m.insert((*k).to_string(), i as i32).unwrap());
458 }
459
460 let seen: BTreeSet<String> = m.iter().map(|(_h, k, _v)| k.clone()).collect();
462 let expected: BTreeSet<String> = keys.iter().map(|s| (*s).to_string()).collect();
463 assert_eq!(seen, expected);
464
465 for (_h, _k, v) in m.iter_mut() {
467 *v += 100;
468 }
469 for (i, _k) in keys.iter().enumerate() {
470 let hv = handles[i].value_ref(&m).copied();
471 assert_eq!(hv, Some((i as i32) + 100));
472 }
473
474 for h in handles {
476 let _ = m.put(h);
477 }
478 }
479
480 #[test]
486 fn iter_raw_requires_put_and_keeps_entries_live() {
487 let mut m: CountedHashMap<String, i32> = CountedHashMap::new();
488 let h1 = m.insert("a".to_string(), 1).unwrap();
490 let h2 = m.insert("b".to_string(), 2).unwrap();
491 let h3 = m.insert("c".to_string(), 3).unwrap();
492
493 let mut raw: Vec<CountedHandle<'static>> = m.iter_raw().map(|(ch, _k, _v)| ch).collect();
495
496 match m.put(h1) {
498 PutResult::Live => {}
499 _ => panic!("expected Live"),
500 }
501 match m.put(h2) {
502 PutResult::Live => {}
503 _ => panic!("expected Live"),
504 }
505 match m.put(h3) {
506 PutResult::Live => {}
507 _ => panic!("expected Live"),
508 }
509 assert!(m.contains_key(&"a".to_string()));
510 assert!(m.contains_key(&"b".to_string()));
511 assert!(m.contains_key(&"c".to_string()));
512
513 let mut removed: BTreeSet<String> = BTreeSet::new();
515 while let Some(ch) = raw.pop() {
516 match m.put(ch) {
517 PutResult::Removed { key, value } => {
518 removed.insert(key.clone());
519 match key.as_str() {
520 "a" => assert_eq!(value, 1),
521 "b" => assert_eq!(value, 2),
522 "c" => assert_eq!(value, 3),
523 _ => unreachable!(),
524 }
525 }
526 PutResult::Live => {}
527 }
528 }
529 assert_eq!(
530 removed,
531 ["a", "b", "c"].into_iter().map(|s| s.to_string()).collect()
532 );
533 assert!(!m.contains_key(&"a".to_string()));
534 assert!(!m.contains_key(&"b".to_string()));
535 assert!(!m.contains_key(&"c".to_string()));
536 }
537
538 #[test]
543 fn iter_mut_raw_requires_put_and_keeps_entries_live() {
544 let mut m: CountedHashMap<&'static str, i32> = CountedHashMap::new();
545 let h1 = m.insert("x", 10).unwrap();
546 let h2 = m.insert("y", 20).unwrap();
547
548 let mut raw: Vec<CountedHandle<'static>> = m
550 .iter_mut_raw()
551 .map(|(ch, _k, v)| {
552 *v += 1;
553 ch
554 })
555 .collect();
556
557 assert!(matches!(m.put(h1), PutResult::Live));
559 assert!(matches!(m.put(h2), PutResult::Live));
560 assert!(m.contains_key(&"x"));
561 assert!(m.contains_key(&"y"));
562
563 let xr = m.find(&"x").unwrap();
565 let yr = m.find(&"y").unwrap();
566 assert_eq!(xr.value_ref(&m), Some(&11));
567 assert_eq!(yr.value_ref(&m), Some(&21));
568 let _ = m.put(xr);
570 let _ = m.put(yr);
571
572 let mut removed = 0;
574 while let Some(ch) = raw.pop() {
575 match m.put(ch) {
576 PutResult::Removed { key, value } => {
577 removed += 1;
578 match key {
579 "x" => assert_eq!(value, 11),
580 "y" => assert_eq!(value, 21),
581 _ => unreachable!(),
582 }
583 }
584 PutResult::Live => {}
585 }
586 }
587 assert_eq!(removed, 2);
588 assert!(!m.contains_key(&"x"));
589 assert!(!m.contains_key(&"y"));
590 }
591
592 #[test]
598 fn dropping_counted_handle_without_put_panics() {
599 use std::panic::{catch_unwind, AssertUnwindSafe};
600 let res = catch_unwind(AssertUnwindSafe(|| {
601 let mut m: CountedHashMap<&'static str, i32> = CountedHashMap::new();
602 let h = m.insert("boom", 1).unwrap();
603 drop(h); }));
605 assert!(
606 res.is_err(),
607 "expected panic when CountedHandle is dropped without put"
608 );
609
610 let res2 = catch_unwind(AssertUnwindSafe(|| {
612 let m: CountedHashMap<&'static str, i32> = {
613 let mut mm = CountedHashMap::new();
614 let _ = mm.insert("a", 1).unwrap();
615 let _ = mm.insert("b", 2).unwrap();
616 mm
617 };
618 let v: Vec<_> = m.iter_raw().collect();
619 drop(v); }));
621 assert!(
622 res2.is_err(),
623 "expected panic when raw handles are dropped without put"
624 );
625 }
626}