Skip to main content

facet_reflect/poke/
map.rs

1use core::mem::ManuallyDrop;
2
3use facet_core::{Facet, MapDef};
4
5use crate::{HeapValue, ReflectError, ReflectErrorKind};
6
7use super::Poke;
8
9/// Lets you mutate a map (implements mutable [`facet_core::MapVTable`] proxies)
10pub struct PokeMap<'mem, 'facet> {
11    value: Poke<'mem, 'facet>,
12    def: MapDef,
13}
14
15impl<'mem, 'facet> core::fmt::Debug for PokeMap<'mem, 'facet> {
16    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
17        f.debug_struct("PokeMap").finish_non_exhaustive()
18    }
19}
20
21impl<'mem, 'facet> PokeMap<'mem, 'facet> {
22    /// Creates a new poke map
23    ///
24    /// # Safety
25    ///
26    /// The caller must ensure that `def` contains valid vtable function pointers that:
27    /// - Correctly implement the map operations for the actual type
28    /// - Do not cause undefined behavior when called
29    /// - Return pointers within valid memory bounds
30    /// - Match the key and value types specified in `def.k()` and `def.v()`
31    ///
32    /// Violating these requirements can lead to memory safety issues.
33    #[inline]
34    pub const unsafe fn new(value: Poke<'mem, 'facet>, def: MapDef) -> Self {
35        Self { value, def }
36    }
37
38    fn err(&self, kind: ReflectErrorKind) -> ReflectError {
39        self.value.err(kind)
40    }
41
42    /// Get the number of entries in the map
43    #[inline]
44    pub fn len(&self) -> usize {
45        unsafe { (self.def.vtable.len)(self.value.data()) }
46    }
47
48    /// Returns true if the map is empty
49    #[inline]
50    pub fn is_empty(&self) -> bool {
51        self.len() == 0
52    }
53
54    /// Check if the map contains a key
55    #[inline]
56    pub fn contains_key(&self, key: &impl Facet<'facet>) -> Result<bool, ReflectError> {
57        self.contains_key_peek(crate::Peek::new(key))
58    }
59
60    /// Check if the map contains a key (using a `Peek`)
61    #[inline]
62    pub fn contains_key_peek(&self, key: crate::Peek<'_, 'facet>) -> Result<bool, ReflectError> {
63        if self.def.k() == key.shape() {
64            return Ok(unsafe { (self.def.vtable.contains_key)(self.value.data(), key.data()) });
65        }
66
67        Err(self.err(ReflectErrorKind::WrongShape {
68            expected: self.def.k(),
69            actual: key.shape(),
70        }))
71    }
72
73    /// Get a value from the map for the given key, as a read-only `Peek`
74    #[inline]
75    pub fn get(
76        &self,
77        key: &impl Facet<'facet>,
78    ) -> Result<Option<crate::Peek<'_, 'facet>>, ReflectError> {
79        self.get_peek(crate::Peek::new(key))
80    }
81
82    /// Get a value from the map for the given key (using a `Peek`), as a read-only `Peek`
83    #[inline]
84    pub fn get_peek(
85        &self,
86        key: crate::Peek<'_, 'facet>,
87    ) -> Result<Option<crate::Peek<'_, 'facet>>, ReflectError> {
88        if self.def.k() != key.shape() {
89            return Err(self.err(ReflectErrorKind::WrongShape {
90                expected: self.def.k(),
91                actual: key.shape(),
92            }));
93        }
94
95        let value_ptr = unsafe { (self.def.vtable.get_value_ptr)(self.value.data(), key.data()) };
96        if value_ptr.is_null() {
97            return Ok(None);
98        }
99        let value_ptr = facet_core::PtrConst::new_sized(value_ptr);
100        Ok(Some(unsafe {
101            crate::Peek::unchecked_new(value_ptr, self.def.v())
102        }))
103    }
104
105    /// Insert a key-value pair into the map.
106    ///
107    /// Both key and value must have shapes matching the map's key and value types.
108    /// The key and value are moved into the map.
109    pub fn insert<K, V>(&mut self, key: K, value: V) -> Result<(), ReflectError>
110    where
111        K: Facet<'facet>,
112        V: Facet<'facet>,
113    {
114        if self.def.k() != K::SHAPE {
115            return Err(self.err(ReflectErrorKind::WrongShape {
116                expected: self.def.k(),
117                actual: K::SHAPE,
118            }));
119        }
120        if self.def.v() != V::SHAPE {
121            return Err(self.err(ReflectErrorKind::WrongShape {
122                expected: self.def.v(),
123                actual: V::SHAPE,
124            }));
125        }
126
127        // The insert vtable moves the key and value (via ptr::read), so we need to
128        // hand over temporary storage that we will not drop afterwards.
129        let mut key = ManuallyDrop::new(key);
130        let mut value = ManuallyDrop::new(value);
131        unsafe {
132            let key_ptr = facet_core::PtrMut::new(&mut key as *mut ManuallyDrop<K> as *mut u8);
133            let value_ptr = facet_core::PtrMut::new(&mut value as *mut ManuallyDrop<V> as *mut u8);
134            (self.def.vtable.insert)(self.value.data_mut(), key_ptr, value_ptr);
135        }
136
137        Ok(())
138    }
139
140    /// Type-erased [`insert`](Self::insert).
141    ///
142    /// Accepts [`HeapValue`]s for key and value; their shapes must match the map's key and
143    /// value types. Both values are moved into the map.
144    pub fn insert_from_heap<const KB: bool, const VB: bool>(
145        &mut self,
146        key: HeapValue<'facet, KB>,
147        value: HeapValue<'facet, VB>,
148    ) -> Result<(), ReflectError> {
149        if self.def.k() != key.shape() {
150            return Err(self.err(ReflectErrorKind::WrongShape {
151                expected: self.def.k(),
152                actual: key.shape(),
153            }));
154        }
155        if self.def.v() != value.shape() {
156            return Err(self.err(ReflectErrorKind::WrongShape {
157                expected: self.def.v(),
158                actual: value.shape(),
159            }));
160        }
161
162        let mut key = key;
163        let mut value = value;
164        let key_guard = key.guard.take().expect("key HeapValue guard already taken");
165        let value_guard = value
166            .guard
167            .take()
168            .expect("value HeapValue guard already taken");
169        unsafe {
170            let key_ptr = facet_core::PtrMut::new(key_guard.ptr.as_ptr());
171            let value_ptr = facet_core::PtrMut::new(value_guard.ptr.as_ptr());
172            (self.def.vtable.insert)(self.value.data_mut(), key_ptr, value_ptr);
173        }
174        drop(key_guard);
175        drop(value_guard);
176        Ok(())
177    }
178
179    /// Returns an iterator over the key-value pairs in the map (read-only).
180    #[inline]
181    pub fn iter(&self) -> crate::PeekMapIter<'_, 'facet> {
182        self.as_peek_map().iter()
183    }
184
185    /// Def getter
186    #[inline]
187    pub const fn def(&self) -> MapDef {
188        self.def
189    }
190
191    /// Converts this `PokeMap` back into a `Poke`
192    #[inline]
193    pub fn into_inner(self) -> Poke<'mem, 'facet> {
194        self.value
195    }
196
197    /// Returns a read-only `PeekMap` view
198    #[inline]
199    pub fn as_peek_map(&self) -> crate::PeekMap<'_, 'facet> {
200        unsafe { crate::PeekMap::new(self.value.as_peek(), self.def) }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use alloc::collections::BTreeMap;
207
208    use super::*;
209
210    #[test]
211    fn poke_map_len_and_insert() {
212        let mut m: BTreeMap<String, i32> = BTreeMap::new();
213        let poke = Poke::new(&mut m);
214        let mut map = poke.into_map().unwrap();
215        assert_eq!(map.len(), 0);
216        map.insert(String::from("one"), 1i32).unwrap();
217        map.insert(String::from("two"), 2i32).unwrap();
218        assert_eq!(map.len(), 2);
219
220        assert_eq!(m.get("one"), Some(&1));
221        assert_eq!(m.get("two"), Some(&2));
222    }
223
224    #[test]
225    fn poke_map_contains_and_get() {
226        let mut m: BTreeMap<String, i32> = BTreeMap::new();
227        m.insert(String::from("a"), 10);
228        let poke = Poke::new(&mut m);
229        let map = poke.into_map().unwrap();
230
231        let key = String::from("a");
232        assert!(map.contains_key(&key).unwrap());
233
234        let v = map.get(&key).unwrap().unwrap();
235        assert_eq!(*v.get::<i32>().unwrap(), 10);
236    }
237
238    #[test]
239    fn poke_map_insert_from_heap() {
240        let mut m: BTreeMap<String, i32> = BTreeMap::new();
241        let poke = Poke::new(&mut m);
242        let mut map = poke.into_map().unwrap();
243
244        let key = crate::Partial::alloc::<String>()
245            .unwrap()
246            .set(String::from("k"))
247            .unwrap()
248            .build()
249            .unwrap();
250        let value = crate::Partial::alloc::<i32>()
251            .unwrap()
252            .set(42i32)
253            .unwrap()
254            .build()
255            .unwrap();
256        map.insert_from_heap(key, value).unwrap();
257
258        assert_eq!(m.get("k"), Some(&42));
259    }
260}