Skip to main content

facet_reflect/poke/
ndarray.rs

1use core::fmt::Debug;
2use facet_core::{NdArrayDef, PtrMut};
3
4use crate::peek::StrideError;
5
6use super::Poke;
7
8/// Lets you mutate an n-dimensional array (implements mutable [`facet_core::NdArrayVTable`] proxies)
9pub struct PokeNdArray<'mem, 'facet> {
10    value: Poke<'mem, 'facet>,
11    def: NdArrayDef,
12}
13
14impl Debug for PokeNdArray<'_, '_> {
15    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16        f.debug_struct("PokeNdArray").finish_non_exhaustive()
17    }
18}
19
20impl<'mem, 'facet> PokeNdArray<'mem, 'facet> {
21    /// Creates a new poke ndarray.
22    ///
23    /// # Safety
24    ///
25    /// The caller must ensure that `def` contains valid vtable function pointers that
26    /// correctly implement the ndarray operations for the actual type, and that the
27    /// element type matches `def.t()`.
28    #[inline]
29    pub const unsafe fn new(value: Poke<'mem, 'facet>, def: NdArrayDef) -> Self {
30        Self { value, def }
31    }
32
33    /// Get the total number of elements in the array.
34    #[inline]
35    pub fn count(&self) -> usize {
36        unsafe { (self.def.vtable.count)(self.value.data()) }
37    }
38
39    /// Get the number of dimensions.
40    #[inline]
41    pub fn n_dim(&self) -> usize {
42        unsafe { (self.def.vtable.n_dim)(self.value.data()) }
43    }
44
45    /// Get the i-th dimension.
46    #[inline]
47    pub fn dim(&self, i: usize) -> Option<usize> {
48        unsafe { (self.def.vtable.dim)(self.value.data(), i) }
49    }
50
51    /// Get a read-only view of the item at the given flat index.
52    #[inline]
53    pub fn get(&self, index: usize) -> Option<crate::Peek<'_, 'facet>> {
54        let item = unsafe { (self.def.vtable.get)(self.value.data(), index)? };
55        Some(unsafe { crate::Peek::unchecked_new(item, self.def.t()) })
56    }
57
58    /// Get a mutable view of the item at the given flat index.
59    ///
60    /// Returns `None` if the underlying ndarray doesn't provide mutable access or
61    /// if the index is out of bounds.
62    pub fn get_mut(&mut self, index: usize) -> Option<Poke<'_, 'facet>> {
63        let get_mut_fn = self.def.vtable.get_mut?;
64        let item = unsafe { get_mut_fn(self.value.data_mut(), index)? };
65        Some(unsafe { Poke::from_raw_parts(item, self.def.t()) })
66    }
67
68    /// Get a mutable pointer to the start of the data buffer (if the array is strided).
69    #[inline]
70    pub fn as_mut_ptr(&mut self) -> Result<PtrMut, StrideError> {
71        let Some(as_mut_ptr) = self.def.vtable.as_mut_ptr else {
72            return Err(StrideError::NotStrided);
73        };
74        Ok(unsafe { as_mut_ptr(self.value.data_mut()) })
75    }
76
77    /// Get the i-th stride in bytes.
78    #[inline]
79    pub fn byte_stride(&self, i: usize) -> Result<Option<isize>, StrideError> {
80        let Some(byte_stride) = self.def.vtable.byte_stride else {
81            return Err(StrideError::NotStrided);
82        };
83        Ok(unsafe { byte_stride(self.value.data(), i) })
84    }
85
86    /// Def getter.
87    #[inline]
88    pub const fn def(&self) -> NdArrayDef {
89        self.def
90    }
91
92    /// Converts this `PokeNdArray` back into a `Poke`.
93    #[inline]
94    pub const fn into_inner(self) -> Poke<'mem, 'facet> {
95        self.value
96    }
97
98    /// Returns a read-only `PeekNdArray` view.
99    #[inline]
100    pub fn as_peek_ndarray(&self) -> crate::PeekNdArray<'_, 'facet> {
101        unsafe { crate::PeekNdArray::new(self.value.as_peek(), self.def) }
102    }
103}