Skip to main content

facet_reflect/poke/
set.rs

1use core::mem::ManuallyDrop;
2
3use facet_core::{Facet, SetDef};
4
5use crate::{HeapValue, ReflectError, ReflectErrorKind};
6
7use super::Poke;
8
9/// Lets you mutate a set (implements mutable [`facet_core::SetVTable`] proxies)
10pub struct PokeSet<'mem, 'facet> {
11    value: Poke<'mem, 'facet>,
12    def: SetDef,
13}
14
15impl<'mem, 'facet> core::fmt::Debug for PokeSet<'mem, 'facet> {
16    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
17        f.debug_struct("PokeSet").finish_non_exhaustive()
18    }
19}
20
21impl<'mem, 'facet> PokeSet<'mem, 'facet> {
22    /// Creates a new poke set
23    ///
24    /// # Safety
25    ///
26    /// The caller must ensure that `def` contains valid vtable function pointers that:
27    /// - Correctly implement the set operations for the actual type
28    /// - Do not cause undefined behavior when called
29    /// - Return pointers within valid memory bounds
30    /// - Match the element type specified in `def.t()`
31    ///
32    /// Violating these requirements can lead to memory safety issues.
33    #[inline]
34    pub const unsafe fn new(value: Poke<'mem, 'facet>, def: SetDef) -> Self {
35        Self { value, def }
36    }
37
38    fn err(&self, kind: ReflectErrorKind) -> ReflectError {
39        self.value.err(kind)
40    }
41
42    /// Get the number of entries in the set
43    #[inline]
44    pub fn len(&self) -> usize {
45        unsafe { (self.def.vtable.len)(self.value.data()) }
46    }
47
48    /// Returns true if the set is empty
49    #[inline]
50    pub fn is_empty(&self) -> bool {
51        self.len() == 0
52    }
53
54    /// Check if the set contains a value
55    #[inline]
56    pub fn contains(&self, value: &impl Facet<'facet>) -> Result<bool, ReflectError> {
57        self.contains_peek(crate::Peek::new(value))
58    }
59
60    /// Check if the set contains a value (using a `Peek`)
61    #[inline]
62    pub fn contains_peek(&self, value: crate::Peek<'_, 'facet>) -> Result<bool, ReflectError> {
63        if self.def.t() == value.shape() {
64            return Ok(unsafe { (self.def.vtable.contains)(self.value.data(), value.data()) });
65        }
66        Err(self.err(ReflectErrorKind::WrongShape {
67            expected: self.def.t(),
68            actual: value.shape(),
69        }))
70    }
71
72    /// Insert a value into the set. Returns `true` if the value was newly
73    /// inserted, `false` if it was already present.
74    pub fn insert<T: Facet<'facet>>(&mut self, value: T) -> Result<bool, ReflectError> {
75        if self.def.t() != T::SHAPE {
76            return Err(self.err(ReflectErrorKind::WrongShape {
77                expected: self.def.t(),
78                actual: T::SHAPE,
79            }));
80        }
81
82        let mut value = ManuallyDrop::new(value);
83        let inserted = unsafe {
84            let value_ptr = facet_core::PtrMut::new(&mut value as *mut ManuallyDrop<T> as *mut u8);
85            (self.def.vtable.insert)(self.value.data_mut(), value_ptr)
86        };
87        Ok(inserted)
88    }
89
90    /// Type-erased [`insert`](Self::insert).
91    ///
92    /// Accepts a [`HeapValue`] whose shape must match the set's element type. The value is
93    /// moved into the set. Returns `true` if the value was newly inserted.
94    pub fn insert_from_heap<const BORROW: bool>(
95        &mut self,
96        value: HeapValue<'facet, BORROW>,
97    ) -> Result<bool, ReflectError> {
98        if self.def.t() != value.shape() {
99            return Err(self.err(ReflectErrorKind::WrongShape {
100                expected: self.def.t(),
101                actual: value.shape(),
102            }));
103        }
104
105        let mut value = value;
106        let guard = value
107            .guard
108            .take()
109            .expect("HeapValue guard was already taken");
110        let inserted = unsafe {
111            let value_ptr = facet_core::PtrMut::new(guard.ptr.as_ptr());
112            (self.def.vtable.insert)(self.value.data_mut(), value_ptr)
113        };
114        drop(guard);
115        Ok(inserted)
116    }
117
118    /// Returns an iterator over the values in the set (read-only).
119    #[inline]
120    pub fn iter(&self) -> crate::PeekSetIter<'_, 'facet> {
121        self.as_peek_set().iter()
122    }
123
124    /// Def getter
125    #[inline]
126    pub const fn def(&self) -> SetDef {
127        self.def
128    }
129
130    /// Converts this `PokeSet` back into a `Poke`
131    #[inline]
132    pub fn into_inner(self) -> Poke<'mem, 'facet> {
133        self.value
134    }
135
136    /// Returns a read-only `PeekSet` view
137    #[inline]
138    pub fn as_peek_set(&self) -> crate::PeekSet<'_, 'facet> {
139        unsafe { crate::PeekSet::new(self.value.as_peek(), self.def) }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use alloc::collections::BTreeSet;
146
147    use super::*;
148
149    #[test]
150    fn poke_set_len_and_insert() {
151        let mut s: BTreeSet<i32> = BTreeSet::new();
152        let poke = Poke::new(&mut s);
153        let mut set = poke.into_set().unwrap();
154        assert_eq!(set.len(), 0);
155
156        assert!(set.insert(1i32).unwrap());
157        assert!(set.insert(2i32).unwrap());
158        assert!(!set.insert(1i32).unwrap());
159
160        assert_eq!(set.len(), 2);
161        assert!(s.contains(&1));
162        assert!(s.contains(&2));
163    }
164
165    #[test]
166    fn poke_set_contains() {
167        let mut s: BTreeSet<i32> = BTreeSet::new();
168        s.insert(42);
169        let poke = Poke::new(&mut s);
170        let set = poke.into_set().unwrap();
171
172        assert!(set.contains(&42i32).unwrap());
173        assert!(!set.contains(&7i32).unwrap());
174    }
175
176    #[test]
177    fn poke_set_insert_from_heap() {
178        let mut s: BTreeSet<i32> = BTreeSet::new();
179        let poke = Poke::new(&mut s);
180        let mut set = poke.into_set().unwrap();
181
182        let hv = crate::Partial::alloc::<i32>()
183            .unwrap()
184            .set(7i32)
185            .unwrap()
186            .build()
187            .unwrap();
188        assert!(set.insert_from_heap(hv).unwrap());
189        assert!(s.contains(&7));
190    }
191}