use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::Hash;
use super::error::RefMapRemoveError;
pub struct RefMap<K: Hash + Eq> {
references: HashMap<K, u64>,
}
impl<K: Hash + Eq> RefMap<K> {
pub fn new() -> Self {
RefMap {
references: HashMap::new(),
}
}
pub fn add_ref(&mut self, ref_id: K) -> u64 {
if let Some(ref_count) = self.references.remove(&ref_id) {
let new_ref_count = ref_count + 1;
self.references.insert(ref_id, new_ref_count);
new_ref_count
} else {
self.references.insert(ref_id, 1);
1
}
}
pub fn remove_ref<Q: ?Sized>(&mut self, ref_id: &Q) -> Result<Option<K>, RefMapRemoveError>
where
K: Borrow<Q>,
Q: Hash + Eq,
{
let (key, ref_count) = match self.references.remove_entry(ref_id) {
Some((key, ref_count)) => (key, ref_count),
None => {
return Err(RefMapRemoveError(
"Trying to remove a reference that does not exist".into(),
))
}
};
if ref_count == 1 {
Ok(Some(key))
} else {
self.references.insert(key, ref_count - 1);
Ok(None)
}
}
}
#[cfg(test)]
pub mod tests {
use super::*;
#[test]
fn test_add_ref() {
let mut ref_map = RefMap::new();
let ref_count = ref_map.add_ref("test_id".to_string());
assert_eq!(ref_count, 1);
let ref_count = ref_map.add_ref("test_id".to_string());
assert_eq!(ref_count, 2);
let ref_count = ref_map.add_ref("test_id_2".to_string());
assert_eq!(ref_count, 1);
}
#[test]
fn test_remove_ref() {
let mut ref_map = RefMap::new();
let ref_count = ref_map.add_ref("test_id".to_string());
assert_eq!(ref_count, 1);
let ref_count = ref_map.add_ref("test_id".to_string());
assert_eq!(ref_count, 2);
let id = ref_map.remove_ref("test_id");
assert_eq!(id, Ok(None));
assert_eq!(ref_map.references.get("test_id").cloned(), Some(1 as u64));
let id = ref_map.remove_ref("test_id");
assert_eq!(id, Ok(Some("test_id".to_string())));
assert_eq!(ref_map.references.get("test_id"), None);
}
#[test]
fn test_remove_ref_err() {
let mut ref_map: RefMap<String> = RefMap::new();
if let Ok(_) = ref_map.remove_ref("test_id") {
panic!("remove_ref should have returned an error");
}
}
}