Skip to main content

facet_reflect/peek/
ndarray.rs

1use super::Peek;
2use core::fmt::Debug;
3use facet_core::{NdArrayDef, PtrConst};
4
5/// Lets you read from an n-dimensional array (implements read-only [`facet_core::NdArrayVTable`] proxies)
6#[derive(Clone, Copy)]
7pub struct PeekNdArray<'mem, 'facet> {
8    value: Peek<'mem, 'facet>,
9    def: NdArrayDef,
10}
11
12impl Debug for PeekNdArray<'_, '_> {
13    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
14        f.debug_struct("PeekNdArray").finish_non_exhaustive()
15    }
16}
17
18/// Error that can occur when trying to access an n-dimensional array as strided
19#[derive(Clone, Copy, PartialEq, Eq)]
20pub enum StrideError {
21    /// Error indicating that the array is not strided.
22    NotStrided,
23}
24
25impl core::fmt::Display for StrideError {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        match self {
28            StrideError::NotStrided => {
29                write!(f, "array is not strided")
30            }
31        }
32    }
33}
34
35impl core::fmt::Debug for StrideError {
36    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
37        match self {
38            StrideError::NotStrided => {
39                write!(f, "StrideError::NotStrided: array is not strided")
40            }
41        }
42    }
43}
44impl<'mem, 'facet> PeekNdArray<'mem, 'facet> {
45    /// Creates a new peek array
46    ///
47    /// # Safety
48    ///
49    /// The caller must ensure that `def` contains valid vtable function pointers that:
50    /// - Correctly implement the ndarray operations for the actual type
51    /// - Do not cause undefined behavior when called
52    /// - Return pointers within valid memory bounds
53    /// - Match the element type specified in `def.t()`
54    ///
55    /// Violating these requirements can lead to memory safety issues.
56    #[inline]
57    pub const unsafe fn new(value: Peek<'mem, 'facet>, def: NdArrayDef) -> Self {
58        Self { value, def }
59    }
60
61    /// Get the number of elements in the array
62    #[inline]
63    pub fn count(&self) -> usize {
64        unsafe { (self.def.vtable.count)(self.value.data()) }
65    }
66
67    /// Get the number of elements in the array
68    #[inline]
69    pub fn n_dim(&self) -> usize {
70        unsafe { (self.def.vtable.n_dim)(self.value.data()) }
71    }
72
73    /// Get the i-th dimension of the array
74    #[inline]
75    pub fn dim(&self, i: usize) -> Option<usize> {
76        unsafe { (self.def.vtable.dim)(self.value.data(), i) }
77    }
78
79    /// Get an item from the array at the specified index
80    #[inline]
81    pub fn get(&self, index: usize) -> Option<Peek<'mem, 'facet>> {
82        let item = unsafe { (self.def.vtable.get)(self.value.data(), index)? };
83
84        Some(unsafe { Peek::unchecked_new(item, self.def.t()) })
85    }
86
87    /// Get a pointer to the start of the array
88    #[inline]
89    pub fn as_ptr(&self) -> Result<PtrConst, StrideError> {
90        let Some(as_ptr) = self.def.vtable.as_ptr else {
91            return Err(StrideError::NotStrided);
92        };
93        let ptr = unsafe { as_ptr(self.value.data()) };
94        Ok(ptr)
95    }
96
97    /// Get the i-th stride of the array in bytes
98    #[inline]
99    pub fn byte_stride(&self, i: usize) -> Result<Option<isize>, StrideError> {
100        let Some(byte_stride) = self.def.vtable.byte_stride else {
101            return Err(StrideError::NotStrided);
102        };
103        Ok(unsafe { byte_stride(self.value.data(), i) })
104    }
105
106    /// Peek value getter
107    #[inline]
108    pub const fn value(&self) -> Peek<'mem, 'facet> {
109        self.value
110    }
111
112    /// Def getter
113    #[inline]
114    pub const fn def(&self) -> NdArrayDef {
115        self.def
116    }
117}