Skip to main content

cairo_lang_utils/
ordered_hash_set.rs

1use core::hash::{BuildHasher, Hash};
2use core::ops::Sub;
3
4use indexmap::IndexSet;
5use itertools::zip_eq;
6
7#[cfg(feature = "std")]
8type BHImpl = std::collections::hash_map::RandomState;
9#[cfg(not(feature = "std"))]
10type BHImpl = hashbrown::DefaultHashBuilder;
11
12#[derive(Clone, Debug)]
13#[cfg_attr(
14    feature = "serde",
15    derive(serde::Deserialize, serde::Serialize),
16    serde(transparent),
17    serde(bound(
18        serialize = "Key: serde::Serialize",
19        deserialize = "Key: serde::Deserialize<'de> + Hash + Eq, BH: BuildHasher + Default"
20    ))
21)]
22pub struct OrderedHashSet<Key, BH = BHImpl>(IndexSet<Key, BH>);
23
24impl<Key, BH> core::ops::Deref for OrderedHashSet<Key, BH> {
25    type Target = IndexSet<Key, BH>;
26
27    fn deref(&self) -> &Self::Target {
28        &self.0
29    }
30}
31
32impl<Key, BH> core::ops::DerefMut for OrderedHashSet<Key, BH> {
33    fn deref_mut(&mut self) -> &mut Self::Target {
34        &mut self.0
35    }
36}
37
38// This code was taken from the salsa::Update trait implementation for IndexSet.
39// It is defined privately in macro_rules! maybe_update_set in the db-ext-macro repo.
40#[cfg(feature = "salsa")]
41unsafe impl<Key: Eq + Hash, BH: BuildHasher> salsa::Update for OrderedHashSet<Key, BH> {
42    unsafe fn maybe_update(old_pointer: *mut Self, new_set: Self) -> bool {
43        let old_set: &mut Self = unsafe { &mut *old_pointer };
44
45        if *old_set == new_set {
46            false
47        } else {
48            old_set.clear();
49            old_set.extend(new_set);
50            true
51        }
52    }
53}
54
55pub type Iter<'a, Key> = indexmap::set::Iter<'a, Key>;
56
57impl<Key, BH: Default> Default for OrderedHashSet<Key, BH> {
58    #[cfg(feature = "std")]
59    fn default() -> Self {
60        Self(Default::default())
61    }
62    #[cfg(not(feature = "std"))]
63    fn default() -> Self {
64        Self(IndexSet::with_hasher(Default::default()))
65    }
66}
67
68impl<Key, BH> IntoIterator for OrderedHashSet<Key, BH> {
69    type Item = Key;
70    type IntoIter = <IndexSet<Key, BH> as IntoIterator>::IntoIter;
71
72    fn into_iter(self) -> Self::IntoIter {
73        self.0.into_iter()
74    }
75}
76
77impl<'a, Key, BH> IntoIterator for &'a OrderedHashSet<Key, BH> {
78    type Item = &'a Key;
79    type IntoIter = <&'a IndexSet<Key, BH> as IntoIterator>::IntoIter;
80
81    fn into_iter(self) -> Self::IntoIter {
82        self.iter()
83    }
84}
85
86impl<Key: Eq, BH> PartialEq for OrderedHashSet<Key, BH> {
87    fn eq(&self, other: &Self) -> bool {
88        if self.len() != other.len() {
89            return false;
90        };
91
92        zip_eq(self.iter(), other.iter()).all(|(a, b)| a == b)
93    }
94}
95
96impl<Key: Eq, BH> Eq for OrderedHashSet<Key, BH> {}
97
98impl<Key: Hash + Eq, BH: BuildHasher + Default> FromIterator<Key> for OrderedHashSet<Key, BH> {
99    fn from_iter<T: IntoIterator<Item = Key>>(iter: T) -> Self {
100        Self(iter.into_iter().collect())
101    }
102}
103
104impl<'a, Key, BH> Sub<&'a OrderedHashSet<Key, BH>> for &'a OrderedHashSet<Key, BH>
105where
106    &'a IndexSet<Key, BH>: Sub<Output = IndexSet<Key, BH>>,
107{
108    type Output = OrderedHashSet<Key, BH>;
109
110    fn sub(self, rhs: Self) -> Self::Output {
111        OrderedHashSet::<Key, BH>(&self.0 - &rhs.0)
112    }
113}