1use core::hash::BuildHasher;
2use std::collections::HashSet;
3
4use crate::{PtrConst, PtrMut, PtrUninit};
5
6use crate::{
7 Def, Facet, HashProxy, IterVTable, OxPtrConst, OxPtrMut, OxRef, SetDef, SetVTable, Shape,
8 ShapeBuilder, Type, TypeNameFn, TypeNameOpts, TypeOpsIndirect, TypeParam, UserType,
9 VTableIndirect,
10};
11
12type HashSetIterator<'mem, T> = std::collections::hash_set::Iter<'mem, T>;
13
14unsafe fn hashset_init_in_place_with_capacity<T, S: Default + BuildHasher>(
15 uninit: PtrUninit,
16 capacity: usize,
17) -> PtrMut {
18 unsafe {
19 uninit.put(HashSet::<T, S>::with_capacity_and_hasher(
20 capacity,
21 S::default(),
22 ))
23 }
24}
25
26unsafe fn hashset_insert<T: Eq + core::hash::Hash + 'static>(ptr: PtrMut, item: PtrMut) -> bool {
27 unsafe {
28 let set = ptr.as_mut::<HashSet<T>>();
29 let item = item.read::<T>();
30 set.insert(item)
31 }
32}
33
34unsafe fn hashset_len<T: 'static>(ptr: PtrConst) -> usize {
35 unsafe { ptr.get::<HashSet<T>>().len() }
36}
37
38unsafe fn hashset_contains<T: Eq + core::hash::Hash + 'static>(
39 ptr: PtrConst,
40 item: PtrConst,
41) -> bool {
42 unsafe { ptr.get::<HashSet<T>>().contains(item.get()) }
43}
44
45unsafe fn hashset_iter_init<T: 'static>(ptr: PtrConst) -> PtrMut {
46 unsafe {
47 let set = ptr.get::<HashSet<T>>();
48 let iter: HashSetIterator<'_, T> = set.iter();
49 let iter_state = Box::new(iter);
50 PtrMut::new(Box::into_raw(iter_state) as *mut u8)
51 }
52}
53
54unsafe fn hashset_iter_next<T: 'static>(iter_ptr: PtrMut) -> Option<PtrConst> {
55 unsafe {
56 let state = iter_ptr.as_mut::<HashSetIterator<'static, T>>();
57 state.next().map(|value| PtrConst::new(value as *const T))
58 }
59}
60
61unsafe fn hashset_iter_dealloc<T>(iter_ptr: PtrMut) {
62 unsafe {
63 drop(Box::from_raw(
64 iter_ptr.as_ptr::<HashSetIterator<'_, T>>() as *mut HashSetIterator<'_, T>
65 ));
66 }
67}
68
69#[inline]
71fn get_set_def(shape: &'static Shape) -> Option<&'static SetDef> {
72 match shape.def {
73 Def::Set(ref def) => Some(def),
74 _ => None,
75 }
76}
77
78unsafe fn hashset_debug(
80 ox: OxPtrConst,
81 f: &mut core::fmt::Formatter<'_>,
82) -> Option<core::fmt::Result> {
83 let shape = ox.shape();
84 let def = get_set_def(shape)?;
85 let ptr = ox.ptr();
86
87 let mut debug_set = f.debug_set();
88
89 let iter_init = def.vtable.iter_vtable.init_with_value?;
91 let iter_ptr = unsafe { iter_init(ptr) };
92
93 loop {
95 let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
96 let Some(item_ptr) = item_ptr else {
97 break;
98 };
99 let item_ox = unsafe { OxRef::new(item_ptr, def.t) };
102 debug_set.entry(&item_ox);
103 }
104
105 unsafe {
107 (def.vtable.iter_vtable.dealloc)(iter_ptr);
108 }
109
110 Some(debug_set.finish())
111}
112
113unsafe fn hashset_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
115 let shape = ox.shape();
116 let def = get_set_def(shape)?;
117 let ptr = ox.ptr();
118
119 use core::hash::Hash;
120
121 let len = unsafe { (def.vtable.len)(ptr) };
123 len.hash(hasher);
124
125 let iter_init = def.vtable.iter_vtable.init_with_value?;
127 let iter_ptr = unsafe { iter_init(ptr) };
128
129 loop {
131 let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
132 let Some(item_ptr) = item_ptr else {
133 break;
134 };
135 unsafe { def.t.call_hash(item_ptr, hasher)? };
136 }
137
138 unsafe {
140 (def.vtable.iter_vtable.dealloc)(iter_ptr);
141 }
142
143 Some(())
144}
145
146unsafe fn hashset_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
148 let shape = a.shape();
149 let def = get_set_def(shape)?;
150
151 let a_ptr = a.ptr();
152 let b_ptr = b.ptr();
153
154 let a_len = unsafe { (def.vtable.len)(a_ptr) };
155 let b_len = unsafe { (def.vtable.len)(b_ptr) };
156
157 if a_len != b_len {
159 return Some(false);
160 }
161
162 let iter_init = def.vtable.iter_vtable.init_with_value?;
164 let iter_ptr = unsafe { iter_init(a_ptr) };
165
166 let mut all_contained = true;
168 loop {
169 let item_ptr = unsafe { (def.vtable.iter_vtable.next)(iter_ptr) };
170 let Some(item_ptr) = item_ptr else {
171 break;
172 };
173 let contained = unsafe { (def.vtable.contains)(b_ptr, item_ptr) };
174 if !contained {
175 all_contained = false;
176 break;
177 }
178 }
179
180 unsafe {
182 (def.vtable.iter_vtable.dealloc)(iter_ptr);
183 }
184
185 Some(all_contained)
186}
187
188unsafe fn hashset_drop<T: 'static, S: 'static>(ox: OxPtrMut) {
190 unsafe {
191 core::ptr::drop_in_place(ox.as_mut::<HashSet<T, S>>());
192 }
193}
194
195unsafe fn hashset_default<T: 'static, S: Default + BuildHasher + 'static>(ox: OxPtrMut) {
197 unsafe { ox.ptr().as_uninit().put(HashSet::<T, S>::default()) };
198}
199
200unsafe impl<'a, T, S> Facet<'a> for HashSet<T, S>
201where
202 T: Facet<'a> + core::cmp::Eq + core::hash::Hash + 'static,
203 S: Facet<'a> + Default + BuildHasher + 'static,
204{
205 const SHAPE: &'static Shape = &const {
206 const fn build_set_vtable<
207 T: Eq + core::hash::Hash + 'static,
208 S: Default + BuildHasher + 'static,
209 >() -> SetVTable {
210 SetVTable::builder()
211 .init_in_place_with_capacity(hashset_init_in_place_with_capacity::<T, S>)
212 .insert(hashset_insert::<T>)
213 .len(hashset_len::<T>)
214 .contains(hashset_contains::<T>)
215 .iter_vtable(IterVTable {
216 init_with_value: Some(hashset_iter_init::<T>),
217 next: hashset_iter_next::<T>,
218 next_back: None,
219 size_hint: None,
220 dealloc: hashset_iter_dealloc::<T>,
221 })
222 .build()
223 }
224
225 const fn build_type_name<'a, T: Facet<'a>>() -> TypeNameFn {
226 fn type_name_impl<'a, T: Facet<'a>>(
227 _shape: &'static Shape,
228 f: &mut core::fmt::Formatter<'_>,
229 opts: TypeNameOpts,
230 ) -> core::fmt::Result {
231 write!(f, "HashSet")?;
232 if let Some(opts) = opts.for_children() {
233 write!(f, "<")?;
234 T::SHAPE.write_type_name(f, opts)?;
235 write!(f, ">")?;
236 } else {
237 write!(f, "<…>")?;
238 }
239 Ok(())
240 }
241 type_name_impl::<T>
242 }
243
244 ShapeBuilder::for_sized::<Self>("HashSet")
245 .type_name(build_type_name::<T>())
246 .ty(Type::User(UserType::Opaque))
247 .def(Def::Set(SetDef::new(
248 &const { build_set_vtable::<T, S>() },
249 T::SHAPE,
250 )))
251 .type_params(&[
252 TypeParam {
253 name: "T",
254 shape: T::SHAPE,
255 },
256 TypeParam {
257 name: "S",
258 shape: S::SHAPE,
259 },
260 ])
261 .vtable_indirect(
262 &const {
263 VTableIndirect {
264 debug: Some(hashset_debug),
265 hash: Some(hashset_hash),
266 partial_eq: Some(hashset_partial_eq),
267 ..VTableIndirect::EMPTY
268 }
269 },
270 )
271 .type_ops_indirect(
272 &const {
273 TypeOpsIndirect {
274 drop_in_place: hashset_drop::<T, S>,
275 default_in_place: Some(hashset_default::<T, S>),
276 clone_into: None,
277 is_truthy: None,
278 }
279 },
280 )
281 .build()
282 };
283}
284
285#[cfg(test)]
286mod tests {
287 use alloc::string::String;
288 use core::ptr::NonNull;
289 use std::collections::HashSet;
290 use std::hash::RandomState;
291
292 use super::*;
293
294 #[test]
295 fn test_hashset_type_params() {
296 let [type_param_1, type_param_2] = <HashSet<i32>>::SHAPE.type_params else {
299 panic!("HashSet<T> should have 2 type params")
300 };
301 assert_eq!(type_param_1.shape(), i32::SHAPE);
302 assert_eq!(type_param_2.shape(), RandomState::SHAPE);
303 }
304
305 #[test]
306 fn test_hashset_vtable_1_new_insert_iter_drop() {
307 facet_testhelpers::setup();
308
309 let hashset_shape = <HashSet<String>>::SHAPE;
310 let hashset_def = hashset_shape
311 .def
312 .into_set()
313 .expect("HashSet<T> should have a set definition");
314
315 let hashset_uninit_ptr = hashset_shape.allocate().unwrap();
317
318 let hashset_ptr =
320 unsafe { (hashset_def.vtable.init_in_place_with_capacity)(hashset_uninit_ptr, 3) };
321
322 let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
324 assert_eq!(hashset_actual_length, 0);
325
326 let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
328
329 let mut hashset_length = 0;
331 for string in strings {
332 let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
334
335 let did_insert = unsafe {
337 (hashset_def.vtable.insert)(
338 hashset_ptr,
339 PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
340 )
341 };
342
343 assert!(did_insert, "expected value to be inserted in the HashSet");
344
345 hashset_length += 1;
347 let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
348 assert_eq!(hashset_actual_length, hashset_length);
349 }
350
351 for string in strings {
353 let mut new_value = core::mem::ManuallyDrop::new(string.to_string());
355
356 let did_insert = unsafe {
358 (hashset_def.vtable.insert)(
359 hashset_ptr,
360 PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
361 )
362 };
363
364 assert!(
365 !did_insert,
366 "expected value to not be inserted in the HashSet"
367 );
368
369 let hashset_actual_length = unsafe { (hashset_def.vtable.len)(hashset_ptr.as_const()) };
371 assert_eq!(hashset_actual_length, hashset_length);
372 }
373
374 let iter_init_with_value_fn = hashset_def.vtable.iter_vtable.init_with_value.unwrap();
376 let hashset_iter_ptr = unsafe { iter_init_with_value_fn(hashset_ptr.as_const()) };
377
378 let mut iter_items = HashSet::<&str>::new();
380 loop {
381 let item_ptr = unsafe { (hashset_def.vtable.iter_vtable.next)(hashset_iter_ptr) };
383 let Some(item_ptr) = item_ptr else {
384 break;
385 };
386
387 let item = unsafe { item_ptr.get::<String>() };
388
389 let did_insert = iter_items.insert(&**item);
391
392 assert!(did_insert, "HashSet iterator returned duplicate item");
393 }
394
395 unsafe {
397 (hashset_def.vtable.iter_vtable.dealloc)(hashset_iter_ptr);
398 }
399
400 assert_eq!(iter_items, strings.iter().copied().collect::<HashSet<_>>());
402
403 unsafe {
405 hashset_shape
406 .call_drop_in_place(hashset_ptr)
407 .expect("HashSet<T> should have drop_in_place");
408
409 hashset_shape.deallocate_mut(hashset_ptr).unwrap();
411 }
412 }
413}