use std::collections::{HashMap, HashSet};
use std::hash::Hash;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum EntryValue<V> {
Value(V),
Deleted,
}
#[derive(Debug, Clone)]
pub struct TrackingDict<K, V> {
data: HashMap<K, V>,
dirty: HashSet<K>,
}
impl<K, V> Default for TrackingDict<K, V>
where
K: Eq + Hash + Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<K, V> TrackingDict<K, V>
where
K: Eq + Hash + Clone,
{
pub fn new() -> Self {
Self {
data: HashMap::new(),
dirty: HashSet::new(),
}
}
pub fn from_map(data: HashMap<K, V>) -> Self {
Self {
data,
dirty: HashSet::new(),
}
}
pub fn get(&self, key: &K) -> Option<&V> {
self.data.get(key)
}
pub fn contains_key(&self, key: &K) -> bool {
self.data.contains_key(key)
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
self.data.iter()
}
pub fn inner(&self) -> &HashMap<K, V> {
&self.data
}
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
self.dirty.insert(key.clone());
self.data.insert(key, value)
}
pub fn remove(&mut self, key: &K) -> Option<V> {
if self.data.contains_key(key) {
self.dirty.insert(key.clone());
}
self.data.remove(key)
}
pub fn clear(&mut self) {
for key in self.data.keys() {
self.dirty.insert(key.clone());
}
self.data.clear();
}
pub fn set_default(&mut self, key: K, default: V) -> &mut V
where
V: Clone,
{
if !self.data.contains_key(&key) {
self.dirty.insert(key.clone());
self.data.insert(key.clone(), default);
}
self.data.get_mut(&key).expect("just inserted")
}
pub fn update_no_track(&mut self, other: HashMap<K, V>) {
for (k, v) in other {
self.data.insert(k, v);
}
}
pub fn mark_as_accessed(&mut self, key: K) {
self.dirty.insert(key);
}
pub fn pop_accessed_keys(&mut self) -> HashSet<K> {
std::mem::take(&mut self.dirty)
}
pub fn pop_accessed_write_items(&mut self) -> Vec<(K, EntryValue<V>)>
where
V: Clone,
{
let keys = self.pop_accessed_keys();
keys.into_iter()
.map(|k| {
let v = self
.data
.get(&k)
.map_or(EntryValue::Deleted, |v| EntryValue::Value(v.clone()));
(k, v)
})
.collect()
}
}
impl<K, V> FromIterator<(K, V)> for TrackingDict<K, V>
where
K: Eq + Hash + Clone,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
Self::from_map(iter.into_iter().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_tracks_key() {
let mut td: TrackingDict<String, i32> = TrackingDict::new();
td.insert("a".into(), 1);
let keys = td.pop_accessed_keys();
assert!(keys.contains("a"));
assert!(td.pop_accessed_keys().is_empty());
}
#[test]
fn remove_tracks_key() {
let mut td = TrackingDict::from_map(HashMap::from([("x".to_owned(), 42)]));
td.remove(&"x".to_owned());
let keys = td.pop_accessed_keys();
assert!(keys.contains("x"));
}
#[test]
fn update_no_track_is_silent() {
let mut td: TrackingDict<String, i32> = TrackingDict::new();
td.update_no_track(HashMap::from([("b".into(), 2)]));
assert!(td.pop_accessed_keys().is_empty());
assert_eq!(td.get(&"b".into()), Some(&2));
}
#[test]
fn clear_marks_all_dirty() {
let mut td =
TrackingDict::from_map(HashMap::from([("a".to_owned(), 1), ("b".to_owned(), 2)]));
td.clear();
let keys = td.pop_accessed_keys();
assert!(keys.contains("a"));
assert!(keys.contains("b"));
assert!(td.is_empty());
}
#[test]
fn pop_accessed_write_items_returns_deleted() {
let mut td = TrackingDict::from_map(HashMap::from([("k".to_owned(), 10)]));
td.remove(&"k".to_owned());
let items = td.pop_accessed_write_items();
assert_eq!(items.len(), 1);
assert_eq!(items[0].1, EntryValue::Deleted);
}
#[test]
fn set_default_tracks_on_miss() {
let mut td: TrackingDict<String, i32> = TrackingDict::new();
td.set_default("new".to_owned(), 5);
let keys = td.pop_accessed_keys();
assert!(keys.contains("new"));
}
}