1use std::collections::HashSet;
2use std::sync::Arc;
3
4use super::{value_structural_hash_key, VmValue};
5
6#[derive(Debug, Clone, Default)]
19pub struct VmSet {
20 items: Arc<Vec<VmValue>>,
24 keys: HashSet<String>,
25}
26
27impl VmSet {
28 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn with_capacity(capacity: usize) -> Self {
33 Self {
34 items: Arc::new(Vec::with_capacity(capacity)),
35 keys: HashSet::with_capacity(capacity),
36 }
37 }
38
39 pub fn insert(&mut self, value: VmValue) -> bool {
42 let key = value_structural_hash_key(&value);
43 if self.keys.insert(key) {
44 Arc::make_mut(&mut self.items).push(value);
45 true
46 } else {
47 false
48 }
49 }
50
51 pub fn contains(&self, value: &VmValue) -> bool {
53 self.keys.contains(&value_structural_hash_key(value))
54 }
55
56 pub fn remove(&mut self, value: &VmValue) -> bool {
58 let key = value_structural_hash_key(value);
59 if self.keys.remove(&key) {
60 Arc::make_mut(&mut self.items).retain(|item| value_structural_hash_key(item) != key);
61 true
62 } else {
63 false
64 }
65 }
66
67 pub fn len(&self) -> usize {
68 self.items.len()
69 }
70
71 pub fn is_empty(&self) -> bool {
72 self.items.is_empty()
73 }
74
75 pub fn iter(&self) -> std::slice::Iter<'_, VmValue> {
76 self.items.iter()
77 }
78
79 pub fn items(&self) -> &[VmValue] {
80 self.items.as_slice()
81 }
82
83 pub fn shared_items(&self) -> Arc<Vec<VmValue>> {
86 Arc::clone(&self.items)
87 }
88
89 pub fn into_items(self) -> Vec<VmValue> {
92 Arc::try_unwrap(self.items).unwrap_or_else(|items| (*items).clone())
93 }
94
95 pub fn union(&self, other: &VmSet) -> VmSet {
97 let mut out = self.clone();
98 for value in other.iter() {
99 out.insert(value.clone());
100 }
101 out
102 }
103
104 pub fn intersect(&self, other: &VmSet) -> VmSet {
106 self.items
107 .iter()
108 .filter(|item| other.contains(item))
109 .cloned()
110 .collect()
111 }
112
113 pub fn difference(&self, other: &VmSet) -> VmSet {
115 self.items
116 .iter()
117 .filter(|item| !other.contains(item))
118 .cloned()
119 .collect()
120 }
121
122 pub fn symmetric_difference(&self, other: &VmSet) -> VmSet {
124 let mut out: VmSet = self
125 .iter()
126 .filter(|item| !other.contains(item))
127 .cloned()
128 .collect();
129 for value in other.iter() {
130 if !self.contains(value) {
131 out.insert(value.clone());
132 }
133 }
134 out
135 }
136
137 pub fn is_subset(&self, other: &VmSet) -> bool {
138 self.items.iter().all(|item| other.contains(item))
139 }
140
141 pub fn is_superset(&self, other: &VmSet) -> bool {
142 other.is_subset(self)
143 }
144
145 pub fn is_disjoint(&self, other: &VmSet) -> bool {
146 !self.items.iter().any(|item| other.contains(item))
147 }
148
149 pub fn sorted_keys(&self) -> Vec<&str> {
153 let mut keys: Vec<&str> = self.keys.iter().map(String::as_str).collect();
154 keys.sort_unstable();
155 keys
156 }
157}
158
159impl FromIterator<VmValue> for VmSet {
160 fn from_iter<I: IntoIterator<Item = VmValue>>(iter: I) -> Self {
161 let iter = iter.into_iter();
162 let mut set = VmSet::with_capacity(iter.size_hint().0);
163 for value in iter {
164 set.insert(value);
165 }
166 set
167 }
168}
169
170impl<'a> IntoIterator for &'a VmSet {
171 type Item = &'a VmValue;
172 type IntoIter = std::slice::Iter<'a, VmValue>;
173
174 fn into_iter(self) -> Self::IntoIter {
175 self.iter()
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::value::{value_structural_hash_key, values_equal};
183
184 fn int_set(values: &[i64]) -> VmSet {
185 values.iter().map(|n| VmValue::Int(*n)).collect()
186 }
187
188 #[test]
189 fn insert_dedups_and_preserves_first_seen_order() {
190 let set = int_set(&[3, 1, 3, 2, 1]);
191 assert_eq!(set.len(), 3);
192 let order: Vec<i64> = set
193 .iter()
194 .map(|v| match v {
195 VmValue::Int(n) => *n,
196 _ => unreachable!(),
197 })
198 .collect();
199 assert_eq!(order, vec![3, 1, 2]);
200 }
201
202 #[test]
203 fn contains_is_structural() {
204 let set = int_set(&[1, 2, 3]);
205 assert!(set.contains(&VmValue::Int(2)));
206 assert!(!set.contains(&VmValue::Int(9)));
207 assert!(set.contains(&VmValue::Float(1.0)));
209 }
210
211 #[test]
212 fn remove_updates_index_and_order() {
213 let mut set = int_set(&[1, 2, 3]);
214 assert!(set.remove(&VmValue::Int(2)));
215 assert!(!set.remove(&VmValue::Int(2)));
216 assert!(!set.contains(&VmValue::Int(2)));
217 assert_eq!(set.len(), 2);
218 }
219
220 #[test]
221 fn set_algebra() {
222 let a = int_set(&[1, 2, 3]);
223 let b = int_set(&[2, 3, 4]);
224 assert_eq!(a.union(&b).len(), 4);
225 assert_eq!(a.intersect(&b).len(), 2);
226 assert!(a.intersect(&b).contains(&VmValue::Int(2)));
227 let diff = a.difference(&b);
228 assert_eq!(diff.len(), 1);
229 assert!(diff.contains(&VmValue::Int(1)));
230 let sym = a.symmetric_difference(&b);
231 assert_eq!(sym.len(), 2);
232 assert!(sym.contains(&VmValue::Int(1)) && sym.contains(&VmValue::Int(4)));
233 assert!(int_set(&[1, 2]).is_subset(&a));
234 assert!(a.is_superset(&int_set(&[1, 2])));
235 assert!(a.is_disjoint(&int_set(&[7, 8])));
236 assert!(!a.is_disjoint(&b));
237 }
238
239 #[test]
240 fn equality_and_hash_are_order_independent() {
241 let a = VmValue::set([VmValue::Int(1), VmValue::Int(2), VmValue::Int(3)]);
242 let b = VmValue::set([VmValue::Int(3), VmValue::Int(2), VmValue::Int(1)]);
243 assert!(values_equal(&a, &b));
244 assert_eq!(value_structural_hash_key(&a), value_structural_hash_key(&b));
245 let c = VmValue::set([VmValue::Int(1), VmValue::Int(2)]);
246 assert!(!values_equal(&a, &c));
247 }
248}