facet_core/impls/std/
hashset.rs

1use core::hash::BuildHasher;
2use std::collections::HashSet;
3
4use crate::{PtrConst, PtrMut, PtrUninit};
5
6use crate::{
7    Def, Facet, HashProxy, IterVTable, OxPtrConst, OxPtrMut, OxRef, SetDef, SetVTable, Shape,
8    ShapeBuilder, Type, TypeNameFn, TypeNameOpts, TypeOpsIndirect, TypeParam, UserType,
9    VTableIndirect,
10};
11
12type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
13
14unsafe fn hashset_init_in_place_with_capacity<T, S: Default + BuildHasher>(
15    uninit: PtrUninit,
16    capacity: usize,
17) -> PtrMut {
18    unsafe {
19        uninit.put(HashSet::<T, S>::with_capacity_and_hasher(
20            capacity,
21            S::default(),
22        ))
23    }
24}
25
26unsafe fn hashset_insert<T: Eq + core::hash::Hash + 'static>(ptr: PtrMut, item: PtrMut) -> bool {
27    unsafe {
28        let set = ptr.as_mut::<HashSet<T>>();
29        let item = item.read::<T>();
30        set.insert(item)
31    }
32}
33
34unsafe fn hashset_len<T: 'static>(ptr: PtrConst) -> usize {
35    unsafe { ptr.get::<HashSet<T>>().len() }
36}
37
38unsafe fn hashset_contains<T: Eq + core::hash::Hash + 'static>(
39    ptr: PtrConst,
40    item: PtrConst,
41) -> bool {
42    unsafe { ptr.get::<HashSet<T>>().contains(item.get()) }
43}
44
45unsafe fn hashset_iter_init<T: 'static>(ptr: PtrConst) -> PtrMut {
46    unsafe {
47        let set = ptr.get::<HashSet<T>>();
48        let iter: HashSetIterator<'_, T> = set.iter();
49        let iter_state = Box::new(iter);
50        PtrMut::new(Box::into_raw(iter_state) as *mut u8)
51    }
52}
53
54unsafe fn hashset_iter_next<T: 'static>(iter_ptr: PtrMut) -> Option<PtrConst> {
55    unsafe {
56        let state = iter_ptr.as_mut::<HashSetIterator<'static, T>>();
57        state.next().map(|value| PtrConst::new(value as *const T))
58    }
59}
60
61unsafe fn hashset_iter_dealloc<T>(iter_ptr: PtrMut) {
62    unsafe {
63        drop(Box::from_raw(
64            iter_ptr.as_ptr::<HashSetIterator<'_, T>>() as *mut HashSetIterator<'_, T>
65        ));
66    }
67}
68
69/// Extract the SetDef from a shape, returns None if not a Set
70#[inline]
71fn get_set_def(shape: &'static Shape) -> Option<&'static SetDef> {
72    match shape.def {
73        Def::Set(ref def) => Some(def),
74        _ => None,
75    }
76}
77
78/// Debug for `HashSet<T>` - delegates to inner T's debug if available
79unsafe fn hashset_debug(
80    ox: OxPtrConst,
81    f: &mut core::fmt::Formatter<'_>,
82) -> Option<core::fmt::Result> {
83    let shape = ox.shape();
84    let def = get_set_def(shape)?;
85    let ptr = ox.ptr();
86
87    let mut debug_set = f.debug_set();
88
89    // Initialize iterator
90    let iter_init = def.vtable.iter_vtable.init_with_value?;
91    let iter_ptr = unsafe { iter_init(ptr) };
92
93    // Iterate over all elements
94    loop {
95        let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
96        let Some(item_ptr) = item_ptr else {
97            break;
98        };
99        // SAFETY: The iterator returns valid pointers to set items.
100        // The caller guarantees the OxPtrConst points to a valid HashSet.
101        let item_ox = unsafe { OxRef::new(item_ptr, def.t) };
102        debug_set.entry(&item_ox);
103    }
104
105    // Deallocate iterator
106    unsafe {
107        (def.vtable.iter_vtable.dealloc)(iter_ptr);
108    }
109
110    Some(debug_set.finish())
111}
112
113/// Hash for `HashSet<T>` - delegates to inner T's hash if available
114unsafe fn hashset_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
115    let shape = ox.shape();
116    let def = get_set_def(shape)?;
117    let ptr = ox.ptr();
118
119    use core::hash::Hash;
120
121    // Hash the length first
122    let len = unsafe { (def.vtable.len)(ptr) };
123    len.hash(hasher);
124
125    // Initialize iterator
126    let iter_init = def.vtable.iter_vtable.init_with_value?;
127    let iter_ptr = unsafe { iter_init(ptr) };
128
129    // Hash all elements
130    loop {
131        let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
132        let Some(item_ptr) = item_ptr else {
133            break;
134        };
135        unsafe { def.t.call_hash(item_ptr, hasher)? };
136    }
137
138    // Deallocate iterator
139    unsafe {
140        (def.vtable.iter_vtable.dealloc)(iter_ptr);
141    }
142
143    Some(())
144}
145
146/// PartialEq for `HashSet<T>`
147unsafe fn hashset_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
148    let shape = a.shape();
149    let def = get_set_def(shape)?;
150
151    let a_ptr = a.ptr();
152    let b_ptr = b.ptr();
153
154    let a_len = unsafe { (def.vtable.len)(a_ptr) };
155    let b_len = unsafe { (def.vtable.len)(b_ptr) };
156
157    // If lengths differ, sets are not equal
158    if a_len != b_len {
159        return Some(false);
160    }
161
162    // Initialize iterator for set a
163    let iter_init = def.vtable.iter_vtable.init_with_value?;
164    let iter_ptr = unsafe { iter_init(a_ptr) };
165
166    // Check if all elements from a are contained in b
167    let mut all_contained = true;
168    loop {
169        let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
170        let Some(item_ptr) = item_ptr else {
171            break;
172        };
173        let contained = unsafe { (def.vtable.contains)(b_ptr, item_ptr) };
174        if !contained {
175            all_contained = false;
176            break;
177        }
178    }
179
180    // Deallocate iterator
181    unsafe {
182        (def.vtable.iter_vtable.dealloc)(iter_ptr);
183    }
184
185    Some(all_contained)
186}
187
188/// Drop for HashSet<T, S>
189unsafe fn hashset_drop<T: 'static, S: 'static>(ox: OxPtrMut) {
190    unsafe {
191        core::ptr::drop_in_place(ox.as_mut::<HashSet<T, S>>());
192    }
193}
194
195/// Default for HashSet<T, S>
196unsafe fn hashset_default<T: 'static, S: Default + BuildHasher + 'static>(ox: OxPtrMut) {
197    unsafe { ox.ptr().as_uninit().put(HashSet::<T, S>::default()) };
198}
199
200unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
201where
202    T: Facet<'a> + core::cmp::Eq + core::hash::Hash + 'static,
203    S: Facet<'a> + Default + BuildHasher + 'static,
204{
205    const SHAPE: &'static Shape = &const {
206        const fn build_set_vtable<
207            T: Eq + core::hash::Hash + 'static,
208            S: Default + BuildHasher + 'static,
209        >() -> SetVTable {
210            SetVTable::builder()
211                .init_in_place_with_capacity(hashset_init_in_place_with_capacity::<T, S>)
212                .insert(hashset_insert::<T>)
213                .len(hashset_len::<T>)
214                .contains(hashset_contains::<T>)
215                .iter_vtable(IterVTable {
216                    init_with_value: Some(hashset_iter_init::<T>),
217                    next: hashset_iter_next::<T>,
218                    next_back: None,
219                    size_hint: None,
220                    dealloc: hashset_iter_dealloc::<T>,
221                })
222                .build()
223        }
224
225        const fn build_type_name<'a, T: Facet<'a>>() -> TypeNameFn {
226            fn type_name_impl<'a, T: Facet<'a>>(
227                _shape: &'static Shape,
228                f: &mut core::fmt::Formatter<'_>,
229                opts: TypeNameOpts,
230            ) -> core::fmt::Result {
231                write!(f, "HashSet")?;
232                if let Some(opts) = opts.for_children() {
233                    write!(f, "<")?;
234                    T::SHAPE.write_type_name(f, opts)?;
235                    write!(f, ">")?;
236                } else {
237                    write!(f, "<…>")?;
238                }
239                Ok(())
240            }
241            type_name_impl::<T>
242        }
243
244        ShapeBuilder::for_sized::<Self>("HashSet")
245            .type_name(build_type_name::<T>())
246            .ty(Type::User(UserType::Opaque))
247            .def(Def::Set(SetDef::new(
248                &const { build_set_vtable::<T, S>() },
249                T::SHAPE,
250            )))
251            .type_params(&[
252                TypeParam {
253                    name: "T",
254                    shape: T::SHAPE,
255                },
256                TypeParam {
257                    name: "S",
258                    shape: S::SHAPE,
259                },
260            ])
261            .vtable_indirect(
262                &const {
263                    VTableIndirect {
264                        debug: Some(hashset_debug),
265                        hash: Some(hashset_hash),
266                        partial_eq: Some(hashset_partial_eq),
267                        ..VTableIndirect::EMPTY
268                    }
269                },
270            )
271            .type_ops_indirect(
272                &const {
273                    TypeOpsIndirect {
274                        drop_in_place: hashset_drop::<T, S>,
275                        default_in_place: Some(hashset_default::<T, S>),
276                        clone_into: None,
277                        is_truthy: None,
278                    }
279                },
280            )
281            .build()
282    };
283}
284
285#[cfg(test)]
286mod tests {
287    use alloc::string::String;
288    use core::ptr::NonNull;
289    use std::collections::HashSet;
290    use std::hash::RandomState;
291
292    use super::*;
293
294    #[test]
295    fn test_hashset_type_params() {
296        // HashSet should have a type param for both its value type
297        // and its hasher state
298        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
299            panic!("HashSet<T> should have 2 type params")
300        };
301        assert_eq!(type_param_1.shape(), i32::SHAPE);
302        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
303    }
304
305    #[test]
306    fn test_hashset_vtable_1_new_insert_iter_drop() {
307        facet_testhelpers::setup();
308
309        let hashset_shape = <HashSet<String>>::SHAPE;
310        let hashset_def = hashset_shape
311            .def
312            .into_set()
313            .expect("HashSet<T> should have a set definition");
314
315        // Allocate memory for the HashSet
316        let hashset_uninit_ptr = hashset_shape.allocate().unwrap();
317
318        // Create the HashSet with a capacity of 3
319        let hashset_ptr =
320            unsafe { (hashset_def.vtable.init_in_place_with_capacity)(hashset_uninit_ptr, 3) };
321
322        // The HashSet is empty, so ensure its length is 0
323        let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
324        assert_eq!(hashset_actual_length, 0);
325
326        // 5 sample values to insert
327        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
328
329        // Insert the 5 values into the HashSet
330        let mut hashset_length = 0;
331        for string in strings {
332            // Create the value
333            let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
334
335            // Insert the value
336            let did_insert = unsafe {
337                (hashset_def.vtable.insert)(
338                    hashset_ptr,
339                    PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
340                )
341            };
342
343            assert!(did_insert, "expected value to be inserted in the HashSet");
344
345            // Ensure the HashSet's length increased by 1
346            hashset_length += 1;
347            let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
348            assert_eq!(hashset_actual_length, hashset_length);
349        }
350
351        // Insert the same 5 values again, ensuring they are deduplicated
352        for string in strings {
353            // Create the value
354            let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
355
356            // Try to insert the value
357            let did_insert = unsafe {
358                (hashset_def.vtable.insert)(
359                    hashset_ptr,
360                    PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
361                )
362            };
363
364            assert!(
365                !did_insert,
366                "expected value to not be inserted in the HashSet"
367            );
368
369            // Ensure the HashSet's length did not increase
370            let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
371            assert_eq!(hashset_actual_length, hashset_length);
372        }
373
374        // Create a new iterator over the HashSet
375        let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
376        let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
377
378        // Collect all the items from the HashSet's iterator
379        let mut iter_items = HashSet::<&str>::new();
380        loop {
381            // Get the next item from the iterator
382            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
383            let Some(item_ptr) = item_ptr else {
384                break;
385            };
386
387            let item = unsafe { item_ptr.get::<String>() };
388
389            // Insert the item into the set of items returned from the iterator
390            let did_insert = iter_items.insert(&**item);
391
392            assert!(did_insert, "HashSet iterator returned duplicate item");
393        }
394
395        // Deallocate the iterator
396        unsafe {
397            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
398        }
399
400        // Ensure the iterator returned all of the strings
401        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
402
403        // Drop the HashSet in place
404        unsafe {
405            hashset_shape
406                .call_drop_in_place(hashset_ptr)
407                .expect("HashSet<T> should have drop_in_place");
408
409            // Deallocate the memory
410            hashset_shape.deallocate_mut(hashset_ptr).unwrap();
411        }
412    }
413}