Skip to main content

facet_reflect/poke/
list.rs

1use super::Poke;
2use core::{fmt::Debug, marker::PhantomData};
3use facet_core::{ListDef, PtrMut};
4
5/// Iterator over a `PokeList`
6pub struct PokeListIter<'mem, 'facet> {
7    state: PokeListIterState<'mem>,
8    index: usize,
9    len: usize,
10    def: ListDef,
11    _list: PhantomData<Poke<'mem, 'facet>>,
12}
13
14impl<'mem, 'facet> Iterator for PokeListIter<'mem, 'facet> {
15    type Item = Poke<'mem, 'facet>;
16
17    #[inline]
18    fn next(&mut self) -> Option<Self::Item> {
19        let item_ptr = match &mut self.state.kind {
20            PokeListIterStateKind::Ptr { data, stride } => {
21                if self.index >= self.len {
22                    return None;
23                }
24
25                unsafe { data.field(*stride * self.index) }
26            }
27            PokeListIterStateKind::Iter { iter } => unsafe {
28                // The iter vtable returns PtrConst, but we know the underlying data is mutable
29                // because we created this iterator from a PokeList which has mutable access.
30                // We need to convert the const pointer back to mutable.
31                let const_ptr = (self.def.iter_vtable().unwrap().next)(*iter)?;
32                PtrMut::new(const_ptr.as_byte_ptr() as *mut u8)
33            },
34        };
35
36        self.index += 1;
37
38        Some(unsafe { Poke::from_raw_parts(item_ptr, self.def.t()) })
39    }
40
41    #[inline]
42    fn size_hint(&self) -> (usize, Option<usize>) {
43        let remaining = self.len.saturating_sub(self.index);
44        (remaining, Some(remaining))
45    }
46}
47
48impl ExactSizeIterator for PokeListIter<'_, '_> {}
49
50impl Drop for PokeListIter<'_, '_> {
51    #[inline]
52    fn drop(&mut self) {
53        match &self.state.kind {
54            PokeListIterStateKind::Iter { iter } => unsafe {
55                (self.def.iter_vtable().unwrap().dealloc)(*iter)
56            },
57            PokeListIterStateKind::Ptr { .. } => {
58                // Nothing to do
59            }
60        }
61    }
62}
63
64struct PokeListIterState<'mem> {
65    kind: PokeListIterStateKind,
66    _phantom: PhantomData<&'mem mut ()>,
67}
68
69enum PokeListIterStateKind {
70    Ptr { data: PtrMut, stride: usize },
71    Iter { iter: PtrMut },
72}
73
74/// Lets you mutate a list (implements mutable [`facet_core::ListVTable`] proxies)
75pub struct PokeList<'mem, 'facet> {
76    value: Poke<'mem, 'facet>,
77    def: ListDef,
78}
79
80impl Debug for PokeList<'_, '_> {
81    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
82        f.debug_struct("PokeList").finish_non_exhaustive()
83    }
84}
85
86impl<'mem, 'facet> PokeList<'mem, 'facet> {
87    /// Creates a new poke list
88    ///
89    /// # Safety
90    ///
91    /// The caller must ensure that `def` contains valid vtable function pointers that:
92    /// - Correctly implement the list operations for the actual type
93    /// - Do not cause undefined behavior when called
94    /// - Return pointers within valid memory bounds
95    /// - Match the element type specified in `def.t()`
96    ///
97    /// Violating these requirements can lead to memory safety issues.
98    #[inline]
99    pub const unsafe fn new(value: Poke<'mem, 'facet>, def: ListDef) -> Self {
100        Self { value, def }
101    }
102
103    /// Get the length of the list
104    #[inline]
105    pub fn len(&self) -> usize {
106        unsafe { (self.def.vtable.len)(self.value.data()) }
107    }
108
109    /// Returns true if the list is empty
110    #[inline]
111    pub fn is_empty(&self) -> bool {
112        self.len() == 0
113    }
114
115    /// Get an immutable reference to an item from the list at the specified index
116    #[inline]
117    pub fn get(&self, index: usize) -> Option<crate::Peek<'_, 'facet>> {
118        let item = unsafe { (self.def.vtable.get)(self.value.data(), index, self.value.shape())? };
119
120        Some(unsafe { crate::Peek::unchecked_new(item, self.def.t()) })
121    }
122
123    /// Get a mutable reference to an item from the list at the specified index
124    #[inline]
125    pub fn get_mut(&mut self, index: usize) -> Option<Poke<'_, 'facet>> {
126        let get_mut_fn = self.def.vtable.get_mut?;
127        let item = unsafe { get_mut_fn(self.value.data, index, self.value.shape())? };
128
129        Some(unsafe { Poke::from_raw_parts(item, self.def.t()) })
130    }
131
132    /// Returns a mutable iterator over the list
133    pub fn iter_mut(self) -> PokeListIter<'mem, 'facet> {
134        let state = if let Some(as_mut_ptr_fn) = self.def.vtable.as_mut_ptr {
135            let data = unsafe { as_mut_ptr_fn(self.value.data) };
136            let layout = self
137                .def
138                .t()
139                .layout
140                .sized_layout()
141                .expect("can only iterate over sized list elements");
142            let stride = layout.size();
143
144            PokeListIterState {
145                kind: PokeListIterStateKind::Ptr { data, stride },
146                _phantom: PhantomData,
147            }
148        } else {
149            // Fall back to the immutable iterator, but we know we have mutable access
150            let iter = unsafe {
151                (self.def.iter_vtable().unwrap().init_with_value.unwrap())(self.value.data())
152            };
153            PokeListIterState {
154                kind: PokeListIterStateKind::Iter { iter },
155                _phantom: PhantomData,
156            }
157        };
158
159        PokeListIter {
160            state,
161            index: 0,
162            len: self.len(),
163            def: self.def(),
164            _list: PhantomData,
165        }
166    }
167
168    /// Def getter
169    #[inline]
170    pub const fn def(&self) -> ListDef {
171        self.def
172    }
173
174    /// Converts this `PokeList` back into a `Poke`
175    #[inline]
176    pub fn into_inner(self) -> Poke<'mem, 'facet> {
177        self.value
178    }
179
180    /// Returns a read-only `PeekList` view
181    #[inline]
182    pub fn as_peek_list(&self) -> crate::PeekList<'_, 'facet> {
183        unsafe { crate::PeekList::new(self.value.as_peek(), self.def) }
184    }
185}
186
187impl<'mem, 'facet> IntoIterator for PokeList<'mem, 'facet> {
188    type Item = Poke<'mem, 'facet>;
189    type IntoIter = PokeListIter<'mem, 'facet>;
190
191    #[inline]
192    fn into_iter(self) -> Self::IntoIter {
193        self.iter_mut()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use alloc::vec::Vec;
200
201    use super::*;
202
203    #[test]
204    fn poke_list_len() {
205        let mut v: Vec<i32> = alloc::vec![1, 2, 3, 4, 5];
206        let poke = Poke::new(&mut v);
207        let list = poke.into_list().unwrap();
208        assert_eq!(list.len(), 5);
209    }
210
211    #[test]
212    fn poke_list_get() {
213        let mut v: Vec<i32> = alloc::vec![10, 20, 30];
214        let poke = Poke::new(&mut v);
215        let list = poke.into_list().unwrap();
216
217        let item = list.get(1).unwrap();
218        assert_eq!(*item.get::<i32>().unwrap(), 20);
219    }
220
221    #[test]
222    fn poke_list_get_mut() {
223        let mut v: Vec<i32> = alloc::vec![10, 20, 30];
224        let poke = Poke::new(&mut v);
225        let mut list = poke.into_list().unwrap();
226
227        {
228            let mut item = list.get_mut(1).unwrap();
229            item.set(99i32).unwrap();
230        }
231
232        // Verify the change
233        let item = list.get(1).unwrap();
234        assert_eq!(*item.get::<i32>().unwrap(), 99);
235    }
236
237    #[test]
238    fn poke_list_iter_mut() {
239        let mut v: Vec<i32> = alloc::vec![1, 2, 3];
240        let poke = Poke::new(&mut v);
241        let list = poke.into_list().unwrap();
242
243        let mut sum = 0;
244        for mut item in list {
245            let val = *item.get::<i32>().unwrap();
246            item.set(val * 10).unwrap();
247            sum += val;
248        }
249
250        assert_eq!(sum, 6);
251        assert_eq!(v, alloc::vec![10, 20, 30]);
252    }
253}