use std::collections::HashMap;
use std::hash::{BuildHasher, Hash};
use std::iter::FusedIterator;
use std::{any::Any, fmt::Debug, ops::Index};
use std::{collections::hash_map::RandomState, marker::PhantomData};
use crate::entry;
use crate::typedkey::{Key, TypedKey, TypedKeyRef};
use crate::typedvalue::TypedMapValue;
pub trait TypedMapKey<Marker = ()>: Eq + Hash {
type Value: 'static;
}
const INVALID_KEY: &str = "Broken TypedMap: invalid key type";
const INVALID_VALUE: &str = "Broken TypedMap: invalid value type";
pub struct TypedMap<Marker = (), S = RandomState> {
state: HashMap<TypedKey, TypedMapValue, S>,
_phantom: PhantomData<Marker>,
}
impl<Marker> TypedMap<Marker> {
pub fn new() -> Self {
TypedMap {
state: Default::default(),
_phantom: PhantomData,
}
}
pub fn with_capacity(capacity: usize) -> Self {
TypedMap {
state: HashMap::with_capacity(capacity),
_phantom: PhantomData,
}
}
}
impl<Marker, S> TypedMap<Marker, S>
where
S: BuildHasher,
{
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
TypedMap {
state: HashMap::with_capacity_and_hasher(capacity, hash_builder),
_phantom: PhantomData,
}
}
pub fn with_hasher(hash_builder: S) -> Self {
TypedMap {
state: HashMap::with_hasher(hash_builder),
_phantom: PhantomData,
}
}
pub fn insert<K: 'static + TypedMapKey<Marker>>(
&mut self,
key: K,
value: K::Value,
) -> Option<K::Value> {
let typed_key = TypedKey::from_key(key);
let value = TypedMapValue::from_value(value);
let old_value = self.state.insert(typed_key, value);
old_value.and_then(|v| v.downcast::<K::Value>().ok())
}
pub fn get<K: 'static + TypedMapKey<Marker>>(&self, key: &K) -> Option<&K::Value> {
let typed_key = TypedKeyRef::from_key_ref(key);
let value = self.state.get(&typed_key as &dyn Key)?;
Some(value.downcast_ref::<K::Value>().expect(INVALID_VALUE))
}
pub fn get_mut<K: 'static + TypedMapKey<Marker>>(&mut self, key: &K) -> Option<&mut K::Value> {
let typed_key = TypedKeyRef::from_key_ref(key);
let value = self.state.get_mut(&typed_key as &dyn Key)?;
Some(value.downcast_mut::<K::Value>().expect(INVALID_VALUE))
}
pub fn get_key_value<K: 'static + TypedMapKey<Marker>>(
&self,
key: &K,
) -> Option<(&K, &K::Value)> {
let typed_key = TypedKeyRef::from_key_ref(key);
let (key, value) = self.state.get_key_value(&typed_key as &dyn Key)?;
Some((
key.downcast_ref().expect(INVALID_KEY),
value.downcast_ref().expect(INVALID_VALUE),
))
}
pub fn remove<K: 'static + TypedMapKey<Marker>>(&mut self, key: &K) -> Option<K::Value> {
let typed_key = TypedKeyRef::from_key_ref(key);
let value = self.state.remove(&typed_key as &dyn Key)?;
Some(value.downcast::<K::Value>().ok().expect(INVALID_VALUE))
}
pub fn remove_entry<K: 'static + TypedMapKey<Marker>>(
&mut self,
key: &K,
) -> Option<(K, K::Value)> {
let typed_key = TypedKeyRef::from_key_ref(key);
let value = self.state.remove_entry(&typed_key as &dyn Key);
value.map(|(k, v)| {
let k = k.downcast::<K>().ok().expect(INVALID_KEY);
let v = v.downcast::<K::Value>().ok().expect(INVALID_VALUE);
(k, v)
})
}
pub fn entry<K: 'static + TypedMapKey<Marker>>(
&mut self,
key: K,
) -> entry::Entry<'_, K, Marker> {
let typed_key = TypedKey::from_key(key);
entry::map_entry(self.state.entry(typed_key))
}
pub fn contains_key<K: 'static + TypedMapKey<Marker>>(&self, key: &K) -> bool {
self.get(key).is_some()
}
pub fn len(&self) -> usize {
self.state.len()
}
pub fn capacity(&self) -> usize {
self.state.capacity()
}
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub fn clear(&mut self) {
self.state.clear();
}
pub fn reserve(&mut self, additional: usize) {
self.state.reserve(additional)
}
pub fn shrink_to_fit(&mut self) {
self.state.shrink_to_fit();
}
pub fn hasher(&self) -> &S {
self.state.hasher()
}
pub fn keys(&self) -> Keys<'_> {
Keys(self.state.keys())
}
pub fn values(&self) -> Values<'_> {
Values(self.state.values())
}
pub fn values_mut(&mut self) -> ValuesMut<'_> {
ValuesMut(self.state.values_mut())
}
pub fn drain(&mut self) -> Drain<'_, Marker> {
Drain(self.state.drain(), PhantomData)
}
pub fn iter(&self) -> Iter<'_, Marker> {
Iter(self.state.iter(), PhantomData)
}
pub fn iter_mut(&mut self) -> IterMut<'_, Marker> {
IterMut(self.state.iter_mut(), PhantomData)
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(TypedKeyValueMutRef<'_, Marker>) -> bool,
{
let g = move |key: &TypedKey, value: &mut TypedMapValue| {
f(TypedKeyValueMutRef {
key,
value,
_marker: PhantomData,
})
};
self.state.retain(g)
}
}
impl<Marker> Default for TypedMap<Marker> {
fn default() -> Self {
TypedMap::new()
}
}
impl<Marker> IntoIterator for TypedMap<Marker> {
type IntoIter = IntoIter<Marker>;
type Item = TypedKeyValue<Marker>;
fn into_iter(self) -> Self::IntoIter {
IntoIter(self.state.into_iter(), PhantomData)
}
}
impl<Marker, K: 'static + TypedMapKey<Marker>, S: BuildHasher> Index<&K> for TypedMap<Marker, S> {
type Output = K::Value;
fn index(&self, key: &K) -> &K::Value {
self.get(key).expect("no entry found for key")
}
}
#[derive(Clone)]
pub struct Iter<'a, Marker>(
std::collections::hash_map::Iter<'a, TypedKey, TypedMapValue>,
PhantomData<Marker>,
);
impl<'a, Marker> Iterator for Iter<'a, Marker> {
type Item = TypedKeyValueRef<'a, Marker>;
fn next(&mut self) -> Option<Self::Item> {
let (key, value) = self.0.next()?;
Some(TypedKeyValueRef {
key,
value,
_marker: PhantomData,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<Marker> ExactSizeIterator for Iter<'_, Marker> {}
impl<Marker> FusedIterator for Iter<'_, Marker> {}
pub struct IterMut<'a, Marker>(
std::collections::hash_map::IterMut<'a, TypedKey, TypedMapValue>,
PhantomData<Marker>,
);
impl<'a, Marker> Iterator for IterMut<'a, Marker> {
type Item = TypedKeyValueMutRef<'a, Marker>;
fn next(&mut self) -> Option<Self::Item> {
let (key, value) = self.0.next()?;
Some(TypedKeyValueMutRef {
key,
value,
_marker: PhantomData,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<Marker> ExactSizeIterator for IterMut<'_, Marker> {}
impl<Marker> FusedIterator for IterMut<'_, Marker> {}
pub struct Drain<'a, Marker>(
std::collections::hash_map::Drain<'a, TypedKey, TypedMapValue>,
PhantomData<Marker>,
);
impl<'a, Marker> Iterator for Drain<'a, Marker> {
type Item = TypedKeyValue<Marker>;
fn next(&mut self) -> Option<Self::Item> {
let (key, value) = self.0.next()?;
Some(TypedKeyValue {
key,
value,
_marker: PhantomData,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<Marker> ExactSizeIterator for Drain<'_, Marker> {}
impl<Marker> FusedIterator for Drain<'_, Marker> {}
pub struct IntoIter<Marker>(
std::collections::hash_map::IntoIter<TypedKey, TypedMapValue>,
PhantomData<Marker>,
);
impl<Marker> Iterator for IntoIter<Marker> {
type Item = TypedKeyValue<Marker>;
fn next(&mut self) -> Option<Self::Item> {
let (key, value) = self.0.next()?;
Some(TypedKeyValue {
key,
value,
_marker: PhantomData,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<Marker> ExactSizeIterator for IntoIter<Marker> {}
impl<Marker> FusedIterator for IntoIter<Marker> {}
#[derive(Clone)]
pub struct Keys<'a>(std::collections::hash_map::Keys<'a, TypedKey, TypedMapValue>);
impl<'a> Iterator for Keys<'a> {
type Item = &'a dyn Any;
fn next(&mut self) -> Option<Self::Item> {
let key = self.0.next()?;
Some(key.as_any())
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl ExactSizeIterator for Keys<'_> {}
impl FusedIterator for Keys<'_> {}
#[derive(Clone)]
pub struct Values<'a>(std::collections::hash_map::Values<'a, TypedKey, TypedMapValue>);
impl<'a> Iterator for Values<'a> {
type Item = &'a dyn Any;
fn next(&mut self) -> Option<Self::Item> {
let value = self.0.next()?;
Some(value.as_any())
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl ExactSizeIterator for Values<'_> {}
impl FusedIterator for Values<'_> {}
pub struct ValuesMut<'a>(std::collections::hash_map::ValuesMut<'a, TypedKey, TypedMapValue>);
impl<'a> Iterator for ValuesMut<'a> {
type Item = &'a mut dyn Any;
fn next(&mut self) -> Option<Self::Item> {
let value = self.0.next()?;
Some(value.as_mut_any())
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl ExactSizeIterator for ValuesMut<'_> {}
impl FusedIterator for ValuesMut<'_> {}
pub struct TypedKeyValue<Marker> {
key: TypedKey,
value: TypedMapValue,
_marker: PhantomData<Marker>,
}
impl<Marker> TypedKeyValue<Marker> {
pub fn downcast_key_ref<K: 'static + TypedMapKey<Marker>>(&self) -> Option<&K> {
self.key.downcast_ref()
}
pub fn downcast_key<K: 'static + TypedMapKey<Marker>>(self) -> Result<K, Self> {
let Self {
key,
value,
_marker,
} = self;
key.downcast().map_err(|key| Self {
key,
value,
_marker,
})
}
pub fn downcast_value_ref<V: Any>(&self) -> Option<&V> {
self.key.downcast_ref()
}
pub fn downcast_value<V: Any>(self) -> Result<V, Self> {
let Self {
key,
value,
_marker,
} = self;
value.downcast().map_err(|value| Self {
key,
value,
_marker,
})
}
pub fn downcast_pair_ref<K: 'static + TypedMapKey<Marker>>(&self) -> Option<(&K, &K::Value)> {
let key = self.downcast_key_ref()?;
let value = self.downcast_value_ref()?;
Some((key, value))
}
pub fn downcast_pair<K: 'static + TypedMapKey<Marker>>(self) -> Result<(K, K::Value), Self> {
let Self {
key,
value,
_marker,
} = self;
match key.downcast() {
Ok(key) => match value.downcast() {
Ok(value) => Ok((key, value)),
Err(dyn_value) => Err(Self {
key: TypedKey::from_key(key),
value: dyn_value,
_marker,
}),
},
Err(dyn_key) => Err(Self {
key: dyn_key,
value,
_marker,
}),
}
}
}
impl<M> Debug for TypedKeyValue<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("TypedKeyValue")
}
}
pub struct TypedKeyValueRef<'a, Marker> {
key: &'a TypedKey,
value: &'a TypedMapValue,
_marker: PhantomData<Marker>,
}
impl<'a, Marker> TypedKeyValueRef<'a, Marker> {
pub fn downcast_key_ref<K: 'static + TypedMapKey<Marker>>(&self) -> Option<&'a K> {
self.key.downcast_ref()
}
pub fn downcast_value_ref<V: 'static + Any>(&self) -> Option<&'a V> {
self.value.downcast_ref()
}
pub fn downcast_pair_ref<K: 'static + TypedMapKey<Marker>>(
&self,
) -> Option<(&'a K, &'a K::Value)> {
self.downcast_key_ref()
.and_then(move |key| self.downcast_value_ref().map(move |value| (key, value)))
}
}
impl<M> Debug for TypedKeyValueRef<'_, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("TypedKeyValueRef")
}
}
pub struct TypedKeyValueMutRef<'a, Marker> {
key: &'a TypedKey,
value: &'a mut TypedMapValue,
_marker: PhantomData<Marker>,
}
impl<'a, Marker> TypedKeyValueMutRef<'a, Marker> {
pub fn downcast_key_ref<K: 'static + TypedMapKey<Marker>>(&self) -> Option<&'a K> {
self.key.downcast_ref()
}
pub fn downcast_value_mut<'b, V: 'static + Any>(&'b mut self) -> Option<&'b mut V>
where
'a: 'b,
{
self.value.downcast_mut()
}
pub fn downcast_value<V: 'static + Any>(self) -> Result<&'a mut V, Self> {
if self.value.is::<V>() {
Ok(self.value.downcast_mut().expect("Unreachable!"))
} else {
Err(self)
}
}
pub fn downcast_pair_mut<'b, K: 'static + TypedMapKey<Marker>>(
&'b mut self,
) -> Option<(&'b K, &'b mut K::Value)>
where
'a: 'b,
{
self.downcast_key_ref()
.and_then(move |key| self.downcast_value_mut().map(move |value| (key, value)))
}
pub fn downcast_pair<K: 'static + TypedMapKey<Marker>>(
self,
) -> Result<(&'a K, &'a mut K::Value), Self> {
let key = self.downcast_key_ref();
let key = match key {
Some(key) => key,
None => return Err(self),
};
match self.downcast_value() {
Ok(value) => Ok((key, value)),
Err(err) => Err(err),
}
}
}
impl<M> Debug for TypedKeyValueMutRef<'_, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("TypedMutRef")
}
}
#[cfg(test)]
mod tests {
use crate::TypedMap;
use crate::TypedMapKey;
use std::hash::Hash;
#[test]
fn test_basic_use() {
struct OtherState;
let mut state = TypedMap::new();
let mut other_state = TypedMap::<OtherState>::new();
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct AThing;
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct BThing(usize);
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct CThing(usize);
impl TypedMapKey for AThing {
type Value = String;
}
impl TypedMapKey for BThing {
type Value = usize;
}
impl TypedMapKey<OtherState> for CThing {
type Value = usize;
}
state.insert(AThing, "Example".to_owned());
state.insert(BThing(32), 33);
state.insert(BThing(33), 34);
other_state.insert(CThing(0), 33);
assert_eq!(state.get(&AThing), Some(&"Example".to_owned()));
assert_eq!(state.get(&BThing(0)), None);
assert_eq!(state.get(&BThing(32)), Some(&33));
assert_eq!(state.get(&BThing(33)), Some(&34));
assert_eq!(other_state.get(&CThing(0)), Some(&33));
*state.entry(BThing(3)).or_default() += 1;
assert_eq!(*state.get(&BThing(3)).unwrap(), 1usize);
*state.entry(BThing(4)).or_insert(3usize) += 1;
*state.entry(BThing(4)).or_insert(3usize) += 1;
assert_eq!(*state.get(&BThing(4)).unwrap(), 5usize);
if let crate::entry::Entry::Occupied(occupied) = state.entry(BThing(3)) {
let (k, v) = occupied.remove_entry();
assert_eq!(k, BThing(3));
assert_eq!(v, 1usize);
} else {
panic!()
}
let mut b_entries: Vec<_> = state
.iter()
.flat_map(|r| r.downcast_pair_ref::<BThing>())
.collect();
b_entries.sort_by_key(|kv| (kv.0).0);
let b4 = BThing(4);
let b32 = BThing(32);
let b33 = BThing(33);
assert_eq!(
b_entries,
vec![(&b4, &5usize), (&b32, &33usize), (&b33, &34usize)]
);
state.iter_mut().for_each(|mut r| {
if let Some((_, value)) = r.downcast_pair_mut::<BThing>() {
*value += 1;
}
});
let b_things: Vec<_> = state
.iter_mut()
.flat_map(|r| r.downcast_pair::<BThing>())
.collect();
assert_eq!(b_things.len(), 3);
}
#[test]
fn test_always_equal_types() {
let mut state = TypedMap::default();
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct AThing;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct BThing;
trait Foo {}
impl Foo for AThing {}
impl Foo for BThing {}
impl Hash for Box<dyn Foo> {
fn hash<H>(&self, hasher: &mut H)
where
H: std::hash::Hasher,
{
hasher.write_i8(0);
hasher.finish();
}
}
impl PartialEq for Box<dyn Foo> {
fn eq(&self, _rhs: &Self) -> bool {
true
}
}
impl Eq for Box<dyn Foo> {}
impl TypedMapKey for AThing {
type Value = String;
}
impl TypedMapKey for BThing {
type Value = usize;
}
impl TypedMapKey for Box<dyn Foo> {
type Value = String;
}
let key_a = Box::new(AThing);
let key_b = Box::new(BThing);
state.insert(key_a.clone() as Box<dyn Foo>, "test1".to_owned());
let old_key = state
.insert(key_b.clone() as Box<dyn Foo>, "test2".to_owned())
.unwrap();
assert_eq!(old_key, "test1".to_owned());
let key_a = &(key_a as Box<dyn Foo>);
let key_b = &(key_b as Box<dyn Foo>);
assert_eq!(state.get(key_a).unwrap(), &"test2".to_owned());
assert_eq!(state.get(key_b).unwrap(), &"test2".to_owned());
assert_eq!(state.remove(key_a).unwrap(), "test2".to_owned());
assert!(state.is_empty());
assert_eq!(state.len(), 0);
}
}