facet_reflect/peek/
ndarray.rs

1use super::Peek;
2use core::fmt::Debug;
3use facet_core::{NdArrayDef, PtrConst};
4
5/// Lets you read from an n-dimensional array (implements read-only [`facet_core::NdArrayVTable`] proxies)
6#[derive(Clone, Copy)]
7pub struct PeekNdArray<'mem, 'facet> {
8    pub(crate) value: Peek<'mem, 'facet>,
9    pub(crate) def: NdArrayDef,
10}
11
12impl Debug for PeekNdArray<'_, '_> {
13    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
14        f.debug_struct("PeekNdArray").finish_non_exhaustive()
15    }
16}
17
18/// Error that can occur when trying to access an n-dimensional array as strided
19#[derive(Clone, Copy, PartialEq, Eq)]
20pub enum StrideError {
21    /// Error indicating that the array is not strided.
22    NotStrided,
23}
24
25impl core::fmt::Display for StrideError {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        match self {
28            StrideError::NotStrided => {
29                write!(f, "array is not strided")
30            }
31        }
32    }
33}
34
35impl core::fmt::Debug for StrideError {
36    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
37        match self {
38            StrideError::NotStrided => {
39                write!(f, "StrideError::NotStrided: array is not strided")
40            }
41        }
42    }
43}
44impl<'mem, 'facet> PeekNdArray<'mem, 'facet> {
45    /// Creates a new peek array
46    #[inline]
47    pub fn new(value: Peek<'mem, 'facet>, def: NdArrayDef) -> Self {
48        Self { value, def }
49    }
50
51    /// Get the number of elements in the array
52    #[inline]
53    pub fn count(&self) -> usize {
54        unsafe { (self.def.vtable.count)(self.value.data()) }
55    }
56
57    /// Get the number of elements in the array
58    #[inline]
59    pub fn n_dim(&self) -> usize {
60        unsafe { (self.def.vtable.n_dim)(self.value.data()) }
61    }
62
63    /// Get the i-th dimension of the array
64    #[inline]
65    pub fn dim(&self, i: usize) -> Option<usize> {
66        unsafe { (self.def.vtable.dim)(self.value.data(), i) }
67    }
68
69    /// Get an item from the array at the specified index
70    #[inline]
71    pub fn get(&self, index: usize) -> Option<Peek<'mem, 'facet>> {
72        let item = unsafe { (self.def.vtable.get)(self.value.data(), index)? };
73
74        Some(unsafe { Peek::unchecked_new(item, self.def.t()) })
75    }
76
77    /// Get a pointer to the start of the array
78    #[inline]
79    pub fn as_ptr(&self) -> Result<PtrConst, StrideError> {
80        let Some(as_ptr) = self.def.vtable.as_ptr else {
81            return Err(StrideError::NotStrided);
82        };
83        let ptr = unsafe { as_ptr(self.value.data()) };
84        Ok(ptr)
85    }
86
87    /// Get the i-th stride of the array in bytes
88    #[inline]
89    pub fn byte_stride(&self, i: usize) -> Result<Option<isize>, StrideError> {
90        let Some(byte_stride) = self.def.vtable.byte_stride else {
91            return Err(StrideError::NotStrided);
92        };
93        Ok(unsafe { byte_stride(self.value.data(), i) })
94    }
95
96    /// Peek value getter
97    #[inline]
98    pub fn value(&self) -> Peek<'mem, 'facet> {
99        self.value
100    }
101
102    /// Def getter
103    #[inline]
104    pub fn def(&self) -> NdArrayDef {
105        self.def
106    }
107}