facet_core/impls_std/
hashset.rs

1use alloc::collections::VecDeque;
2use core::hash::{BuildHasher, Hash};
3use std::collections::HashSet;
4
5use crate::ptr::{PtrConst, PtrMut};
6
7use crate::{
8    Def, Facet, MarkerTraits, SetDef, SetIterVTable, SetVTable, Shape, Type, TypeParam, UserType,
9    VTableView, ValueVTable,
10};
11
12struct HashSetIterator<'mem, T> {
13    items: VecDeque<&'mem T>,
14}
15
16unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
17where
18    T: Facet<'a> + core::cmp::Eq + core::hash::Hash,
19    S: Facet<'a> + Default + BuildHasher,
20{
21    const VTABLE: &'static ValueVTable = &const {
22        let mut builder = ValueVTable::builder::<Self>()
23            .marker_traits(
24                MarkerTraits::SEND
25                    .union(MarkerTraits::SYNC)
26                    .union(MarkerTraits::EQ)
27                    .union(MarkerTraits::UNPIN)
28                    .intersection(T::SHAPE.vtable.marker_traits),
29            )
30            .type_name(|f, opts| {
31                if let Some(opts) = opts.for_children() {
32                    write!(f, "HashSet<")?;
33                    (T::SHAPE.vtable.type_name)(f, opts)?;
34                    write!(f, ">")
35                } else {
36                    write!(f, "HashSet<⋯>")
37                }
38            })
39            .default_in_place(|target| unsafe { target.put(Self::default()) })
40            .eq(|a, b| a == b);
41
42        if T::SHAPE.vtable.debug.is_some() {
43            builder = builder.debug(|value, f| {
44                let t_debug = <VTableView<T>>::of().debug().unwrap();
45                write!(f, "{{")?;
46                for (i, item) in value.iter().enumerate() {
47                    if i > 0 {
48                        write!(f, ", ")?;
49                    }
50                    (t_debug)(item, f)?;
51                }
52                write!(f, "}}")
53            });
54        }
55
56        if T::SHAPE.vtable.clone_into.is_some() {
57            builder = builder.clone_into(|src, dst| unsafe {
58                let set = src;
59                let mut new_set = HashSet::with_capacity_and_hasher(set.len(), S::default());
60
61                let t_clone_into = <VTableView<T>>::of().clone_into().unwrap();
62
63                for item in set {
64                    use crate::TypedPtrUninit;
65                    use core::mem::MaybeUninit;
66
67                    let mut new_item = MaybeUninit::<T>::uninit();
68                    let uninit_item = TypedPtrUninit::new(new_item.as_mut_ptr());
69
70                    (t_clone_into)(item, uninit_item);
71
72                    new_set.insert(new_item.assume_init());
73                }
74
75                dst.put(new_set)
76            });
77        }
78
79        if T::SHAPE.vtable.hash.is_some() {
80            builder = builder.hash(|set, hasher_this, hasher_write_fn| unsafe {
81                use crate::HasherProxy;
82                let t_hash = <VTableView<T>>::of().hash().unwrap();
83                let mut hasher = HasherProxy::new(hasher_this, hasher_write_fn);
84                set.len().hash(&mut hasher);
85                for item in set {
86                    (t_hash)(item, hasher_this, hasher_write_fn);
87                }
88            });
89        }
90
91        builder.build()
92    };
93
94    const SHAPE: &'static Shape = &const {
95        Shape::builder_for_sized::<Self>()
96            .type_params(&[
97                TypeParam {
98                    name: "T",
99                    shape: || T::SHAPE,
100                },
101                TypeParam {
102                    name: "S",
103                    shape: || S::SHAPE,
104                },
105            ])
106            .ty(Type::User(UserType::Opaque))
107            .def(Def::Set(
108                SetDef::builder()
109                    .t(|| T::SHAPE)
110                    .vtable(
111                        &const {
112                            SetVTable::builder()
113                                .init_in_place_with_capacity(|uninit, capacity| unsafe {
114                                    uninit
115                                        .put(Self::with_capacity_and_hasher(capacity, S::default()))
116                                })
117                                .insert(|ptr, item| unsafe {
118                                    let set = ptr.as_mut::<HashSet<T>>();
119                                    let item = item.read::<T>();
120                                    set.insert(item)
121                                })
122                                .len(|ptr| unsafe {
123                                    let set = ptr.get::<HashSet<T>>();
124                                    set.len()
125                                })
126                                .contains(|ptr, item| unsafe {
127                                    let set = ptr.get::<HashSet<T>>();
128                                    set.contains(item.get())
129                                })
130                                .iter(|ptr| unsafe {
131                                    let set = ptr.get::<HashSet<T>>();
132                                    let items: VecDeque<&T> = set.iter().collect();
133                                    let iter_state = Box::new(HashSetIterator { items });
134                                    PtrMut::new(Box::into_raw(iter_state) as *mut u8)
135                                })
136                                .iter_vtable(
137                                    SetIterVTable::builder()
138                                        .next(|iter_ptr| unsafe {
139                                            let state = iter_ptr.as_mut::<HashSetIterator<'_, T>>();
140                                            state
141                                                .items
142                                                .pop_front()
143                                                .map(|item| PtrConst::new(item as *const T))
144                                        })
145                                        .dealloc(|iter_ptr| unsafe {
146                                            drop(Box::from_raw(
147                                                iter_ptr.as_ptr::<HashSetIterator<'_, T>>()
148                                                    as *mut HashSetIterator<'_, T>,
149                                            ));
150                                        })
151                                        .build(),
152                                )
153                                .build()
154                        },
155                    )
156                    .build(),
157            ))
158            .build()
159    };
160}
161
162#[cfg(test)]
163mod tests {
164    use alloc::string::String;
165    use std::collections::HashSet;
166    use std::hash::RandomState;
167
168    use super::*;
169
170    #[test]
171    fn test_hashset_type_params() {
172        // HashSet should have a type param for both its value type
173        // and its hasher state
174        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
175            panic!("HashSet<T> should have 2 type params")
176        };
177        assert_eq!(type_param_1.shape(), i32::SHAPE);
178        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
179    }
180
181    #[test]
182    fn test_hashset_vtable_1_new_insert_iter_drop() -> eyre::Result<()> {
183        facet_testhelpers::setup();
184
185        let hashset_shape = <HashSet<String>>::SHAPE;
186        let hashset_def = hashset_shape
187            .def
188            .into_set()
189            .expect("HashSet<T> should have a set definition");
190
191        // Allocate memory for the HashSet
192        let hashset_uninit_ptr = hashset_shape.allocate()?;
193
194        // Create the HashSet with a capacity of 3
195        let hashset_ptr =
196            unsafe { (hashset_def.vtable.init_in_place_with_capacity_fn)(hashset_uninit_ptr, 3) };
197
198        // The HashSet is empty, so ensure its length is 0
199        let hashset_actual_length = unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
200        assert_eq!(hashset_actual_length, 0);
201
202        // 5 sample values to insert
203        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
204
205        // Insert the 5 values into the HashSet
206        let mut hashset_length = 0;
207        for string in strings {
208            // Create the value
209            let mut new_value = string.to_string();
210
211            // Insert the value
212            let did_insert = unsafe {
213                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
214            };
215
216            // The value now belongs to the HashSet, so forget it
217            core::mem::forget(new_value);
218
219            assert!(did_insert, "expected value to be inserted in the HashSet");
220
221            // Ensure the HashSet's length increased by 1
222            hashset_length += 1;
223            let hashset_actual_length =
224                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
225            assert_eq!(hashset_actual_length, hashset_length);
226        }
227
228        // Insert the same 5 values again, ensuring they are deduplicated
229        for string in strings {
230            // Create the value
231            let mut new_value = string.to_string();
232
233            // Try to insert the value
234            let did_insert = unsafe {
235                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
236            };
237
238            // The value now belongs to the HashSet, so forget it
239            core::mem::forget(new_value);
240
241            assert!(
242                !did_insert,
243                "expected value to not be inserted in the HashSet"
244            );
245
246            // Ensure the HashSet's length did not increase
247            let hashset_actual_length =
248                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
249            assert_eq!(hashset_actual_length, hashset_length);
250        }
251
252        // Create a new iterator over the HashSet
253        let hashset_iter_ptr = unsafe { (hashset_def.vtable.iter_fn)(hashset_ptr.as_const()) };
254
255        // Collect all the items from the HashSet's iterator
256        let mut iter_items = HashSet::<&str>::new();
257        loop {
258            // Get the next item from the iterator
259            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
260            let Some(item_ptr) = item_ptr else {
261                break;
262            };
263
264            let item = unsafe { item_ptr.get::<String>() };
265
266            // Insert the item into the set of items returned from the iterator
267            let did_insert = iter_items.insert(&**item);
268
269            assert!(did_insert, "HashSet iterator returned duplicate item");
270        }
271
272        // Deallocate the iterator
273        unsafe {
274            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
275        }
276
277        // Ensure the iterator returned all of the strings
278        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
279
280        // Get the function pointer for dropping the HashSet
281        let drop_fn = hashset_shape
282            .vtable
283            .drop_in_place
284            .expect("HashSet<T> should have drop_in_place");
285
286        // Drop the HashSet in place
287        unsafe { drop_fn(hashset_ptr) };
288
289        // Deallocate the memory
290        unsafe { hashset_shape.deallocate_mut(hashset_ptr)? };
291
292        Ok(())
293    }
294}