facet_core/impls_std/
hashset.rs

1use core::hash::{BuildHasher, Hash};
2use std::collections::HashSet;
3
4use crate::ptr::{PtrConst, PtrMut};
5
6use crate::{
7    Def, Facet, IterVTable, MarkerTraits, SetDef, SetVTable, Shape, Type, TypeParam, UserType,
8    VTableView, ValueVTable,
9};
10
11type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
12
13unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
14where
15    T: Facet<'a> + core::cmp::Eq + core::hash::Hash,
16    S: Facet<'a> + Default + BuildHasher,
17{
18    const VTABLE: &'static ValueVTable = &const {
19        let mut builder = ValueVTable::builder::<Self>()
20            .marker_traits(
21                MarkerTraits::SEND
22                    .union(MarkerTraits::SYNC)
23                    .union(MarkerTraits::EQ)
24                    .union(MarkerTraits::UNPIN)
25                    .intersection(T::SHAPE.vtable.marker_traits),
26            )
27            .type_name(|f, opts| {
28                if let Some(opts) = opts.for_children() {
29                    write!(f, "HashSet<")?;
30                    (T::SHAPE.vtable.type_name)(f, opts)?;
31                    write!(f, ">")
32                } else {
33                    write!(f, "HashSet<⋯>")
34                }
35            })
36            .default_in_place(|target| unsafe { target.put(Self::default()) })
37            .eq(|a, b| a == b);
38
39        if T::SHAPE.vtable.debug.is_some() {
40            builder = builder.debug(|value, f| {
41                let t_debug = <VTableView<T>>::of().debug().unwrap();
42                write!(f, "{{")?;
43                for (i, item) in value.iter().enumerate() {
44                    if i > 0 {
45                        write!(f, ", ")?;
46                    }
47                    (t_debug)(item, f)?;
48                }
49                write!(f, "}}")
50            });
51        }
52
53        if T::SHAPE.vtable.clone_into.is_some() {
54            builder = builder.clone_into(|src, dst| unsafe {
55                let set = src;
56                let mut new_set = HashSet::with_capacity_and_hasher(set.len(), S::default());
57
58                let t_clone_into = <VTableView<T>>::of().clone_into().unwrap();
59
60                for item in set {
61                    use crate::TypedPtrUninit;
62                    use core::mem::MaybeUninit;
63
64                    let mut new_item = MaybeUninit::<T>::uninit();
65                    let uninit_item = TypedPtrUninit::new(new_item.as_mut_ptr());
66
67                    (t_clone_into)(item, uninit_item);
68
69                    new_set.insert(new_item.assume_init());
70                }
71
72                dst.put(new_set)
73            });
74        }
75
76        if T::SHAPE.vtable.hash.is_some() {
77            builder = builder.hash(|set, hasher_this, hasher_write_fn| unsafe {
78                use crate::HasherProxy;
79                let t_hash = <VTableView<T>>::of().hash().unwrap();
80                let mut hasher = HasherProxy::new(hasher_this, hasher_write_fn);
81                set.len().hash(&mut hasher);
82                for item in set {
83                    (t_hash)(item, hasher_this, hasher_write_fn);
84                }
85            });
86        }
87
88        builder.build()
89    };
90
91    const SHAPE: &'static Shape<'static> = &const {
92        Shape::builder_for_sized::<Self>()
93            .type_params(&[
94                TypeParam {
95                    name: "T",
96                    shape: || T::SHAPE,
97                },
98                TypeParam {
99                    name: "S",
100                    shape: || S::SHAPE,
101                },
102            ])
103            .ty(Type::User(UserType::Opaque))
104            .def(Def::Set(
105                SetDef::builder()
106                    .t(|| T::SHAPE)
107                    .vtable(
108                        &const {
109                            SetVTable::builder()
110                                .init_in_place_with_capacity(|uninit, capacity| unsafe {
111                                    uninit
112                                        .put(Self::with_capacity_and_hasher(capacity, S::default()))
113                                })
114                                .insert(|ptr, item| unsafe {
115                                    let set = ptr.as_mut::<HashSet<T>>();
116                                    let item = item.read::<T>();
117                                    set.insert(item)
118                                })
119                                .len(|ptr| unsafe {
120                                    let set = ptr.get::<HashSet<T>>();
121                                    set.len()
122                                })
123                                .contains(|ptr, item| unsafe {
124                                    let set = ptr.get::<HashSet<T>>();
125                                    set.contains(item.get())
126                                })
127                                .iter_vtable(
128                                    IterVTable::builder()
129                                        .init_with_value(|ptr| unsafe {
130                                            let set = ptr.get::<HashSet<T>>();
131                                            let iter: HashSetIterator<'_, T> = set.iter();
132                                            let iter_state = Box::new(iter);
133                                            PtrMut::new(Box::into_raw(iter_state) as *mut u8)
134                                        })
135                                        .next(|iter_ptr| unsafe {
136                                            let state = iter_ptr.as_mut::<HashSetIterator<'_, T>>();
137                                            state.next().map(|value| PtrConst::new(value))
138                                        })
139                                        .dealloc(|iter_ptr| unsafe {
140                                            drop(Box::from_raw(
141                                                iter_ptr.as_ptr::<HashSetIterator<'_, T>>()
142                                                    as *mut HashSetIterator<'_, T>,
143                                            ));
144                                        })
145                                        .build(),
146                                )
147                                .build()
148                        },
149                    )
150                    .build(),
151            ))
152            .build()
153    };
154}
155
156#[cfg(test)]
157mod tests {
158    use alloc::string::String;
159    use std::collections::HashSet;
160    use std::hash::RandomState;
161
162    use super::*;
163
164    #[test]
165    fn test_hashset_type_params() {
166        // HashSet should have a type param for both its value type
167        // and its hasher state
168        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
169            panic!("HashSet<T> should have 2 type params")
170        };
171        assert_eq!(type_param_1.shape(), i32::SHAPE);
172        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
173    }
174
175    #[test]
176    fn test_hashset_vtable_1_new_insert_iter_drop() -> eyre::Result<()> {
177        facet_testhelpers::setup();
178
179        let hashset_shape = <HashSet<String>>::SHAPE;
180        let hashset_def = hashset_shape
181            .def
182            .into_set()
183            .expect("HashSet<T> should have a set definition");
184
185        // Allocate memory for the HashSet
186        let hashset_uninit_ptr = hashset_shape.allocate()?;
187
188        // Create the HashSet with a capacity of 3
189        let hashset_ptr =
190            unsafe { (hashset_def.vtable.init_in_place_with_capacity_fn)(hashset_uninit_ptr, 3) };
191
192        // The HashSet is empty, so ensure its length is 0
193        let hashset_actual_length = unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
194        assert_eq!(hashset_actual_length, 0);
195
196        // 5 sample values to insert
197        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
198
199        // Insert the 5 values into the HashSet
200        let mut hashset_length = 0;
201        for string in strings {
202            // Create the value
203            let mut new_value = string.to_string();
204
205            // Insert the value
206            let did_insert = unsafe {
207                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
208            };
209
210            // The value now belongs to the HashSet, so forget it
211            core::mem::forget(new_value);
212
213            assert!(did_insert, "expected value to be inserted in the HashSet");
214
215            // Ensure the HashSet's length increased by 1
216            hashset_length += 1;
217            let hashset_actual_length =
218                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
219            assert_eq!(hashset_actual_length, hashset_length);
220        }
221
222        // Insert the same 5 values again, ensuring they are deduplicated
223        for string in strings {
224            // Create the value
225            let mut new_value = string.to_string();
226
227            // Try to insert the value
228            let did_insert = unsafe {
229                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
230            };
231
232            // The value now belongs to the HashSet, so forget it
233            core::mem::forget(new_value);
234
235            assert!(
236                !did_insert,
237                "expected value to not be inserted in the HashSet"
238            );
239
240            // Ensure the HashSet's length did not increase
241            let hashset_actual_length =
242                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
243            assert_eq!(hashset_actual_length, hashset_length);
244        }
245
246        // Create a new iterator over the HashSet
247        let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
248        let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
249
250        // Collect all the items from the HashSet's iterator
251        let mut iter_items = HashSet::<&str>::new();
252        loop {
253            // Get the next item from the iterator
254            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
255            let Some(item_ptr) = item_ptr else {
256                break;
257            };
258
259            let item = unsafe { item_ptr.get::<String>() };
260
261            // Insert the item into the set of items returned from the iterator
262            let did_insert = iter_items.insert(&**item);
263
264            assert!(did_insert, "HashSet iterator returned duplicate item");
265        }
266
267        // Deallocate the iterator
268        unsafe {
269            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
270        }
271
272        // Ensure the iterator returned all of the strings
273        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
274
275        // Get the function pointer for dropping the HashSet
276        let drop_fn = hashset_shape
277            .vtable
278            .drop_in_place
279            .expect("HashSet<T> should have drop_in_place");
280
281        // Drop the HashSet in place
282        unsafe { drop_fn(hashset_ptr) };
283
284        // Deallocate the memory
285        unsafe { hashset_shape.deallocate_mut(hashset_ptr)? };
286
287        Ok(())
288    }
289}