Skip to main content

wolfram_expr/
array_buf.rs

1//! Shared implementation backing both [`NumericArray`][crate::NumericArray] and
2//! [`PackedArray`][crate::PackedArray].
3//!
4//! Both types are dense N-dimensional buffers with an element-type tag — the only
5//! difference is the *set* of valid element types (PackedArray supports a strict
6//! subset).
7
8use crate::wxf::NumericArrayEnum;
9use crate::ByteArray;
10
11/// Sealed marker for Rust primitives valid as an array element. The set is
12/// fixed (the C ABI only knows these widths), so external types can't
13/// implement [`ArrayElement`].
14mod sealed {
15    use crate::complex::{Complex32, Complex64};
16    pub trait Sealed {}
17    impl Sealed for i8 {}
18    impl Sealed for i16 {}
19    impl Sealed for i32 {}
20    impl Sealed for i64 {}
21    impl Sealed for u8 {}
22    impl Sealed for u16 {}
23    impl Sealed for u32 {}
24    impl Sealed for u64 {}
25    impl Sealed for f32 {}
26    impl Sealed for f64 {}
27    impl Sealed for Complex32 {}
28    impl Sealed for Complex64 {}
29}
30
31/// Connects a Rust primitive to its element-type discriminant. Implemented
32/// once per `(type, tag)` pair: e.g. `i32: ArrayElement<NumericArrayEnum>`
33/// (with `TAG = Integer32`) and `i32: ArrayElement<PackedArrayEnum>` (with
34/// `TAG = Integer32`). Sealed — only the primitives listed above can
35/// satisfy the `Sealed` super-bound.
36pub trait ArrayElement<Tag: Copy + PartialEq>: Copy + 'static + sealed::Sealed {
37    /// The element-type tag for `Self` under this array kind.
38    const TAG: Tag;
39}
40
41/// Generic dense N-dimensional buffer parameterized by an element-type tag.
42///
43/// `NumericArray = ArrayBuf<NumericArrayEnum>` and
44/// `PackedArray   = ArrayBuf<PackedArrayEnum>`. Each provides specialized
45/// constructors (`from_slice<T: …Element>`) and a typed slice view; shape,
46/// byte access, and element count are shared via this struct.
47#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct ArrayBuf<Tag> {
49    pub(crate) data_type: Tag,
50    pub(crate) dimensions: Vec<usize>,
51    pub(crate) bytes: ByteArray,
52}
53
54impl<Tag: Copy + PartialEq> ArrayBuf<Tag> {
55    /// Construct from raw parts. Caller is responsible for ensuring
56    /// `bytes.len() == prod(dimensions) * element_size`.
57    pub fn new(data_type: Tag, dimensions: Vec<usize>, bytes: ByteArray) -> Self {
58        ArrayBuf {
59            data_type,
60            dimensions,
61            bytes,
62        }
63    }
64
65    /// The concrete element-type tag.
66    pub fn data_type(&self) -> Tag {
67        self.data_type
68    }
69
70    /// Multi-dimensional shape.
71    pub fn dimensions(&self) -> &[usize] {
72        &self.dimensions
73    }
74
75    /// Raw byte buffer.
76    pub(crate) fn as_bytes(&self) -> &[u8] {
77        &self.bytes
78    }
79
80    /// Number of dimensions.
81    pub fn rank(&self) -> usize {
82        self.dimensions.len()
83    }
84
85    /// Total element count = product of dimensions.
86    pub fn element_count(&self) -> usize {
87        self.dimensions.iter().product()
88    }
89
90
91    /// Construct from a typed slice. Dimensions must satisfy
92    /// `prod(dimensions) == slice.len()`.
93    pub fn from_slice<T: ArrayElement<Tag>>(dimensions: Vec<usize>, slice: &[T]) -> Self {
94        assert_eq!(
95            dimensions.iter().product::<usize>(),
96            slice.len(),
97            "ArrayBuf::from_slice: dims product must equal slice length"
98        );
99        let bytes: &[u8] = unsafe {
100            std::slice::from_raw_parts(
101                slice.as_ptr() as *const u8,
102                std::mem::size_of_val(slice),
103            )
104        };
105        ArrayBuf::new(T::TAG, dimensions, ByteArray::from(bytes))
106    }
107
108    /// Try to view the buffer as a slice of `T`. Returns `None` if `T`'s tag
109    /// doesn't match this array's [`data_type`][Self::data_type].
110    pub fn try_as_slice<T: ArrayElement<Tag>>(&self) -> Option<&[T]> {
111        if self.data_type != T::TAG {
112            return None;
113        }
114        let bytes = self.as_bytes();
115        let elem_size = std::mem::size_of::<T>();
116        debug_assert_eq!(bytes.len() % elem_size, 0);
117        if bytes.is_empty() {
118            return Some(&[]);
119        }
120        // SAFETY: tag matches T, so the bytes were produced from a `[T]`.
121        Some(unsafe {
122            std::slice::from_raw_parts(
123                bytes.as_ptr() as *const T,
124                bytes.len() / elem_size,
125            )
126        })
127    }
128}
129
130/// Common read API implemented by both the owned [`crate::NumericArray`] /
131/// [`crate::PackedArray`] and the runtime-handle `NumericArray<T>` in
132/// `wolfram-library-link`.
133pub trait NumericArrayRead {
134    /// The element-type tag.
135    fn data_type(&self) -> NumericArrayEnum;
136    /// The multi-dimensional shape (row-major).
137    fn dimensions(&self) -> &[usize];
138    #[doc(hidden)]
139    fn as_bytes(&self) -> &[u8];
140
141    /// Number of dimensions.
142    fn rank(&self) -> usize {
143        self.dimensions().len()
144    }
145    /// Total element count = product of dimensions.
146    fn element_count(&self) -> usize {
147        self.dimensions().iter().product()
148    }
149    #[doc(hidden)]
150    fn byte_count(&self) -> usize {
151        self.as_bytes().len()
152    }
153    #[doc(hidden)]
154    fn element_size(&self) -> usize {
155        self.data_type().size_in_bytes()
156    }
157
158    /// View the buffer as `&[T]` if `T`'s element type matches; else `None`.
159    fn try_as_slice<T: ArrayElement<NumericArrayEnum>>(&self) -> Option<&[T]> {
160        if self.data_type() != T::TAG {
161            return None;
162        }
163        let bytes = self.as_bytes();
164        let elem_size = std::mem::size_of::<T>();
165        debug_assert_eq!(bytes.len() % elem_size, 0);
166        // SAFETY: tag matches, alignment guaranteed by construction.
167        Some(unsafe {
168            std::slice::from_raw_parts(
169                bytes.as_ptr() as *const T,
170                bytes.len() / elem_size,
171            )
172        })
173    }
174}
175
176impl<Tag: Into<NumericArrayEnum> + Copy + PartialEq> NumericArrayRead for ArrayBuf<Tag> {
177    fn data_type(&self) -> NumericArrayEnum {
178        self.data_type.into()
179    }
180    fn dimensions(&self) -> &[usize] {
181        &self.dimensions
182    }
183    fn as_bytes(&self) -> &[u8] {
184        &self.bytes
185    }
186}