Skip to main content

facet_core/impls/std/
hashset.rs

1use core::hash::BuildHasher;
2use std::collections::HashSet;
3
4use crate::{PtrConst, PtrMut, PtrUninit};
5
6use crate::{
7    Def, Facet, HashProxy, IterVTable, OxPtrConst, OxPtrMut, OxPtrUninit, OxRef, SetDef, SetVTable,
8    Shape, ShapeBuilder, Type, TypeNameFn, TypeNameOpts, TypeOpsIndirect, TypeParam, UserType,
9    VTableIndirect, Variance, VarianceDep, VarianceDesc,
10};
11
12type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
13
14unsafe extern "C" fn hashset_init_in_place_with_capacity<T, S: Default + BuildHasher>(
15    uninit: PtrUninit,
16    capacity: usize,
17) -> PtrMut {
18    unsafe {
19        uninit.put(HashSet::<T, S>::with_capacity_and_hasher(
20            capacity,
21            S::default(),
22        ))
23    }
24}
25
26unsafe extern "C" fn hashset_insert<T: Eq + core::hash::Hash + 'static>(
27    ptr: PtrMut,
28    item: PtrMut,
29) -> bool {
30    unsafe {
31        let set = ptr.as_mut::<HashSet<T>>();
32        let item = item.read::<T>();
33        set.insert(item)
34    }
35}
36
37unsafe extern "C" fn hashset_len<T: 'static>(ptr: PtrConst) -> usize {
38    unsafe { ptr.get::<HashSet<T>>().len() }
39}
40
41unsafe extern "C" fn hashset_contains<T: Eq + core::hash::Hash + 'static>(
42    ptr: PtrConst,
43    item: PtrConst,
44) -> bool {
45    unsafe { ptr.get::<HashSet<T>>().contains(item.get()) }
46}
47
48unsafe extern "C" fn hashset_iter_init<T: 'static>(ptr: PtrConst) -> PtrMut {
49    unsafe {
50        let set = ptr.get::<HashSet<T>>();
51        let iter: HashSetIterator<'_, T> = set.iter();
52        let iter_state = Box::new(iter);
53        PtrMut::new(Box::into_raw(iter_state) as *mut u8)
54    }
55}
56
57unsafe fn hashset_iter_next<T: 'static>(iter_ptr: PtrMut) -> Option<PtrConst> {
58    unsafe {
59        let state = iter_ptr.as_mut::<HashSetIterator<'static, T>>();
60        state.next().map(|value| PtrConst::new(value as *const T))
61    }
62}
63
64unsafe extern "C" fn hashset_iter_dealloc<T>(iter_ptr: PtrMut) {
65    unsafe {
66        drop(Box::from_raw(
67            iter_ptr.as_ptr::<HashSetIterator<'_, T>>() as *mut HashSetIterator<'_, T>
68        ));
69    }
70}
71
72/// Build a HashSet from a contiguous slice of elements.
73///
74/// # Safety
75/// - `set` must point to uninitialized memory
76/// - `elements_ptr` must point to `count` consecutive initialized T values
77/// - Elements are moved out and should not be dropped by caller
78unsafe extern "C" fn hashset_from_slice<
79    T: Eq + core::hash::Hash + 'static,
80    S: Default + BuildHasher + 'static,
81>(
82    set: PtrUninit,
83    elements_ptr: *mut u8,
84    count: usize,
85) -> PtrMut {
86    unsafe {
87        let elements = elements_ptr as *mut T;
88        let mut hashset = HashSet::<T, S>::with_capacity_and_hasher(count, S::default());
89        for i in 0..count {
90            let elem = core::ptr::read(elements.add(i));
91            hashset.insert(elem);
92        }
93        set.put(hashset)
94    }
95}
96
97/// Extract the SetDef from a shape, returns None if not a Set
98#[inline]
99const fn get_set_def(shape: &'static Shape) -> Option<&'static SetDef> {
100    match shape.def {
101        Def::Set(ref def) => Some(def),
102        _ => None,
103    }
104}
105
106/// Debug for `HashSet<T>` - delegates to inner T's debug if available
107unsafe fn hashset_debug(
108    ox: OxPtrConst,
109    f: &mut core::fmt::Formatter<'_>,
110) -> Option<core::fmt::Result> {
111    let shape = ox.shape();
112    let def = get_set_def(shape)?;
113    let ptr = ox.ptr();
114
115    let mut debug_set = f.debug_set();
116
117    // Initialize iterator
118    let iter_init = def.vtable.iter_vtable.init_with_value?;
119    let iter_ptr = unsafe { iter_init(ptr) };
120
121    // Iterate over all elements
122    loop {
123        let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
124        let Some(item_ptr) = item_ptr else {
125            break;
126        };
127        // SAFETY: The iterator returns valid pointers to set items.
128        // The caller guarantees the OxPtrConst points to a valid HashSet.
129        let item_ox = unsafe { OxRef::new(item_ptr, def.t) };
130        debug_set.entry(&item_ox);
131    }
132
133    // Deallocate iterator
134    unsafe {
135        (def.vtable.iter_vtable.dealloc)(iter_ptr);
136    }
137
138    Some(debug_set.finish())
139}
140
141/// Hash for `HashSet<T>` - delegates to inner T's hash if available
142unsafe fn hashset_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
143    let shape = ox.shape();
144    let def = get_set_def(shape)?;
145    let ptr = ox.ptr();
146
147    use core::hash::Hash;
148
149    // Hash the length first
150    let len = unsafe { (def.vtable.len)(ptr) };
151    len.hash(hasher);
152
153    // Initialize iterator
154    let iter_init = def.vtable.iter_vtable.init_with_value?;
155    let iter_ptr = unsafe { iter_init(ptr) };
156
157    // Hash all elements
158    loop {
159        let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
160        let Some(item_ptr) = item_ptr else {
161            break;
162        };
163        unsafe { def.t.call_hash(item_ptr, hasher)? };
164    }
165
166    // Deallocate iterator
167    unsafe {
168        (def.vtable.iter_vtable.dealloc)(iter_ptr);
169    }
170
171    Some(())
172}
173
174/// PartialEq for `HashSet<T>`
175unsafe fn hashset_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
176    let shape = a.shape();
177    let def = get_set_def(shape)?;
178
179    let a_ptr = a.ptr();
180    let b_ptr = b.ptr();
181
182    let a_len = unsafe { (def.vtable.len)(a_ptr) };
183    let b_len = unsafe { (def.vtable.len)(b_ptr) };
184
185    // If lengths differ, sets are not equal
186    if a_len != b_len {
187        return Some(false);
188    }
189
190    // Initialize iterator for set a
191    let iter_init = def.vtable.iter_vtable.init_with_value?;
192    let iter_ptr = unsafe { iter_init(a_ptr) };
193
194    // Check if all elements from a are contained in b
195    let mut all_contained = true;
196    loop {
197        let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
198        let Some(item_ptr) = item_ptr else {
199            break;
200        };
201        let contained = unsafe { (def.vtable.contains)(b_ptr, item_ptr) };
202        if !contained {
203            all_contained = false;
204            break;
205        }
206    }
207
208    // Deallocate iterator
209    unsafe {
210        (def.vtable.iter_vtable.dealloc)(iter_ptr);
211    }
212
213    Some(all_contained)
214}
215
216/// Drop for HashSet<T, S>
217unsafe fn hashset_drop<T: 'static, S: 'static>(ox: OxPtrMut) {
218    unsafe {
219        core::ptr::drop_in_place(ox.as_mut::<HashSet<T, S>>());
220    }
221}
222
223/// Default for HashSet<T, S>
224unsafe fn hashset_default<T: 'static, S: Default + BuildHasher + 'static>(ox: OxPtrUninit) -> bool {
225    unsafe { ox.put(HashSet::<T, S>::default()) };
226    true
227}
228
229unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
230where
231    T: Facet<'a> + core::cmp::Eq + core::hash::Hash + 'static,
232    S: Facet<'a> + Default + BuildHasher + 'static,
233{
234    const SHAPE: &'static Shape = &const {
235        const fn build_set_vtable<
236            T: Eq + core::hash::Hash + 'static,
237            S: Default + BuildHasher + 'static,
238        >() -> SetVTable {
239            SetVTable::builder()
240                .init_in_place_with_capacity(hashset_init_in_place_with_capacity::<T, S>)
241                .insert(hashset_insert::<T>)
242                .len(hashset_len::<T>)
243                .contains(hashset_contains::<T>)
244                .iter_vtable(IterVTable {
245                    init_with_value: Some(hashset_iter_init::<T>),
246                    next: hashset_iter_next::<T>,
247                    next_back: None,
248                    size_hint: None,
249                    dealloc: hashset_iter_dealloc::<T>,
250                })
251                .from_slice(Some(hashset_from_slice::<T, S>))
252                .build()
253        }
254
255        const fn build_type_name<'a, T: Facet<'a>>() -> TypeNameFn {
256            fn type_name_impl<'a, T: Facet<'a>>(
257                _shape: &'static Shape,
258                f: &mut core::fmt::Formatter<'_>,
259                opts: TypeNameOpts,
260            ) -> core::fmt::Result {
261                write!(f, "HashSet")?;
262                if let Some(opts) = opts.for_children() {
263                    write!(f, "<")?;
264                    T::SHAPE.write_type_name(f, opts)?;
265                    write!(f, ">")?;
266                } else {
267                    write!(f, "<…>")?;
268                }
269                Ok(())
270            }
271            type_name_impl::<T>
272        }
273
274        ShapeBuilder::for_sized::<Self>("HashSet")
275            .module_path("std::collections::hash_set")
276            .type_name(build_type_name::<T>())
277            .ty(Type::User(UserType::Opaque))
278            .def(Def::Set(SetDef::new(
279                &const { build_set_vtable::<T, S>() },
280                T::SHAPE,
281            )))
282            .type_params(&[
283                TypeParam {
284                    name: "T",
285                    shape: T::SHAPE,
286                },
287                TypeParam {
288                    name: "S",
289                    shape: S::SHAPE,
290                },
291            ])
292            .inner(T::SHAPE)
293            // HashSet<T> propagates T's variance
294            .variance(VarianceDesc {
295                base: Variance::Bivariant,
296                deps: &const { [VarianceDep::covariant(T::SHAPE)] },
297            })
298            .vtable_indirect(
299                &const {
300                    VTableIndirect {
301                        debug: Some(hashset_debug),
302                        hash: Some(hashset_hash),
303                        partial_eq: Some(hashset_partial_eq),
304                        ..VTableIndirect::EMPTY
305                    }
306                },
307            )
308            .type_ops_indirect(
309                &const {
310                    TypeOpsIndirect {
311                        drop_in_place: hashset_drop::<T, S>,
312                        default_in_place: Some(hashset_default::<T, S>),
313                        clone_into: None,
314                        is_truthy: None,
315                    }
316                },
317            )
318            .build()
319    };
320}
321
322#[cfg(test)]
323mod tests {
324    use alloc::string::String;
325    use core::ptr::NonNull;
326    use std::collections::HashSet;
327    use std::hash::RandomState;
328
329    use super::*;
330
331    #[test]
332    fn test_hashset_type_params() {
333        // HashSet should have a type param for both its value type
334        // and its hasher state
335        let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
336            panic!("HashSet<T> should have 2 type params")
337        };
338        assert_eq!(type_param_1.shape(), i32::SHAPE);
339        assert_eq!(type_param_2.shape(), RandomState::SHAPE);
340    }
341
342    #[test]
343    fn test_hashset_vtable_1_new_insert_iter_drop() {
344        facet_testhelpers::setup();
345
346        let hashset_shape = <HashSet<String>>::SHAPE;
347        let hashset_def = hashset_shape
348            .def
349            .into_set()
350            .expect("HashSet<T> should have a set definition");
351
352        // Allocate memory for the HashSet
353        let hashset_uninit_ptr = hashset_shape.allocate().unwrap();
354
355        // Create the HashSet with a capacity of 3
356        let hashset_ptr =
357            unsafe { (hashset_def.vtable.init_in_place_with_capacity)(hashset_uninit_ptr, 3) };
358
359        // The HashSet is empty, so ensure its length is 0
360        let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
361        assert_eq!(hashset_actual_length, 0);
362
363        // 5 sample values to insert
364        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
365
366        // Insert the 5 values into the HashSet
367        let mut hashset_length = 0;
368        for string in strings {
369            // Create the value
370            let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
371
372            // Insert the value
373            let did_insert = unsafe {
374                (hashset_def.vtable.insert)(
375                    hashset_ptr,
376                    PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
377                )
378            };
379
380            assert!(did_insert, "expected value to be inserted in the HashSet");
381
382            // Ensure the HashSet's length increased by 1
383            hashset_length += 1;
384            let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
385            assert_eq!(hashset_actual_length, hashset_length);
386        }
387
388        // Insert the same 5 values again, ensuring they are deduplicated
389        for string in strings {
390            // Create the value
391            let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
392
393            // Try to insert the value
394            let did_insert = unsafe {
395                (hashset_def.vtable.insert)(
396                    hashset_ptr,
397                    PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
398                )
399            };
400
401            assert!(
402                !did_insert,
403                "expected value to not be inserted in the HashSet"
404            );
405
406            // Ensure the HashSet's length did not increase
407            let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
408            assert_eq!(hashset_actual_length, hashset_length);
409        }
410
411        // Create a new iterator over the HashSet
412        let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
413        let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
414
415        // Collect all the items from the HashSet's iterator
416        let mut iter_items = HashSet::<&str>::new();
417        loop {
418            // Get the next item from the iterator
419            let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
420            let Some(item_ptr) = item_ptr else {
421                break;
422            };
423
424            let item = unsafe { item_ptr.get::<String>() };
425
426            // Insert the item into the set of items returned from the iterator
427            let did_insert = iter_items.insert(&**item);
428
429            assert!(did_insert, "HashSet iterator returned duplicate item");
430        }
431
432        // Deallocate the iterator
433        unsafe {
434            (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
435        }
436
437        // Ensure the iterator returned all of the strings
438        assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
439
440        // Drop the HashSet in place
441        unsafe {
442            hashset_shape
443                .call_drop_in_place(hashset_ptr)
444                .expect("HashSet<T> should have drop_in_place");
445
446            // Deallocate the memory
447            hashset_shape.deallocate_mut(hashset_ptr).unwrap();
448        }
449    }
450}