facet_core/impls_std/
hashset.rs

1use core::hash::BuildHasher;
2use std::collections::HashSet;
3
4use crate::ptr::{PtrConst, PtrMut};
5
6use crate::{
7    Def, Facet, IterVTable, MarkerTraits, SetDef, SetVTable, Shape, Type, TypeParam, UserType,
8    ValueVTable,
9};
10
11type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
12
13unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
14where
15    T: Facet<'a> + core::cmp::Eq + core::hash::Hash,
16    S: Facet<'a> + Default + BuildHasher,
17{
18    const VTABLE: &'static ValueVTable = &const {
19        ValueVTable::builder::<Self>()
20            .marker_traits(|| {
21                MarkerTraits::SEND
22                    .union(MarkerTraits::SYNC)
23                    .union(MarkerTraits::EQ)
24                    .union(MarkerTraits::UNPIN)
25                    .intersection(T::SHAPE.vtable.marker_traits())
26            })
27            .type_name(|f, opts| {
28                if let Some(opts) = opts.for_children() {
29                    write!(f, "{}<", Self::SHAPE.type_identifier)?;
30                    (T::SHAPE.vtable.type_name())(f, opts)?;
31                    write!(f, ">")
32                } else {
33                    write!(f, "HashSet<⋯>")
34                }
35            })
36            .default_in_place(|| Some(|target| unsafe { target.put(Self::default()) }))
37            .build()
38    };
39
40    const SHAPE: &'static Shape = &const {
41        Shape::builder_for_sized::<Self>()
42            .type_identifier("HashSet")
43            .type_params(&[
44                TypeParam {
45                    name: "T",
46                    shape: || T::SHAPE,
47                },
48                TypeParam {
49                    name: "S",
50                    shape: || S::SHAPE,
51                },
52            ])
53            .ty(Type::User(UserType::Opaque))
54            .def(Def::Set(
55                SetDef::builder()
56                    .t(|| T::SHAPE)
57                    .vtable(
58                        &const {
59                            SetVTable::builder()
60                                .init_in_place_with_capacity(|uninit, capacity| unsafe {
61                                    uninit
62                                        .put(Self::with_capacity_and_hasher(capacity, S::default()))
63                                })
64                                .insert(|ptr, item| unsafe {
65                                    let set = ptr.as_mut::<HashSet<T>>();
66                                    let item = item.read::<T>();
67                                    set.insert(item)
68                                })
69                                .len(|ptr| unsafe {
70                                    let set = ptr.get::<HashSet<T>>();
71                                    set.len()
72                                })
73                                .contains(|ptr, item| unsafe {
74                                    let set = ptr.get::<HashSet<T>>();
75                                    set.contains(item.get())
76                                })
77                                .iter_vtable(
78                                    IterVTable::builder()
79                                        .init_with_value(|ptr| unsafe {
80                                            let set = ptr.get::<HashSet<T>>();
81                                            let iter: HashSetIterator<'_, T> = set.iter();
82                                            let iter_state = Box::new(iter);
83                                            PtrMut::new(Box::into_raw(iter_state) as *mut u8)
84                                        })
85                                        .next(|iter_ptr| unsafe {
86                                            let state = iter_ptr.as_mut::<HashSetIterator<'_, T>>();
87                                            state.next().map(|value| PtrConst::new(value))
88                                        })
89                                        .dealloc(|iter_ptr| unsafe {
90                                            drop(Box::from_raw(
91                                                iter_ptr.as_ptr::<HashSetIterator<'_, T>>()
92                                                    as *mut HashSetIterator<'_, T>,
93                                            ));
94                                        })
95                                        .build(),
96                                )
97                                .build()
98                        },
99                    )
100                    .build(),
101            ))
102            .build()
103    };
104}
105
106#[cfg(test)]
107mod tests {
108    use alloc::string::String;
109    use std::collections::HashSet;
110    use std::hash::RandomState;
111
112    use super::*;
113
114    #[test]
115    fn test_hashset_type_params() {
116        // HashSet should have a type param for both its value type
117        // and its hasher state
118        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
119            panic!("HashSet<T> should have 2 type params")
120        };
121        assert_eq!(type_param_1.shape(), i32::SHAPE);
122        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
123    }
124
125    #[test]
126    fn test_hashset_vtable_1_new_insert_iter_drop() {
127        facet_testhelpers::setup();
128
129        let hashset_shape = <HashSet<String>>::SHAPE;
130        let hashset_def = hashset_shape
131            .def
132            .into_set()
133            .expect("HashSet<T> should have a set definition");
134
135        // Allocate memory for the HashSet
136        let hashset_uninit_ptr = hashset_shape.allocate().unwrap();
137
138        // Create the HashSet with a capacity of 3
139        let hashset_ptr =
140            unsafe { (hashset_def.vtable.init_in_place_with_capacity_fn)(hashset_uninit_ptr, 3) };
141
142        // The HashSet is empty, so ensure its length is 0
143        let hashset_actual_length = unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
144        assert_eq!(hashset_actual_length, 0);
145
146        // 5 sample values to insert
147        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
148
149        // Insert the 5 values into the HashSet
150        let mut hashset_length = 0;
151        for string in strings {
152            // Create the value
153            let mut new_value = string.to_string();
154
155            // Insert the value
156            let did_insert = unsafe {
157                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
158            };
159
160            // The value now belongs to the HashSet, so forget it
161            core::mem::forget(new_value);
162
163            assert!(did_insert, "expected value to be inserted in the HashSet");
164
165            // Ensure the HashSet's length increased by 1
166            hashset_length += 1;
167            let hashset_actual_length =
168                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
169            assert_eq!(hashset_actual_length, hashset_length);
170        }
171
172        // Insert the same 5 values again, ensuring they are deduplicated
173        for string in strings {
174            // Create the value
175            let mut new_value = string.to_string();
176
177            // Try to insert the value
178            let did_insert = unsafe {
179                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
180            };
181
182            // The value now belongs to the HashSet, so forget it
183            core::mem::forget(new_value);
184
185            assert!(
186                !did_insert,
187                "expected value to not be inserted in the HashSet"
188            );
189
190            // Ensure the HashSet's length did not increase
191            let hashset_actual_length =
192                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
193            assert_eq!(hashset_actual_length, hashset_length);
194        }
195
196        // Create a new iterator over the HashSet
197        let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
198        let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
199
200        // Collect all the items from the HashSet's iterator
201        let mut iter_items = HashSet::<&str>::new();
202        loop {
203            // Get the next item from the iterator
204            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
205            let Some(item_ptr) = item_ptr else {
206                break;
207            };
208
209            let item = unsafe { item_ptr.get::<String>() };
210
211            // Insert the item into the set of items returned from the iterator
212            let did_insert = iter_items.insert(&**item);
213
214            assert!(did_insert, "HashSet iterator returned duplicate item");
215        }
216
217        // Deallocate the iterator
218        unsafe {
219            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
220        }
221
222        // Ensure the iterator returned all of the strings
223        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
224
225        // Get the function pointer for dropping the HashSet
226        let drop_fn = (hashset_shape.vtable.sized().unwrap().drop_in_place)()
227            .expect("HashSet<T> should have drop_in_place");
228
229        // Drop the HashSet in place
230        unsafe { drop_fn(hashset_ptr) };
231
232        // Deallocate the memory
233        unsafe { hashset_shape.deallocate_mut(hashset_ptr).unwrap() };
234    }
235}