facet_core/impls_std/
hashset.rs

1use core::hash::{BuildHasher, Hash};
2use std::collections::HashSet;
3
4use crate::ptr::{PtrConst, PtrMut};
5
6use crate::{
7    Def, Facet, IterVTable, MarkerTraits, SetDef, SetVTable, Shape, Type, TypeParam, UserType,
8    VTableView, ValueVTable,
9};
10
11type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
12
13unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
14where
15    T: Facet<'a> + core::cmp::Eq + core::hash::Hash,
16    S: Facet<'a> + Default + BuildHasher,
17{
18    const VTABLE: &'static ValueVTable = &const {
19        ValueVTable::builder::<Self>()
20            .marker_traits(|| {
21                MarkerTraits::SEND
22                    .union(MarkerTraits::SYNC)
23                    .union(MarkerTraits::EQ)
24                    .union(MarkerTraits::UNPIN)
25                    .intersection(T::SHAPE.vtable.marker_traits())
26            })
27            .type_name(|f, opts| {
28                if let Some(opts) = opts.for_children() {
29                    write!(f, "{}<", Self::SHAPE.type_identifier)?;
30                    (T::SHAPE.vtable.type_name)(f, opts)?;
31                    write!(f, ">")
32                } else {
33                    write!(f, "HashSet<⋯>")
34                }
35            })
36            .default_in_place(|| Some(|target| unsafe { target.put(Self::default()) }))
37            .partial_eq(|| Some(|a, b| a == b))
38            .debug(|| {
39                if (T::SHAPE.vtable.debug)().is_some() {
40                    Some(|value, f| {
41                        let t_debug = <VTableView<T>>::of().debug().unwrap();
42                        write!(f, "{{")?;
43                        for (i, item) in value.iter().enumerate() {
44                            if i > 0 {
45                                write!(f, ", ")?;
46                            }
47                            (t_debug)(item, f)?;
48                        }
49                        write!(f, "}}")
50                    })
51                } else {
52                    None
53                }
54            })
55            .clone_into(|| {
56                if (T::SHAPE.vtable.clone_into)().is_some() {
57                    Some(|src, dst| unsafe {
58                        let set = src;
59                        let mut new_set =
60                            HashSet::with_capacity_and_hasher(set.len(), S::default());
61
62                        let t_clone_into = <VTableView<T>>::of().clone_into().unwrap();
63
64                        for item in set {
65                            use crate::TypedPtrUninit;
66                            use core::mem::MaybeUninit;
67
68                            let mut new_item = MaybeUninit::<T>::uninit();
69                            let uninit_item = TypedPtrUninit::new(new_item.as_mut_ptr());
70
71                            (t_clone_into)(item, uninit_item);
72
73                            new_set.insert(new_item.assume_init());
74                        }
75
76                        dst.put(new_set)
77                    })
78                } else {
79                    None
80                }
81            })
82            .hash(|| {
83                if (T::SHAPE.vtable.hash)().is_some() {
84                    Some(|set, hasher_this, hasher_write_fn| unsafe {
85                        use crate::HasherProxy;
86                        let t_hash = <VTableView<T>>::of().hash().unwrap();
87                        let mut hasher = HasherProxy::new(hasher_this, hasher_write_fn);
88                        set.len().hash(&mut hasher);
89                        for item in set {
90                            (t_hash)(item, hasher_this, hasher_write_fn);
91                        }
92                    })
93                } else {
94                    None
95                }
96            })
97            .build()
98    };
99
100    const SHAPE: &'static Shape<'static> = &const {
101        Shape::builder_for_sized::<Self>()
102            .type_identifier("HashSet")
103            .type_params(&[
104                TypeParam {
105                    name: "T",
106                    shape: || T::SHAPE,
107                },
108                TypeParam {
109                    name: "S",
110                    shape: || S::SHAPE,
111                },
112            ])
113            .ty(Type::User(UserType::Opaque))
114            .def(Def::Set(
115                SetDef::builder()
116                    .t(|| T::SHAPE)
117                    .vtable(
118                        &const {
119                            SetVTable::builder()
120                                .init_in_place_with_capacity(|uninit, capacity| unsafe {
121                                    uninit
122                                        .put(Self::with_capacity_and_hasher(capacity, S::default()))
123                                })
124                                .insert(|ptr, item| unsafe {
125                                    let set = ptr.as_mut::<HashSet<T>>();
126                                    let item = item.read::<T>();
127                                    set.insert(item)
128                                })
129                                .len(|ptr| unsafe {
130                                    let set = ptr.get::<HashSet<T>>();
131                                    set.len()
132                                })
133                                .contains(|ptr, item| unsafe {
134                                    let set = ptr.get::<HashSet<T>>();
135                                    set.contains(item.get())
136                                })
137                                .iter_vtable(
138                                    IterVTable::builder()
139                                        .init_with_value(|ptr| unsafe {
140                                            let set = ptr.get::<HashSet<T>>();
141                                            let iter: HashSetIterator<'_, T> = set.iter();
142                                            let iter_state = Box::new(iter);
143                                            PtrMut::new(Box::into_raw(iter_state) as *mut u8)
144                                        })
145                                        .next(|iter_ptr| unsafe {
146                                            let state = iter_ptr.as_mut::<HashSetIterator<'_, T>>();
147                                            state.next().map(|value| PtrConst::new(value))
148                                        })
149                                        .dealloc(|iter_ptr| unsafe {
150                                            drop(Box::from_raw(
151                                                iter_ptr.as_ptr::<HashSetIterator<'_, T>>()
152                                                    as *mut HashSetIterator<'_, T>,
153                                            ));
154                                        })
155                                        .build(),
156                                )
157                                .build()
158                        },
159                    )
160                    .build(),
161            ))
162            .build()
163    };
164}
165
166#[cfg(test)]
167mod tests {
168    use alloc::string::String;
169    use std::collections::HashSet;
170    use std::hash::RandomState;
171
172    use super::*;
173
174    #[test]
175    fn test_hashset_type_params() {
176        // HashSet should have a type param for both its value type
177        // and its hasher state
178        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
179            panic!("HashSet<T> should have 2 type params")
180        };
181        assert_eq!(type_param_1.shape(), i32::SHAPE);
182        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
183    }
184
185    #[test]
186    fn test_hashset_vtable_1_new_insert_iter_drop() -> eyre::Result<()> {
187        facet_testhelpers::setup();
188
189        let hashset_shape = <HashSet<String>>::SHAPE;
190        let hashset_def = hashset_shape
191            .def
192            .into_set()
193            .expect("HashSet<T> should have a set definition");
194
195        // Allocate memory for the HashSet
196        let hashset_uninit_ptr = hashset_shape.allocate()?;
197
198        // Create the HashSet with a capacity of 3
199        let hashset_ptr =
200            unsafe { (hashset_def.vtable.init_in_place_with_capacity_fn)(hashset_uninit_ptr, 3) };
201
202        // The HashSet is empty, so ensure its length is 0
203        let hashset_actual_length = unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
204        assert_eq!(hashset_actual_length, 0);
205
206        // 5 sample values to insert
207        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
208
209        // Insert the 5 values into the HashSet
210        let mut hashset_length = 0;
211        for string in strings {
212            // Create the value
213            let mut new_value = string.to_string();
214
215            // Insert the value
216            let did_insert = unsafe {
217                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
218            };
219
220            // The value now belongs to the HashSet, so forget it
221            core::mem::forget(new_value);
222
223            assert!(did_insert, "expected value to be inserted in the HashSet");
224
225            // Ensure the HashSet's length increased by 1
226            hashset_length += 1;
227            let hashset_actual_length =
228                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
229            assert_eq!(hashset_actual_length, hashset_length);
230        }
231
232        // Insert the same 5 values again, ensuring they are deduplicated
233        for string in strings {
234            // Create the value
235            let mut new_value = string.to_string();
236
237            // Try to insert the value
238            let did_insert = unsafe {
239                (hashset_def.vtable.insert_fn)(hashset_ptr, PtrMut::new(&raw mut new_value))
240            };
241
242            // The value now belongs to the HashSet, so forget it
243            core::mem::forget(new_value);
244
245            assert!(
246                !did_insert,
247                "expected value to not be inserted in the HashSet"
248            );
249
250            // Ensure the HashSet's length did not increase
251            let hashset_actual_length =
252                unsafe { (hashset_def.vtable.len_fn)(hashset_ptr.as_const()) };
253            assert_eq!(hashset_actual_length, hashset_length);
254        }
255
256        // Create a new iterator over the HashSet
257        let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
258        let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
259
260        // Collect all the items from the HashSet's iterator
261        let mut iter_items = HashSet::<&str>::new();
262        loop {
263            // Get the next item from the iterator
264            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
265            let Some(item_ptr) = item_ptr else {
266                break;
267            };
268
269            let item = unsafe { item_ptr.get::<String>() };
270
271            // Insert the item into the set of items returned from the iterator
272            let did_insert = iter_items.insert(&**item);
273
274            assert!(did_insert, "HashSet iterator returned duplicate item");
275        }
276
277        // Deallocate the iterator
278        unsafe {
279            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
280        }
281
282        // Ensure the iterator returned all of the strings
283        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
284
285        // Get the function pointer for dropping the HashSet
286        let drop_fn =
287            (hashset_shape.vtable.drop_in_place)().expect("HashSet<T> should have drop_in_place");
288
289        // Drop the HashSet in place
290        unsafe { drop_fn(hashset_ptr) };
291
292        // Deallocate the memory
293        unsafe { hashset_shape.deallocate_mut(hashset_ptr)? };
294
295        Ok(())
296    }
297}