1use core::hash::BuildHasher;
2use std::collections::HashSet;
3
4use crate::{PtrConst, PtrMut, PtrUninit};
5
6use crate::{
7 Def, Facet, HashProxy, IterVTable, OxPtrConst, OxPtrMut, OxPtrUninit, OxRef, SetDef, SetVTable,
8 Shape, ShapeBuilder, Type, TypeNameFn, TypeNameOpts, TypeOpsIndirect, TypeParam, UserType,
9 VTableIndirect, Variance, VarianceDep, VarianceDesc,
10};
11
12type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
13
14unsafe extern "C" fn hashset_init_in_place_with_capacity<T, S: Default + BuildHasher>(
15 uninit: PtrUninit,
16 capacity: usize,
17) -> PtrMut {
18 unsafe {
19 uninit.put(HashSet::<T, S>::with_capacity_and_hasher(
20 capacity,
21 S::default(),
22 ))
23 }
24}
25
26unsafe extern "C" fn hashset_insert<T: Eq + core::hash::Hash + 'static>(
27 ptr: PtrMut,
28 item: PtrMut,
29) -> bool {
30 unsafe {
31 let set = ptr.as_mut::<HashSet<T>>();
32 let item = item.read::<T>();
33 set.insert(item)
34 }
35}
36
37unsafe extern "C" fn hashset_len<T: 'static>(ptr: PtrConst) -> usize {
38 unsafe { ptr.get::<HashSet<T>>().len() }
39}
40
41unsafe extern "C" fn hashset_contains<T: Eq + core::hash::Hash + 'static>(
42 ptr: PtrConst,
43 item: PtrConst,
44) -> bool {
45 unsafe { ptr.get::<HashSet<T>>().contains(item.get()) }
46}
47
48unsafe extern "C" fn hashset_iter_init<T: 'static>(ptr: PtrConst) -> PtrMut {
49 unsafe {
50 let set = ptr.get::<HashSet<T>>();
51 let iter: HashSetIterator<'_, T> = set.iter();
52 let iter_state = Box::new(iter);
53 PtrMut::new(Box::into_raw(iter_state) as *mut u8)
54 }
55}
56
57unsafe fn hashset_iter_next<T: 'static>(iter_ptr: PtrMut) -> Option<PtrConst> {
58 unsafe {
59 let state = iter_ptr.as_mut::<HashSetIterator<'static, T>>();
60 state.next().map(|value| PtrConst::new(value as *const T))
61 }
62}
63
64unsafe extern "C" fn hashset_iter_dealloc<T>(iter_ptr: PtrMut) {
65 unsafe {
66 drop(Box::from_raw(
67 iter_ptr.as_ptr::<HashSetIterator<'_, T>>() as *mut HashSetIterator<'_, T>
68 ));
69 }
70}
71
72unsafe extern "C" fn hashset_from_slice<
79 T: Eq + core::hash::Hash + 'static,
80 S: Default + BuildHasher + 'static,
81>(
82 set: PtrUninit,
83 elements_ptr: *mut u8,
84 count: usize,
85) -> PtrMut {
86 unsafe {
87 let elements = elements_ptr as *mut T;
88 let mut hashset = HashSet::<T, S>::with_capacity_and_hasher(count, S::default());
89 for i in 0..count {
90 let elem = core::ptr::read(elements.add(i));
91 hashset.insert(elem);
92 }
93 set.put(hashset)
94 }
95}
96
97#[inline]
99const fn get_set_def(shape: &'static Shape) -> Option<&'static SetDef> {
100 match shape.def {
101 Def::Set(ref def) => Some(def),
102 _ => None,
103 }
104}
105
106unsafe fn hashset_debug(
108 ox: OxPtrConst,
109 f: &mut core::fmt::Formatter<'_>,
110) -> Option<core::fmt::Result> {
111 let shape = ox.shape();
112 let def = get_set_def(shape)?;
113 let ptr = ox.ptr();
114
115 let mut debug_set = f.debug_set();
116
117 let iter_init = def.vtable.iter_vtable.init_with_value?;
119 let iter_ptr = unsafe { iter_init(ptr) };
120
121 loop {
123 let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
124 let Some(item_ptr) = item_ptr else {
125 break;
126 };
127 let item_ox = unsafe { OxRef::new(item_ptr, def.t) };
130 debug_set.entry(&item_ox);
131 }
132
133 unsafe {
135 (def.vtable.iter_vtable.dealloc)(iter_ptr);
136 }
137
138 Some(debug_set.finish())
139}
140
141unsafe fn hashset_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
143 let shape = ox.shape();
144 let def = get_set_def(shape)?;
145 let ptr = ox.ptr();
146
147 use core::hash::Hash;
148
149 let len = unsafe { (def.vtable.len)(ptr) };
151 len.hash(hasher);
152
153 let iter_init = def.vtable.iter_vtable.init_with_value?;
155 let iter_ptr = unsafe { iter_init(ptr) };
156
157 loop {
159 let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
160 let Some(item_ptr) = item_ptr else {
161 break;
162 };
163 unsafe { def.t.call_hash(item_ptr, hasher)? };
164 }
165
166 unsafe {
168 (def.vtable.iter_vtable.dealloc)(iter_ptr);
169 }
170
171 Some(())
172}
173
174unsafe fn hashset_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
176 let shape = a.shape();
177 let def = get_set_def(shape)?;
178
179 let a_ptr = a.ptr();
180 let b_ptr = b.ptr();
181
182 let a_len = unsafe { (def.vtable.len)(a_ptr) };
183 let b_len = unsafe { (def.vtable.len)(b_ptr) };
184
185 if a_len != b_len {
187 return Some(false);
188 }
189
190 let iter_init = def.vtable.iter_vtable.init_with_value?;
192 let iter_ptr = unsafe { iter_init(a_ptr) };
193
194 let mut all_contained = true;
196 loop {
197 let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
198 let Some(item_ptr) = item_ptr else {
199 break;
200 };
201 let contained = unsafe { (def.vtable.contains)(b_ptr, item_ptr) };
202 if !contained {
203 all_contained = false;
204 break;
205 }
206 }
207
208 unsafe {
210 (def.vtable.iter_vtable.dealloc)(iter_ptr);
211 }
212
213 Some(all_contained)
214}
215
216unsafe fn hashset_drop<T: 'static, S: 'static>(ox: OxPtrMut) {
218 unsafe {
219 core::ptr::drop_in_place(ox.as_mut::<HashSet<T, S>>());
220 }
221}
222
223unsafe fn hashset_default<T: 'static, S: Default + BuildHasher + 'static>(ox: OxPtrUninit) -> bool {
225 unsafe { ox.put(HashSet::<T, S>::default()) };
226 true
227}
228
229unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
230where
231 T: Facet<'a> + core::cmp::Eq + core::hash::Hash + 'static,
232 S: Facet<'a> + Default + BuildHasher + 'static,
233{
234 const SHAPE: &'static Shape = &const {
235 const fn build_set_vtable<
236 T: Eq + core::hash::Hash + 'static,
237 S: Default + BuildHasher + 'static,
238 >() -> SetVTable {
239 SetVTable::builder()
240 .init_in_place_with_capacity(hashset_init_in_place_with_capacity::<T, S>)
241 .insert(hashset_insert::<T>)
242 .len(hashset_len::<T>)
243 .contains(hashset_contains::<T>)
244 .iter_vtable(IterVTable {
245 init_with_value: Some(hashset_iter_init::<T>),
246 next: hashset_iter_next::<T>,
247 next_back: None,
248 size_hint: None,
249 dealloc: hashset_iter_dealloc::<T>,
250 })
251 .from_slice(Some(hashset_from_slice::<T, S>))
252 .build()
253 }
254
255 const fn build_type_name<'a, T: Facet<'a>>() -> TypeNameFn {
256 fn type_name_impl<'a, T: Facet<'a>>(
257 _shape: &'static Shape,
258 f: &mut core::fmt::Formatter<'_>,
259 opts: TypeNameOpts,
260 ) -> core::fmt::Result {
261 write!(f, "HashSet")?;
262 if let Some(opts) = opts.for_children() {
263 write!(f, "<")?;
264 T::SHAPE.write_type_name(f, opts)?;
265 write!(f, ">")?;
266 } else {
267 write!(f, "<…>")?;
268 }
269 Ok(())
270 }
271 type_name_impl::<T>
272 }
273
274 ShapeBuilder::for_sized::<Self>("HashSet")
275 .module_path("std::collections::hash_set")
276 .type_name(build_type_name::<T>())
277 .ty(Type::User(UserType::Opaque))
278 .def(Def::Set(SetDef::new(
279 &const { build_set_vtable::<T, S>() },
280 T::SHAPE,
281 )))
282 .type_params(&[
283 TypeParam {
284 name: "T",
285 shape: T::SHAPE,
286 },
287 TypeParam {
288 name: "S",
289 shape: S::SHAPE,
290 },
291 ])
292 .inner(T::SHAPE)
293 .variance(VarianceDesc {
295 base: Variance::Bivariant,
296 deps: &const { [VarianceDep::covariant(T::SHAPE)] },
297 })
298 .vtable_indirect(
299 &const {
300 VTableIndirect {
301 debug: Some(hashset_debug),
302 hash: Some(hashset_hash),
303 partial_eq: Some(hashset_partial_eq),
304 ..VTableIndirect::EMPTY
305 }
306 },
307 )
308 .type_ops_indirect(
309 &const {
310 TypeOpsIndirect {
311 drop_in_place: hashset_drop::<T, S>,
312 default_in_place: Some(hashset_default::<T, S>),
313 clone_into: None,
314 is_truthy: None,
315 }
316 },
317 )
318 .build()
319 };
320}
321
322#[cfg(test)]
323mod tests {
324 use alloc::string::String;
325 use core::ptr::NonNull;
326 use std::collections::HashSet;
327 use std::hash::RandomState;
328
329 use super::*;
330
331 #[test]
332 fn test_hashset_type_params() {
333 let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
336 panic!("HashSet<T> should have 2 type params")
337 };
338 assert_eq!(type_param_1.shape(), i32::SHAPE);
339 assert_eq!(type_param_2.shape(), RandomState::SHAPE);
340 }
341
342 #[test]
343 fn test_hashset_vtable_1_new_insert_iter_drop() {
344 facet_testhelpers::setup();
345
346 let hashset_shape = <HashSet<String>>::SHAPE;
347 let hashset_def = hashset_shape
348 .def
349 .into_set()
350 .expect("HashSet<T> should have a set definition");
351
352 let hashset_uninit_ptr = hashset_shape.allocate().unwrap();
354
355 let hashset_ptr =
357 unsafe { (hashset_def.vtable.init_in_place_with_capacity)(hashset_uninit_ptr, 3) };
358
359 let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
361 assert_eq!(hashset_actual_length, 0);
362
363 let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
365
366 let mut hashset_length = 0;
368 for string in strings {
369 let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
371
372 let did_insert = unsafe {
374 (hashset_def.vtable.insert)(
375 hashset_ptr,
376 PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
377 )
378 };
379
380 assert!(did_insert, "expected value to be inserted in the HashSet");
381
382 hashset_length += 1;
384 let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
385 assert_eq!(hashset_actual_length, hashset_length);
386 }
387
388 for string in strings {
390 let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
392
393 let did_insert = unsafe {
395 (hashset_def.vtable.insert)(
396 hashset_ptr,
397 PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
398 )
399 };
400
401 assert!(
402 !did_insert,
403 "expected value to not be inserted in the HashSet"
404 );
405
406 let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
408 assert_eq!(hashset_actual_length, hashset_length);
409 }
410
411 let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
413 let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
414
415 let mut iter_items = HashSet::<&str>::new();
417 loop {
418 let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
420 let Some(item_ptr) = item_ptr else {
421 break;
422 };
423
424 let item = unsafe { item_ptr.get::<String>() };
425
426 let did_insert = iter_items.insert(&**item);
428
429 assert!(did_insert, "HashSet iterator returned duplicate item");
430 }
431
432 unsafe {
434 (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
435 }
436
437 assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
439
440 unsafe {
442 hashset_shape
443 .call_drop_in_place(hashset_ptr)
444 .expect("HashSet<T> should have drop_in_place");
445
446 hashset_shape.deallocate_mut(hashset_ptr).unwrap();
448 }
449 }
450}