Skip to main content

cljrs_value/collections/
hash_set.rs

1use crate::Value;
2
3/// An immutable hash set backed by `rpds::HashTrieSet`.
4#[derive(Debug, Clone)]
5pub struct PersistentHashSet {
6    inner: rpds::HashTrieSetSync<Value>,
7}
8
9impl PersistentHashSet {
10    pub fn empty() -> Self {
11        Self {
12            inner: rpds::HashTrieSetSync::new_sync(),
13        }
14    }
15
16    pub fn from_set(set: rpds::HashTrieSetSync<Value>) -> Self {
17        Self { inner: set }
18    }
19
20    pub fn count(&self) -> usize {
21        self.inner.size()
22    }
23
24    pub fn is_empty(&self) -> bool {
25        self.inner.is_empty()
26    }
27
28    pub fn contains(&self, val: &Value) -> bool {
29        self.inner.contains(val)
30    }
31
32    /// Return a new set with `val` added.
33    pub fn conj(&self, val: Value) -> Self {
34        Self {
35            inner: self.inner.insert(val),
36        }
37    }
38
39    pub fn conj_mut(&mut self, val: Value) -> &mut Self {
40        self.inner.insert_mut(val);
41        self
42    }
43
44    /// Return a new set with `val` removed.
45    pub fn disj(&self, val: &Value) -> Self {
46        Self {
47            inner: self.inner.remove(val),
48        }
49    }
50
51    /// Iterate over all elements in an unspecified order.
52    pub fn iter(&self) -> impl Iterator<Item = &Value> {
53        self.inner.iter()
54    }
55
56    pub fn inner(&self) -> &rpds::HashTrieSetSync<Value> {
57        &self.inner
58    }
59}
60
61impl FromIterator<Value> for PersistentHashSet {
62    fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
63        let mut inner = rpds::HashTrieSetSync::new_sync();
64        for v in iter {
65            inner.insert_mut(v)
66        }
67        Self { inner }
68    }
69}
70
71impl PartialEq for PersistentHashSet {
72    fn eq(&self, other: &Self) -> bool {
73        if self.count() != other.count() {
74            return false;
75        }
76        self.inner.iter().all(|k| other.contains(k))
77    }
78}
79
80impl cljrs_gc::Trace for PersistentHashSet {
81    fn trace(&self, visitor: &mut cljrs_gc::MarkVisitor) {
82        for v in self.inner.iter() {
83            v.trace(visitor);
84        }
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::Value;
92
93    fn int(n: i64) -> Value {
94        Value::Long(n)
95    }
96
97    #[test]
98    fn test_basic() {
99        let s = PersistentHashSet::empty();
100        let s = s.conj(int(1)).conj(int(2)).conj(int(3));
101        assert_eq!(s.count(), 3);
102        assert!(s.contains(&int(1)));
103        assert!(s.contains(&int(2)));
104        assert!(!s.contains(&int(99)));
105    }
106
107    #[test]
108    fn test_idempotent_conj() {
109        let s = PersistentHashSet::empty().conj(int(1)).conj(int(1));
110        assert_eq!(s.count(), 1);
111    }
112
113    #[test]
114    fn test_disj() {
115        let s = PersistentHashSet::empty().conj(int(1)).conj(int(2));
116        let s2 = s.disj(&int(1));
117        assert!(!s2.contains(&int(1)));
118        assert!(s2.contains(&int(2)));
119        assert_eq!(s2.count(), 1);
120    }
121
122    #[test]
123    fn test_equality_order_independent() {
124        let a = PersistentHashSet::from_iter([int(1), int(2), int(3)]);
125        let b = PersistentHashSet::from_iter([int(3), int(1), int(2)]);
126        assert_eq!(a, b);
127    }
128}