facet_core/impls_std/
hashset.rs

1use core::hash::BuildHasher;
2use core::ptr::NonNull;
3use std::collections::HashSet;
4
5use crate::ptr::{PtrConst, PtrMut};
6
7use crate::{
8    Def, Facet, IterVTable, MarkerTraits, SetDef, SetVTable, Shape, Type, TypeParam, UserType,
9    ValueVTable,
10};
11
12type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
13
14unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
15where
16    T: Facet<'a> + core::cmp::Eq + core::hash::Hash,
17    S: Facet<'a> + Default + BuildHasher,
18{
19    const SHAPE: &'static Shape = &const {
20        Shape::builder_for_sized::<Self>()
21            .vtable(
22                ValueVTable::builder::<Self>()
23                    .marker_traits({
24                        MarkerTraits::SEND
25                            .union(MarkerTraits::SYNC)
26                            .union(MarkerTraits::EQ)
27                            .union(MarkerTraits::UNPIN)
28                            .intersection(T::SHAPE.vtable.marker_traits())
29                    })
30                    .type_name(|f, opts| {
31                        write!(f, "{}<", Self::SHAPE.type_identifier)?;
32                        if let Some(opts) = opts.for_children() {
33                            (T::SHAPE.vtable.type_name())(f, opts)?;
34                        } else {
35                            write!(f, "…")?;
36                        }
37                        write!(f, ">")
38                    })
39                    .default_in_place({
40                        Some(|target| unsafe { target.put(Self::default()).into() })
41                    })
42                    .build(),
43            )
44            .type_identifier("HashSet")
45            .type_params(&[
46                TypeParam {
47                    name: "T",
48                    shape: T::SHAPE,
49                },
50                TypeParam {
51                    name: "S",
52                    shape: S::SHAPE,
53                },
54            ])
55            .ty(Type::User(UserType::Opaque))
56            .def(Def::Set(
57                SetDef::builder()
58                    .t(T::SHAPE)
59                    .vtable(
60                        &const {
61                            SetVTable::builder()
62                                .init_in_place_with_capacity(|uninit, capacity| unsafe {
63                                    uninit
64                                        .put(Self::with_capacity_and_hasher(capacity, S::default()))
65                                })
66                                .insert(|ptr, item| unsafe {
67                                    let set = ptr.as_mut::<HashSet<T>>();
68                                    let item = item.read::<T>();
69                                    set.insert(item)
70                                })
71                                .len(|ptr| unsafe {
72                                    let set = ptr.get::<HashSet<T>>();
73                                    set.len()
74                                })
75                                .contains(|ptr, item| unsafe {
76                                    let set = ptr.get::<HashSet<T>>();
77                                    set.contains(item.get())
78                                })
79                                .iter_vtable(
80                                    IterVTable::builder()
81                                        .init_with_value(|ptr| unsafe {
82                                            let set = ptr.get::<HashSet<T>>();
83                                            let iter: HashSetIterator<'_, T> = set.iter();
84                                            let iter_state = Box::new(iter);
85                                            PtrMut::new(NonNull::new_unchecked(Box::into_raw(
86                                                iter_state,
87                                            )
88                                                as *mut u8))
89                                        })
90                                        .next(|iter_ptr| unsafe {
91                                            let state = iter_ptr.as_mut::<HashSetIterator<'_, T>>();
92                                            state
93                                                .next()
94                                                .map(|value| PtrConst::new(NonNull::from(value)))
95                                        })
96                                        .dealloc(|iter_ptr| unsafe {
97                                            drop(Box::from_raw(
98                                                iter_ptr.as_ptr::<HashSetIterator<'_, T>>()
99                                                    as *mut HashSetIterator<'_, T>,
100                                            ));
101                                        })
102                                        .build(),
103                                )
104                                .build()
105                        },
106                    )
107                    .build(),
108            ))
109            .build()
110    };
111}
112
113#[cfg(test)]
114mod tests {
115    use alloc::string::String;
116    use core::ptr::NonNull;
117    use std::collections::HashSet;
118    use std::hash::RandomState;
119
120    use super::*;
121
122    #[test]
123    fn test_hashset_type_params() {
124        // HashSet should have a type param for both its value type
125        // and its hasher state
126        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
127            panic!("HashSet<T> should have 2 type params")
128        };
129        assert_eq!(type_param_1.shape(), i32::SHAPE);
130        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
131    }
132
133    #[test]
134    fn test_hashset_vtable_1_new_insert_iter_drop() {
135        facet_testhelpers::setup();
136
137        let hashset_shape = <HashSet<String>>::SHAPE;
138        let hashset_def = hashset_shape
139            .def
140            .into_set()
141            .expect("HashSet<T> should have a set definition");
142
143        // Allocate memory for the HashSet
144        let hashset_uninit_ptr = hashset_shape.allocate().unwrap();
145
146        // Create the HashSet with a capacity of 3
147        let hashset_ptr =
148            unsafe { (hashset_def.vtable.init_in_place_with_capacity_fn)(hashset_uninit_ptr, 3) };
149
150        // The HashSet is empty, so ensure its length is 0
151        let hashset_actual_length = unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
152        assert_eq!(hashset_actual_length, 0);
153
154        // 5 sample values to insert
155        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
156
157        // Insert the 5 values into the HashSet
158        let mut hashset_length = 0;
159        for string in strings {
160            // Create the value
161            let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
162
163            // Insert the value
164            let did_insert = unsafe {
165                (hashset_def.vtable.insert_fn)(
166                    hashset_ptr,
167                    PtrMut::new(NonNull::from(&mut new_value)),
168                )
169            };
170
171            assert!(did_insert, "expected value to be inserted in the HashSet");
172
173            // Ensure the HashSet's length increased by 1
174            hashset_length += 1;
175            let hashset_actual_length =
176                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
177            assert_eq!(hashset_actual_length, hashset_length);
178        }
179
180        // Insert the same 5 values again, ensuring they are deduplicated
181        for string in strings {
182            // Create the value
183            let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
184
185            // Try to insert the value
186            let did_insert = unsafe {
187                (hashset_def.vtable.insert_fn)(
188                    hashset_ptr,
189                    PtrMut::new(NonNull::from(&mut new_value)),
190                )
191            };
192
193            assert!(
194                !did_insert,
195                "expected value to not be inserted in the HashSet"
196            );
197
198            // Ensure the HashSet's length did not increase
199            let hashset_actual_length =
200                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
201            assert_eq!(hashset_actual_length, hashset_length);
202        }
203
204        // Create a new iterator over the HashSet
205        let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
206        let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
207
208        // Collect all the items from the HashSet's iterator
209        let mut iter_items = HashSet::<&str>::new();
210        loop {
211            // Get the next item from the iterator
212            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
213            let Some(item_ptr) = item_ptr else {
214                break;
215            };
216
217            let item = unsafe { item_ptr.get::<String>() };
218
219            // Insert the item into the set of items returned from the iterator
220            let did_insert = iter_items.insert(&**item);
221
222            assert!(did_insert, "HashSet iterator returned duplicate item");
223        }
224
225        // Deallocate the iterator
226        unsafe {
227            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
228        }
229
230        // Ensure the iterator returned all of the strings
231        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
232
233        // Get the function pointer for dropping the HashSet
234        let drop_fn = hashset_shape
235            .vtable
236            .drop_in_place
237            .expect("HashSet<T> should have drop_in_place");
238
239        // Drop the HashSet in place
240        unsafe { drop_fn(hashset_ptr) };
241
242        // Deallocate the memory
243        unsafe { hashset_shape.deallocate_mut(hashset_ptr).unwrap() };
244    }
245}