cairo_lang_utils/
ordered_hash_map.rs

1use core::hash::{BuildHasher, Hash};
2#[cfg(feature = "std")]
3use std::collections::hash_map::RandomState;
4
5use indexmap::IndexMap;
6use itertools::zip_eq;
7
8#[cfg(feature = "std")]
9type BHImpl = RandomState;
10#[cfg(not(feature = "std"))]
11type BHImpl = hashbrown::DefaultHashBuilder;
12
13#[derive(Clone, Debug)]
14pub struct OrderedHashMap<Key, Value, BH = BHImpl>(IndexMap<Key, Value, BH>);
15
16impl<Key, Value, BH> core::ops::Deref for OrderedHashMap<Key, Value, BH> {
17    type Target = IndexMap<Key, Value, BH>;
18
19    fn deref(&self) -> &Self::Target {
20        &self.0
21    }
22}
23
24impl<Key, Value, BH> core::ops::DerefMut for OrderedHashMap<Key, Value, BH> {
25    fn deref_mut(&mut self) -> &mut Self::Target {
26        &mut self.0
27    }
28}
29
30#[cfg(feature = "salsa")]
31unsafe impl<Key: salsa::Update + Eq + Hash, Value: salsa::Update> salsa::Update
32    for OrderedHashMap<Key, Value, BHImpl>
33{
34    // This code was taken from the salsa::Update trait implementation for IndexMap.
35    // It is defined privately in macro_rules! maybe_update_map in the db-ext-macro repo.
36    unsafe fn maybe_update(old_pointer: *mut Self, new_map: Self) -> bool {
37        let old_map: &mut Self = unsafe { &mut *old_pointer };
38
39        // To be considered "equal", the set of keys
40        // must be the same between the two maps.
41        let same_keys =
42            old_map.len() == new_map.len() && old_map.keys().all(|k| new_map.contains_key(k));
43
44        // If the set of keys has changed, then just pull in the new values
45        // from new_map and discard the old ones.
46        if !same_keys {
47            old_map.clear();
48            old_map.extend(new_map);
49            return true;
50        }
51
52        // Otherwise, recursively descend to the values.
53        // We do not invoke `K::update` because we assume
54        // that if the values are `Eq` they must not need
55        // updating (see the trait criteria).
56        let mut changed = false;
57        for (key, new_value) in new_map.into_iter() {
58            let old_value = old_map.get_mut(&key).unwrap();
59            changed |= unsafe { Value::maybe_update(old_value, new_value) };
60        }
61        changed
62    }
63}
64
65impl<Key, Value, BH: Default> Default for OrderedHashMap<Key, Value, BH> {
66    #[cfg(feature = "std")]
67    fn default() -> Self {
68        Self(Default::default())
69    }
70    #[cfg(not(feature = "std"))]
71    fn default() -> Self {
72        Self(IndexMap::with_hasher(Default::default()))
73    }
74}
75
76impl<Key, Value, BH> OrderedHashMap<Key, Value, BH> {
77    /// Returns true if the map contains no elements.
78    pub fn is_empty(&self) -> bool {
79        self.0.is_empty()
80    }
81}
82
83impl<Key: Eq + Hash, Value, BH: BuildHasher> OrderedHashMap<Key, Value, BH> {
84    /// Returns true if the maps are equal, ignoring the order of the entries.
85    pub fn eq_unordered(&self, other: &Self) -> bool
86    where
87        Value: Eq,
88    {
89        if self.len() != other.len() {
90            return false;
91        };
92        self.iter().all(|(k, v)| other.get(k) == Some(v))
93    }
94}
95
96/// Entry for an existing key-value pair or a vacant location to insert one.
97pub type Entry<'a, Key, Value> = indexmap::map::Entry<'a, Key, Value>;
98
99impl<Key, Value, BH> IntoIterator for OrderedHashMap<Key, Value, BH> {
100    type Item = (Key, Value);
101    type IntoIter = indexmap::map::IntoIter<Key, Value>;
102    fn into_iter(self) -> Self::IntoIter {
103        let OrderedHashMap(inner) = self;
104        inner.into_iter()
105    }
106}
107
108impl<Key: Eq, Value: Eq, BH> PartialEq for OrderedHashMap<Key, Value, BH> {
109    fn eq(&self, other: &Self) -> bool {
110        if self.len() != other.len() {
111            return false;
112        };
113
114        zip_eq(self.iter(), other.iter()).all(|(a, b)| a == b)
115    }
116}
117
118impl<Key: Hash + Eq, Value: Eq, BH: BuildHasher> Eq for OrderedHashMap<Key, Value, BH> {}
119
120impl<Key: Hash, Value: Hash, BH> Hash for OrderedHashMap<Key, Value, BH> {
121    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
122        self.len().hash(state);
123        for e in self.iter() {
124            e.hash(state);
125        }
126    }
127}
128
129impl<Key: Hash + Eq, Value, BH: BuildHasher + Default> FromIterator<(Key, Value)>
130    for OrderedHashMap<Key, Value, BH>
131{
132    fn from_iter<T: IntoIterator<Item = (Key, Value)>>(iter: T) -> Self {
133        Self(iter.into_iter().collect())
134    }
135}
136
137impl<Key: Hash + Eq, Value, BH: BuildHasher + Default, const N: usize> From<[(Key, Value); N]>
138    for OrderedHashMap<Key, Value, BH>
139{
140    fn from(init_map: [(Key, Value); N]) -> Self {
141        Self(IndexMap::from_iter(init_map))
142    }
143}
144
145#[cfg(feature = "serde")]
146mod impl_serde {
147    #[cfg(not(feature = "std"))]
148    use alloc::vec::Vec;
149
150    use itertools::Itertools;
151    use serde::{Deserialize, Deserializer, Serialize, Serializer};
152
153    use super::*;
154
155    impl<K: Hash + Eq + Serialize, V: Serialize, BH: BuildHasher> Serialize
156        for OrderedHashMap<K, V, BH>
157    {
158        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
159        where
160            S: Serializer,
161        {
162            self.0.serialize(serializer)
163        }
164    }
165
166    impl<'de, K: Hash + Eq + Deserialize<'de>, V: Deserialize<'de>, BH: BuildHasher + Default>
167        Deserialize<'de> for OrderedHashMap<K, V, BH>
168    {
169        fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
170            IndexMap::<K, V, BH>::deserialize(deserializer).map(|s| OrderedHashMap(s))
171        }
172    }
173
174    pub fn serialize_ordered_hashmap_vec<'de, K, V, BH, S>(
175        v: &OrderedHashMap<K, V, BH>,
176        serializer: S,
177    ) -> Result<S::Ok, S::Error>
178    where
179        S: Serializer,
180        K: Serialize + Deserialize<'de> + Hash + Eq,
181        V: Serialize + Deserialize<'de>,
182    {
183        v.iter().collect_vec().serialize(serializer)
184    }
185
186    pub fn deserialize_ordered_hashmap_vec<'de, K, V, BH: BuildHasher + Default, D>(
187        deserializer: D,
188    ) -> Result<OrderedHashMap<K, V, BH>, D::Error>
189    where
190        D: Deserializer<'de>,
191        K: Serialize + Deserialize<'de> + Hash + Eq,
192        V: Serialize + Deserialize<'de>,
193    {
194        Ok(Vec::<(K, V)>::deserialize(deserializer)?.into_iter().collect())
195    }
196}
197#[cfg(feature = "serde")]
198pub use impl_serde::*;