1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4pub struct ArrayLayout<const N: usize> {
6 ndim: usize,
7 content: Union<N>,
8}
9
10unsafe impl<const N: usize> Send for ArrayLayout<N> {}
11unsafe impl<const N: usize> Sync for ArrayLayout<N> {}
12
13union Union<const N: usize> {
14 ptr: NonNull<usize>,
15 _inlined: (isize, [usize; N], [isize; N]),
16}
17
18impl<const N: usize> Clone for ArrayLayout<N> {
19 #[inline]
20 fn clone(&self) -> Self {
21 Self::new(self.shape(), self.strides(), self.offset())
22 }
23}
24
25impl<const N: usize> PartialEq for ArrayLayout<N> {
26 #[inline]
27 fn eq(&self, other: &Self) -> bool {
28 self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice()
29 }
30}
31
32impl<const N: usize> Eq for ArrayLayout<N> {}
33
34impl<const N: usize> Drop for ArrayLayout<N> {
35 fn drop(&mut self) {
36 if let Some(ptr) = self.ptr_allocated() {
37 unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) }
38 }
39 }
40}
41
42#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
44pub enum Endian {
45 BigEndian,
47 LittleEndian,
49}
50
51impl<const N: usize> ArrayLayout<N> {
52 pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self {
62 assert_eq!(
64 shape.len(),
65 strides.len(),
66 "shape and strides must have the same length"
67 );
68
69 let mut ans = Self::with_ndim(shape.len());
70 let mut content = ans.content_mut();
71 content.set_offset(offset);
72 content.copy_shape(shape);
73 content.copy_strides(strides);
74 ans
75 }
76
77 pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self {
87 let mut ans = Self::with_ndim(shape.len());
88 let mut content = ans.content_mut();
89 content.set_offset(0);
90 content.copy_shape(shape);
91 let mut mul = element_size as isize;
92 let push = |i| {
93 content.set_stride(i, mul);
94 mul *= shape[i] as isize;
95 };
96 match endian {
97 Endian::BigEndian => (0..shape.len()).rev().for_each(push),
98 Endian::LittleEndian => (0..shape.len()).for_each(push),
99 }
100 ans
101 }
102
103 #[inline]
105 pub const fn ndim(&self) -> usize {
106 self.ndim
107 }
108
109 #[inline]
111 pub fn offset(&self) -> isize {
112 self.content().offset()
113 }
114
115 #[inline]
117 pub fn shape(&self) -> &[usize] {
118 self.content().shape()
119 }
120
121 #[inline]
123 pub fn strides(&self) -> &[isize] {
124 self.content().strides()
125 }
126
127 #[inline]
135 pub fn num_elements(&self) -> usize {
136 self.shape().iter().product()
137 }
138
139 pub fn element_offset(&self, index: usize, endian: Endian) -> isize {
147 fn offset_forwards(
148 mut rem: usize,
149 shape: impl IntoIterator<Item = usize>,
150 strides: impl IntoIterator<Item = isize>,
151 ) -> isize {
152 let mut ans = 0;
153 for (d, s) in zip(shape, strides) {
154 ans += s * (rem % d) as isize;
155 rem /= d
156 }
157 ans
158 }
159
160 let shape = self.shape().iter().cloned();
161 let strides = self.strides().iter().cloned();
162 self.offset()
163 + match endian {
164 Endian::BigEndian => offset_forwards(index, shape.rev(), strides.rev()),
165 Endian::LittleEndian => offset_forwards(index, shape, strides),
166 }
167 }
168
169 pub fn data_range(&self) -> RangeInclusive<isize> {
171 let content = self.content();
172 let mut start = content.offset();
173 let mut end = content.offset();
174 for (&d, s) in zip(content.shape(), content.strides()) {
175 use std::cmp::Ordering::{Equal, Greater, Less};
176 let i = d as isize - 1;
177 match s.cmp(&0) {
178 Equal => {}
179 Less => start += s * i,
180 Greater => end += s * i,
181 }
182 }
183 start..=end
184 }
185}
186
187mod transform;
188pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg};
189
190use std::{
191 alloc::{Layout, alloc, dealloc},
192 iter::zip,
193 ops::RangeInclusive,
194 ptr::{NonNull, copy_nonoverlapping},
195 slice::from_raw_parts,
196};
197
198impl<const N: usize> ArrayLayout<N> {
199 #[inline]
200 fn ptr_allocated(&self) -> Option<NonNull<usize>> {
201 const { assert!(N > 0) }
202 if self.ndim > N {
203 Some(unsafe { self.content.ptr })
204 } else {
205 None
206 }
207 }
208
209 #[inline]
210 fn content(&self) -> Content<false> {
211 Content {
212 ptr: self
213 .ptr_allocated()
214 .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
215 ndim: self.ndim,
216 }
217 }
218
219 #[inline]
220 fn content_mut(&mut self) -> Content<true> {
221 Content {
222 ptr: self
223 .ptr_allocated()
224 .unwrap_or(unsafe { NonNull::new_unchecked(&self.content as *const _ as _) }),
225 ndim: self.ndim,
226 }
227 }
228
229 #[inline]
231 fn with_ndim(ndim: usize) -> Self {
232 Self {
233 ndim,
234 content: if ndim <= N {
235 Union {
236 _inlined: (0, [0; N], [0; N]),
237 }
238 } else {
239 Union {
240 ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) },
241 }
242 },
243 }
244 }
245}
246
247struct Content<const MUT: bool> {
248 ptr: NonNull<usize>,
249 ndim: usize,
250}
251
252impl<const MUT: bool> Content<MUT> {
253 #[inline]
254 fn as_slice(&self) -> &[usize] {
255 unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) }
256 }
257
258 #[inline]
259 fn offset(&self) -> isize {
260 unsafe { self.ptr.cast().read() }
261 }
262
263 #[inline]
264 fn shape<'a>(&self) -> &'a [usize] {
265 unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) }
266 }
267
268 #[inline]
269 fn strides<'a>(&self) -> &'a [isize] {
270 unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) }
271 }
272}
273
274impl Content<true> {
275 #[inline]
276 fn set_offset(&mut self, val: isize) {
277 unsafe { self.ptr.cast().write(val) }
278 }
279
280 #[inline]
281 fn set_shape(&mut self, idx: usize, val: usize) {
282 assert!(idx < self.ndim);
283 unsafe { self.ptr.add(1 + idx).write(val) }
284 }
285
286 #[inline]
287 fn set_stride(&mut self, idx: usize, val: isize) {
288 assert!(idx < self.ndim);
289 unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) }
290 }
291
292 #[inline]
293 fn copy_shape(&mut self, val: &[usize]) {
294 assert!(val.len() == self.ndim);
295 unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) }
296 }
297
298 #[inline]
299 fn copy_strides(&mut self, val: &[isize]) {
300 assert!(val.len() == self.ndim);
301 unsafe {
302 copy_nonoverlapping(
303 val.as_ptr(),
304 self.ptr.add(1 + self.ndim).cast().as_ptr(),
305 self.ndim,
306 )
307 }
308 }
309}
310
311#[inline]
312fn layout(ndim: usize) -> Layout {
313 Layout::array::<usize>(1 + ndim * 2).unwrap()
314}