ndarray_layout/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4/// An array layout allow N dimensions inlined.
5pub struct ArrayLayout<const N: usize> {
6    ndim: usize,
7    content: Union<N>,
8}
9
10unsafe impl<const N: usize> Send for ArrayLayout<N> {}
11unsafe impl<const N: usize> Sync for ArrayLayout<N> {}
12
13union Union<const N: usize> {
14    ptr: NonNull<usize>,
15    _inlined: (isize, [usize; N], [isize; N]),
16}
17
18impl<const N: usize> Clone for ArrayLayout<N> {
19    #[inline]
20    fn clone(&self) -> Self {
21        Self::new(self.shape(), self.strides(), self.offset())
22    }
23}
24
25impl<const N: usize> PartialEq for ArrayLayout<N> {
26    #[inline]
27    fn eq(&self, other: &Self) -> bool {
28        self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice()
29    }
30}
31
32impl<const N: usize> Eq for ArrayLayout<N> {}
33
34impl<const N: usize> Drop for ArrayLayout<N> {
35    fn drop(&mut self) {
36        if let Some(ptr) = self.ptr_allocated() {
37            unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) }
38        }
39    }
40}
41
42/// 元信息存储顺序。
43#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
44pub enum Endian {
45    /// 大端序,范围更大的维度在元信息中更靠前的位置。
46    BigEndian,
47    /// 小端序,范围更小的维度在元信息中更靠前的位置。
48    LittleEndian,
49}
50
51impl<const N: usize> ArrayLayout<N> {
52    /// Creates a new Layout with the given shape, strides, and offset.
53    ///
54    /// ```rust
55    /// # use ndarray_layout::ArrayLayout;
56    /// let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
57    /// assert_eq!(layout.offset(), 20);
58    /// assert_eq!(layout.shape(), &[2, 3, 4]);
59    /// assert_eq!(layout.strides(), &[12, -4, 1]);
60    /// ```
61    pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self {
62        // check
63        assert_eq!(
64            shape.len(),
65            strides.len(),
66            "shape and strides must have the same length"
67        );
68
69        let mut ans = Self::with_ndim(shape.len());
70        let mut content = ans.content_mut();
71        content.set_offset(offset);
72        content.copy_shape(shape);
73        content.copy_strides(strides);
74        ans
75    }
76
77    /// Creates a new contiguous Layout with the given shape.
78    ///
79    /// ```rust
80    /// # use ndarray_layout::{Endian, ArrayLayout};
81    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
82    /// assert_eq!(layout.offset(), 0);
83    /// assert_eq!(layout.shape(), &[2, 3, 4]);
84    /// assert_eq!(layout.strides(), &[4, 8, 24]);
85    /// ```
86    pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self {
87        let mut ans = Self::with_ndim(shape.len());
88        let mut content = ans.content_mut();
89        content.set_offset(0);
90        content.copy_shape(shape);
91        let mut mul = element_size as isize;
92        let push = |i| {
93            content.set_stride(i, mul);
94            mul *= shape[i] as isize;
95        };
96        match endian {
97            Endian::BigEndian => (0..shape.len()).rev().for_each(push),
98            Endian::LittleEndian => (0..shape.len()).for_each(push),
99        }
100        ans
101    }
102
103    /// Gets offset.
104    #[inline]
105    pub const fn ndim(&self) -> usize {
106        self.ndim
107    }
108
109    /// Gets offset.
110    #[inline]
111    pub fn offset(&self) -> isize {
112        self.content().offset()
113    }
114
115    /// Gets shape.
116    #[inline]
117    pub fn shape(&self) -> &[usize] {
118        self.content().shape()
119    }
120
121    /// Gets strides.
122    #[inline]
123    pub fn strides(&self) -> &[isize] {
124        self.content().strides()
125    }
126
127    /// Calculates the number of elements in the array.
128    ///
129    /// ```rust
130    /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout};
131    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 20);
132    /// assert_eq!(layout.num_elements(), 24);
133    /// ```
134    #[inline]
135    pub fn num_elements(&self) -> usize {
136        self.shape().iter().product()
137    }
138
139    /// Calculates the offset of element at the given `index`.
140    ///
141    /// ```rust
142    /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout};
143    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 4);
144    /// assert_eq!(layout.element_offset(22, BigEndian), 88); // 88 <- (22 % 4 * 4) + (22 / 4 % 3 * 16) + (22 / 4 / 3 % 2 * 48)
145    /// ```
146    pub fn element_offset(&self, index: usize, endian: Endian) -> isize {
147        fn offset_forwards(
148            mut rem: usize,
149            shape: impl IntoIterator<Item = usize>,
150            strides: impl IntoIterator<Item = isize>,
151        ) -> isize {
152            let mut ans = 0;
153            for (d, s) in zip(shape, strides) {
154                ans += s * (rem % d) as isize;
155                rem /= d
156            }
157            ans
158        }
159
160        let shape = self.shape().iter().cloned();
161        let strides = self.strides().iter().cloned();
162        self.offset()
163            + match endian {
164                Endian::BigEndian => offset_forwards(index, shape.rev(), strides.rev()),
165                Endian::LittleEndian => offset_forwards(index, shape, strides),
166            }
167    }
168
169    /// Calculates the range of data in bytes to determine the location of the memory area that the array needs to access.
170    pub fn data_range(&self) -> RangeInclusive<isize> {
171        let content = self.content();
172        let mut start = content.offset();
173        let mut end = content.offset();
174        for (&d, s) in zip(content.shape(), content.strides()) {
175            use std::cmp::Ordering::{Equal, Greater, Less};
176            let i = d as isize - 1;
177            match s.cmp(&0) {
178                Equal => {}
179                Less => start += s * i,
180                Greater => end += s * i,
181            }
182        }
183        start..=end
184    }
185}
186
187mod transform;
188pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg};
189
190use std::{
191    alloc::{Layout, alloc, dealloc},
192    iter::zip,
193    ops::RangeInclusive,
194    ptr::{NonNull, copy_nonoverlapping},
195    slice::from_raw_parts,
196};
197
198impl<const N: usize> ArrayLayout<N> {
199    #[inline]
200    fn ptr_allocated(&self) -> Option<NonNull<usize>> {
201        const { assert!(N > 0) }
202        if self.ndim > N {
203            Some(unsafe { self.content.ptr })
204        } else {
205            None
206        }
207    }
208
209    #[inline]
210    fn content(&self) -> Content<false> {
211        Content {
212            ptr: self
213                .ptr_allocated()
214                .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
215            ndim: self.ndim,
216        }
217    }
218
219    #[inline]
220    fn content_mut(&mut self) -> Content<true> {
221        Content {
222            ptr: self
223                .ptr_allocated()
224                .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
225            ndim: self.ndim,
226        }
227    }
228
229    /// Create a new ArrayLayout with the given dimensions.
230    #[inline]
231    fn with_ndim(ndim: usize) -> Self {
232        Self {
233            ndim,
234            content: if ndim <= N {
235                Union {
236                    _inlined: (0, [0; N], [0; N]),
237                }
238            } else {
239                Union {
240                    ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) },
241                }
242            },
243        }
244    }
245}
246
247struct Content<const MUT: bool> {
248    ptr: NonNull<usize>,
249    ndim: usize,
250}
251
252impl<const MUT: bool> Content<MUT> {
253    #[inline]
254    fn as_slice(&self) -> &[usize] {
255        unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) }
256    }
257
258    #[inline]
259    fn offset(&self) -> isize {
260        unsafe { self.ptr.cast().read() }
261    }
262
263    #[inline]
264    fn shape<'a>(&self) -> &'a [usize] {
265        unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) }
266    }
267
268    #[inline]
269    fn strides<'a>(&self) -> &'a [isize] {
270        unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) }
271    }
272}
273
274impl Content<true> {
275    #[inline]
276    fn set_offset(&mut self, val: isize) {
277        unsafe { self.ptr.cast().write(val) }
278    }
279
280    #[inline]
281    fn set_shape(&mut self, idx: usize, val: usize) {
282        assert!(idx < self.ndim);
283        unsafe { self.ptr.add(1 + idx).write(val) }
284    }
285
286    #[inline]
287    fn set_stride(&mut self, idx: usize, val: isize) {
288        assert!(idx < self.ndim);
289        unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) }
290    }
291
292    #[inline]
293    fn copy_shape(&mut self, val: &[usize]) {
294        assert!(val.len() == self.ndim);
295        unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) }
296    }
297
298    #[inline]
299    fn copy_strides(&mut self, val: &[isize]) {
300        assert!(val.len() == self.ndim);
301        unsafe {
302            copy_nonoverlapping(
303                val.as_ptr(),
304                self.ptr.add(1 + self.ndim).cast().as_ptr(),
305                self.ndim,
306            )
307        }
308    }
309}
310
311#[inline]
312fn layout(ndim: usize) -> Layout {
313    Layout::array::<usize>(1 + ndim * 2).unwrap()
314}