ndarray_layout/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4/// An array layout allow N dimensions inlined.
5pub struct ArrayLayout<const N: usize> {
6    ndim: usize,
7    content: Union<N>,
8}
9
10unsafe impl<const N: usize> Send for ArrayLayout<N> {}
11unsafe impl<const N: usize> Sync for ArrayLayout<N> {}
12
13union Union<const N: usize> {
14    ptr: NonNull<usize>,
15    _inlined: (isize, [usize; N], [isize; N]),
16}
17
18impl<const N: usize> Clone for ArrayLayout<N> {
19    #[inline]
20    fn clone(&self) -> Self {
21        Self::new(self.shape(), self.strides(), self.offset())
22    }
23}
24
25impl<const N: usize> PartialEq for ArrayLayout<N> {
26    #[inline]
27    fn eq(&self, other: &Self) -> bool {
28        self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice()
29    }
30}
31
32impl<const N: usize> Eq for ArrayLayout<N> {}
33
34impl<const N: usize> Drop for ArrayLayout<N> {
35    fn drop(&mut self) {
36        if let Some(ptr) = self.ptr_allocated() {
37            unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) }
38        }
39    }
40}
41
42/// 元信息存储顺序。
43#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
44pub enum Endian {
45    /// 大端序,范围更大的维度在元信息中更靠前的位置。
46    BigEndian,
47    /// 小端序,范围更小的维度在元信息中更靠前的位置。
48    LittleEndian,
49}
50
51impl<const N: usize> ArrayLayout<N> {
52    /// Creates a new Layout with the given shape, strides, and offset.
53    ///
54    /// ```rust
55    /// # use ndarray_layout::ArrayLayout;
56    /// let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
57    /// assert_eq!(layout.offset(), 20);
58    /// assert_eq!(layout.shape(), &[2, 3, 4]);
59    /// assert_eq!(layout.strides(), &[12, -4, 1]);
60    /// ```
61    pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self {
62        // check
63        assert_eq!(
64            shape.len(),
65            strides.len(),
66            "shape and strides must have the same length"
67        );
68
69        let mut ans = Self::with_ndim(shape.len());
70        let mut content = ans.content_mut();
71        content.set_offset(offset);
72        content.copy_shape(shape);
73        content.copy_strides(strides);
74        ans
75    }
76
77    /// Creates a new contiguous Layout with the given shape.
78    ///
79    /// ```rust
80    /// # use ndarray_layout::{Endian, ArrayLayout};
81    /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4);
82    /// assert_eq!(layout.offset(), 0);
83    /// assert_eq!(layout.shape(), &[2, 3, 4]);
84    /// assert_eq!(layout.strides(), &[4, 8, 24]);
85    /// ```
86    pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self {
87        let mut ans = Self::with_ndim(shape.len());
88        let mut content = ans.content_mut();
89        content.set_offset(0);
90        content.copy_shape(shape);
91        let mut mul = element_size as isize;
92        let push = |i| {
93            content.set_stride(i, mul);
94            mul *= shape[i] as isize;
95        };
96        match endian {
97            Endian::BigEndian => (0..shape.len()).rev().for_each(push),
98            Endian::LittleEndian => (0..shape.len()).for_each(push),
99        }
100        ans
101    }
102
103    /// Gets offset.
104    #[inline]
105    pub const fn ndim(&self) -> usize {
106        self.ndim
107    }
108
109    /// Gets offset.
110    #[inline]
111    pub fn offset(&self) -> isize {
112        self.content().offset()
113    }
114
115    /// Gets shape.
116    #[inline]
117    pub fn shape(&self) -> &[usize] {
118        self.content().shape()
119    }
120
121    /// Gets strides.
122    #[inline]
123    pub fn strides(&self) -> &[isize] {
124        self.content().strides()
125    }
126
127    /// Calculate the range of data in bytes to determine the location of the memory area that the tensor needs to access.
128    pub fn data_range(&self) -> RangeInclusive<isize> {
129        let content = self.content();
130        let mut start = content.offset();
131        let mut end = content.offset();
132        for (&d, s) in zip(content.shape(), content.strides()) {
133            use std::cmp::Ordering::{Equal, Greater, Less};
134            let i = d as isize - 1;
135            match s.cmp(&0) {
136                Equal => {}
137                Less => start += s * i,
138                Greater => end += s * i,
139            }
140        }
141        start..=end
142    }
143}
144
145mod transform;
146pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg};
147
148use std::{
149    alloc::{alloc, dealloc, Layout},
150    iter::zip,
151    ops::RangeInclusive,
152    ptr::{copy_nonoverlapping, NonNull},
153    slice::from_raw_parts,
154};
155
156impl<const N: usize> ArrayLayout<N> {
157    #[inline]
158    fn ptr_allocated(&self) -> Option<NonNull<usize>> {
159        const { assert!(N > 0) }
160        if self.ndim > N {
161            Some(unsafe { self.content.ptr })
162        } else {
163            None
164        }
165    }
166
167    #[inline]
168    fn content(&self) -> Content<false> {
169        Content {
170            ptr: self
171                .ptr_allocated()
172                .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
173            ndim: self.ndim,
174        }
175    }
176
177    #[inline]
178    fn content_mut(&mut self) -> Content<true> {
179        Content {
180            ptr: self
181                .ptr_allocated()
182                .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
183            ndim: self.ndim,
184        }
185    }
186
187    /// Create a new ArrayLayout with the given dimensions.
188    #[inline]
189    fn with_ndim(ndim: usize) -> Self {
190        Self {
191            ndim,
192            content: if ndim <= N {
193                Union {
194                    _inlined: (0, [0; N], [0; N]),
195                }
196            } else {
197                Union {
198                    ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) },
199                }
200            },
201        }
202    }
203}
204
205struct Content<const MUT: bool> {
206    ptr: NonNull<usize>,
207    ndim: usize,
208}
209
210impl<const MUT: bool> Content<MUT> {
211    #[inline]
212    fn as_slice(&self) -> &[usize] {
213        unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) }
214    }
215
216    #[inline]
217    fn offset(&self) -> isize {
218        unsafe { self.ptr.cast().read() }
219    }
220
221    #[inline]
222    fn shape<'a>(&self) -> &'a [usize] {
223        unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) }
224    }
225
226    #[inline]
227    fn strides<'a>(&self) -> &'a [isize] {
228        unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) }
229    }
230}
231
232impl Content<true> {
233    #[inline]
234    fn set_offset(&mut self, val: isize) {
235        unsafe { self.ptr.cast().write(val) }
236    }
237
238    #[inline]
239    fn set_shape(&mut self, idx: usize, val: usize) {
240        assert!(idx < self.ndim);
241        unsafe { self.ptr.add(1 + idx).write(val) }
242    }
243
244    #[inline]
245    fn set_stride(&mut self, idx: usize, val: isize) {
246        assert!(idx < self.ndim);
247        unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) }
248    }
249
250    #[inline]
251    fn copy_shape(&mut self, val: &[usize]) {
252        assert!(val.len() == self.ndim);
253        unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) }
254    }
255
256    #[inline]
257    fn copy_strides(&mut self, val: &[isize]) {
258        assert!(val.len() == self.ndim);
259        unsafe {
260            copy_nonoverlapping(
261                val.as_ptr(),
262                self.ptr.add(1 + self.ndim).cast().as_ptr(),
263                self.ndim,
264            )
265        }
266    }
267}
268
269#[inline]
270fn layout(ndim: usize) -> Layout {
271    Layout::array::<usize>(1 + ndim * 2).unwrap()
272}