Skip to main content

facet_core/impls/alloc/
btreeset.rs

1use alloc::boxed::Box;
2use alloc::collections::BTreeSet;
3
4use crate::{PtrConst, PtrMut, PtrUninit};
5
6use crate::{
7    Def, Facet, IterVTable, OxPtrMut, OxPtrUninit, SetDef, SetVTable, Shape, ShapeBuilder,
8    TypeNameFn, TypeNameOpts, TypeOpsIndirect, TypeParam, VTableIndirect, Variance, VarianceDep,
9    VarianceDesc,
10};
11
12type BTreeSetIterator<'mem, T> = alloc::collections::btree_set::Iter<'mem, T>;
13
14unsafe extern "C" fn btreeset_init_in_place_with_capacity<T>(
15    uninit: PtrUninit,
16    _capacity: usize,
17) -> PtrMut {
18    unsafe { uninit.put(BTreeSet::<T>::new()) }
19}
20
21unsafe extern "C" fn btreeset_insert<T: Eq + Ord + 'static>(ptr: PtrMut, item: PtrMut) -> bool {
22    unsafe {
23        let set = ptr.as_mut::<BTreeSet<T>>();
24        let item = item.read::<T>();
25        set.insert(item)
26    }
27}
28
29unsafe extern "C" fn btreeset_len<T: 'static>(ptr: PtrConst) -> usize {
30    unsafe { ptr.get::<BTreeSet<T>>().len() }
31}
32
33unsafe extern "C" fn btreeset_contains<T: Eq + Ord + 'static>(
34    ptr: PtrConst,
35    item: PtrConst,
36) -> bool {
37    unsafe { ptr.get::<BTreeSet<T>>().contains(item.get()) }
38}
39
40unsafe extern "C" fn btreeset_iter_init<T: 'static>(ptr: PtrConst) -> PtrMut {
41    unsafe {
42        let set = ptr.get::<BTreeSet<T>>();
43        let iter: BTreeSetIterator<'_, T> = set.iter();
44        let iter_state = Box::new(iter);
45        PtrMut::new(Box::into_raw(iter_state) as *mut u8)
46    }
47}
48
49unsafe fn btreeset_iter_next<T: 'static>(iter_ptr: PtrMut) -> Option<PtrConst> {
50    unsafe {
51        let state = iter_ptr.as_mut::<BTreeSetIterator<'static, T>>();
52        state.next().map(|value| PtrConst::new(value as *const T))
53    }
54}
55
56unsafe fn btreeset_iter_next_back<T: 'static>(iter_ptr: PtrMut) -> Option<PtrConst> {
57    unsafe {
58        let state = iter_ptr.as_mut::<BTreeSetIterator<'static, T>>();
59        state
60            .next_back()
61            .map(|value| PtrConst::new(value as *const T))
62    }
63}
64
65unsafe extern "C" fn btreeset_iter_dealloc<T>(iter_ptr: PtrMut) {
66    unsafe {
67        drop(Box::from_raw(
68            iter_ptr.as_ptr::<BTreeSetIterator<'_, T>>() as *mut BTreeSetIterator<'_, T>
69        ));
70    }
71}
72
73/// Build a BTreeSet from a contiguous slice of elements.
74///
75/// # Safety
76/// - `set` must point to uninitialized memory
77/// - `elements_ptr` must point to `count` consecutive initialized T values
78/// - Elements are moved out and should not be dropped by caller
79unsafe extern "C" fn btreeset_from_slice<T: Eq + Ord + 'static>(
80    set: PtrUninit,
81    elements_ptr: *mut u8,
82    count: usize,
83) -> PtrMut {
84    unsafe {
85        let elements = elements_ptr as *mut T;
86        let mut btreeset = BTreeSet::<T>::new();
87        for i in 0..count {
88            let elem = core::ptr::read(elements.add(i));
89            btreeset.insert(elem);
90        }
91        set.put(btreeset)
92    }
93}
94
95/// Drop implementation for `BTreeSet<T>`
96unsafe fn btreeset_drop<T>(ox: OxPtrMut) {
97    unsafe {
98        core::ptr::drop_in_place(ox.ptr().as_ptr::<BTreeSet<T>>() as *mut BTreeSet<T>);
99    }
100}
101
102/// Default implementation for `BTreeSet<T>`
103unsafe fn btreeset_default<T>(ox: OxPtrUninit) -> bool {
104    unsafe { ox.put(BTreeSet::<T>::new()) };
105    true
106}
107
108unsafe impl<'a, T> Facet<'a> for BTreeSet<T>
109where
110    T: Facet<'a> + core::cmp::Eq + core::cmp::Ord + 'static,
111{
112    const SHAPE: &'static crate::Shape = &const {
113        const fn build_set_vtable<T: Eq + Ord + 'static>() -> SetVTable {
114            SetVTable::builder()
115                .init_in_place_with_capacity(btreeset_init_in_place_with_capacity::<T>)
116                .insert(btreeset_insert::<T>)
117                .len(btreeset_len::<T>)
118                .contains(btreeset_contains::<T>)
119                .iter_vtable(IterVTable {
120                    init_with_value: Some(btreeset_iter_init::<T>),
121                    next: btreeset_iter_next::<T>,
122                    next_back: Some(btreeset_iter_next_back::<T>),
123                    size_hint: None,
124                    dealloc: btreeset_iter_dealloc::<T>,
125                })
126                .from_slice(Some(btreeset_from_slice::<T>))
127                .build()
128        }
129
130        const fn build_type_name<'a, T: Facet<'a>>() -> TypeNameFn {
131            fn type_name_impl<'a, T: Facet<'a>>(
132                _shape: &'static Shape,
133                f: &mut core::fmt::Formatter<'_>,
134                opts: TypeNameOpts,
135            ) -> core::fmt::Result {
136                write!(f, "BTreeSet")?;
137                if let Some(opts) = opts.for_children() {
138                    write!(f, "<")?;
139                    T::SHAPE.write_type_name(f, opts)?;
140                    write!(f, ">")?;
141                } else {
142                    write!(f, "<…>")?;
143                }
144                Ok(())
145            }
146            type_name_impl::<T>
147        }
148
149        ShapeBuilder::for_sized::<Self>("BTreeSet")
150            .module_path("alloc::collections::btree_set")
151            .type_name(build_type_name::<T>())
152            .vtable_indirect(&VTableIndirect::EMPTY)
153            .type_ops_indirect(
154                &const {
155                    TypeOpsIndirect {
156                        drop_in_place: btreeset_drop::<T>,
157                        default_in_place: Some(btreeset_default::<T>),
158                        clone_into: None,
159                        is_truthy: None,
160                    }
161                },
162            )
163            .def(Def::Set(SetDef::new(
164                &const { build_set_vtable::<T>() },
165                T::SHAPE,
166            )))
167            .type_params(&[TypeParam {
168                name: "T",
169                shape: T::SHAPE,
170            }])
171            .inner(T::SHAPE)
172            // BTreeSet<T> propagates T's variance
173            .variance(VarianceDesc {
174                base: Variance::Bivariant,
175                deps: &const { [VarianceDep::covariant(T::SHAPE)] },
176            })
177            .build()
178    };
179}
180
181#[cfg(test)]
182mod tests {
183    use core::ptr::NonNull;
184
185    use alloc::collections::BTreeSet;
186    use alloc::string::String;
187    use alloc::vec::Vec;
188
189    use super::*;
190
191    #[test]
192    fn test_btreesetset_type_params() {
193        let [type_param_1] = <BTreeSet<i32>>::SHAPE.type_params else {
194            panic!("BTreeSet<T> should have 1 type param")
195        };
196        assert_eq!(type_param_1.shape(), i32::SHAPE);
197    }
198
199    #[test]
200    fn test_btreeset_vtable_1_new_insert_iter_drop() {
201        facet_testhelpers::setup();
202
203        let btreeset_shape = <BTreeSet<String>>::SHAPE;
204        let btreeset_def = btreeset_shape
205            .def
206            .into_set()
207            .expect("BTreeSet<T> should have a set definition");
208
209        // Allocate memory for the BTreeSet
210        let btreeset_uninit_ptr = btreeset_shape.allocate().unwrap();
211
212        // Create the BTreeSet
213        let btreeset_ptr =
214            unsafe { (btreeset_def.vtable.init_in_place_with_capacity)(btreeset_uninit_ptr, 0) };
215
216        // The BTreeSet is empty, so ensure its length is 0
217        let btreeset_actual_length = unsafe { (btreeset_def.vtable.len)(btreeset_ptr.as_const()) };
218        assert_eq!(btreeset_actual_length, 0);
219
220        // 5 sample values to insert
221        let strings = ["foo", "bar", "bazz", "fizzbuzz", "fifth thing"];
222
223        // Insert the 5 values into the BTreeSet
224        let mut btreeset_length = 0;
225        for string in strings {
226            // Create the value
227            let mut new_value = string.to_string();
228
229            // Insert the value
230            let did_insert = unsafe {
231                (btreeset_def.vtable.insert)(
232                    btreeset_ptr,
233                    PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
234                )
235            };
236
237            // The value now belongs to the BTreeSet, so forget it
238            core::mem::forget(new_value);
239
240            assert!(did_insert, "expected value to be inserted in the BTreeSet");
241
242            // Ensure the BTreeSet's length increased by 1
243            btreeset_length += 1;
244            let btreeset_actual_length =
245                unsafe { (btreeset_def.vtable.len)(btreeset_ptr.as_const()) };
246            assert_eq!(btreeset_actual_length, btreeset_length);
247        }
248
249        // Insert the same 5 values again, ensuring they are deduplicated
250        for string in strings {
251            // Create the value
252            let mut new_value = string.to_string();
253
254            // Try to insert the value
255            let did_insert = unsafe {
256                (btreeset_def.vtable.insert)(
257                    btreeset_ptr,
258                    PtrMut::new(NonNull::from(&mut new_value).as_ptr()),
259                )
260            };
261
262            // The value now belongs to the BTreeSet, so forget it
263            core::mem::forget(new_value);
264
265            assert!(
266                !did_insert,
267                "expected value to not be inserted in the BTreeSet"
268            );
269
270            // Ensure the BTreeSet's length did not increase
271            let btreeset_actual_length =
272                unsafe { (btreeset_def.vtable.len)(btreeset_ptr.as_const()) };
273            assert_eq!(btreeset_actual_length, btreeset_length);
274        }
275
276        // Create a new iterator over the BTreeSet
277        let iter_init_with_value_fn = btreeset_def.vtable.iter_vtable.init_with_value.unwrap();
278        let btreeset_iter_ptr = unsafe { iter_init_with_value_fn(btreeset_ptr.as_const()) };
279
280        // Collect all the items from the BTreeSet's iterator
281        let mut iter_items = Vec::<&str>::new();
282        loop {
283            // Get the next item from the iterator
284            let item_ptr = unsafe { (btreeset_def.vtable.iter_vtable.next)(btreeset_iter_ptr) };
285            let Some(item_ptr) = item_ptr else {
286                break;
287            };
288
289            let item = unsafe { item_ptr.get::<String>() };
290
291            // Add the item into the list of items returned from the iterator
292            iter_items.push(&**item);
293        }
294
295        // Deallocate the iterator
296        unsafe {
297            (btreeset_def.vtable.iter_vtable.dealloc)(btreeset_iter_ptr);
298        }
299
300        // BTrees iterate in sorted order, so ensure the iterator returned
301        // each item in order
302        let mut strings_sorted = strings.to_vec();
303        strings_sorted.sort();
304        assert_eq!(iter_items, strings_sorted);
305
306        // Drop the BTreeSet in place
307        unsafe {
308            btreeset_shape
309                .call_drop_in_place(btreeset_ptr)
310                .expect("BTreeSet<T> should have drop_in_place");
311        }
312
313        // Deallocate the memory
314        unsafe { btreeset_shape.deallocate_mut(btreeset_ptr).unwrap() };
315    }
316}