acme_tensor/shape/
shape.rs

1/*
2   Appellation: shape <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::{Axis, Rank, ShapeError, Stride};
6use crate::iter::zip;
7use crate::prelude::{Ixs, SwapAxes};
8#[cfg(not(feature = "std"))]
9use alloc::vec;
10use core::ops::{self, Deref};
11#[cfg(feature = "std")]
12use std::vec;
13
14/// A shape is a description of the number of elements in each dimension.
15#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
16#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize,))]
17pub struct Shape(Vec<usize>);
18
19impl Shape {
20    pub fn new(shape: Vec<usize>) -> Self {
21        Self(shape)
22    }
23    /// Creates a new shape of rank 0.
24    pub fn scalar() -> Self {
25        Self(Vec::new())
26    }
27
28    pub fn stride_offset(index: &[usize], strides: &Stride) -> Ixs {
29        let mut offset = 0;
30        for (&i, &s) in index.iter().zip(strides.as_slice()) {
31            offset += super::dim::stride_offset(i, s);
32        }
33        offset
34    }
35
36    pub fn with_capacity(capacity: usize) -> Self {
37        Self(Vec::with_capacity(capacity))
38    }
39    /// Creates a new shape of the given rank with all dimensions set to 0.
40    pub fn zeros(rank: usize) -> Self {
41        Self(vec![0; rank])
42    }
43    /// Get a reference to the shape as a slice.
44    pub fn as_slice(&self) -> &[usize] {
45        &self.0
46    }
47    /// Get a mutable reference to the shape as a slice.
48    pub fn as_slice_mut(&mut self) -> &mut [usize] {
49        &mut self.0
50    }
51
52    pub fn check_size(&self) -> Result<usize, ShapeError> {
53        let size_nonzero = self
54            .as_slice()
55            .iter()
56            .filter(|&&d| d != 0)
57            .try_fold(1usize, |acc, &d| acc.checked_mul(d))
58            .ok_or(ShapeError::Overflow)?;
59        if size_nonzero > core::isize::MAX as usize {
60            Err(ShapeError::Overflow)
61        } else {
62            Ok(self.size())
63        }
64    }
65    /// Decrement the dimensions of the shape by 1,
66    /// returning a new shape.
67    pub fn dec(&self) -> Self {
68        let mut shape = self.clone();
69        shape.dec_inplace();
70        shape
71    }
72    /// Decrement the dimensions of the shape by 1, inplace.
73    pub fn dec_inplace(&mut self) {
74        for dim in self.iter_mut() {
75            *dim -= 1;
76        }
77    }
78    /// Decrement the dimension at the given [Axis] by 1.
79    pub fn dec_axis(&mut self, axis: Axis) {
80        self[axis] -= 1;
81    }
82    /// Attempts to create a one-dimensional shape that describes the
83    /// diagonal of the current shape.
84    pub fn diag(&self) -> Shape {
85        Self::new(i![self.nrows()])
86    }
87
88    pub fn first_index(&self) -> Option<Vec<usize>> {
89        if self.is_empty() {
90            return None;
91        }
92        Some(vec![0; *self.rank()])
93    }
94    pub fn get_final_position(&self) -> Vec<usize> {
95        self.dec().to_vec()
96    }
97    /// Inserts a new dimension along the given [Axis], inplace.
98    pub fn insert(&mut self, index: Axis, dim: usize) {
99        self.0.insert(*index, dim)
100    }
101    /// Inserts a new dimension along the given [Axis].
102    pub fn insert_axis(&self, index: Axis) -> Self {
103        let mut shape = self.clone();
104        shape.insert(index, 1);
105        shape
106    }
107    /// Returns true if the strides are C contiguous (aka row major).
108    pub fn is_contiguous(&self, stride: &Stride) -> bool {
109        if self.0.len() != stride.len() {
110            return false;
111        }
112        let mut acc = 1;
113        for (&stride, &dim) in stride.iter().zip(self.iter()).rev() {
114            if stride != acc {
115                return false;
116            }
117            acc *= dim;
118        }
119        true
120    }
121    /// Returns true if the shape is a scalar.
122    pub fn is_scalar(&self) -> bool {
123        self.0.is_empty()
124    }
125    /// Checks to see if the shape is square
126    pub fn is_square(&self) -> bool {
127        self.iter().all(|&dim| dim == self[0])
128    }
129    /// Creates an immutable iterator over the elements of the shape
130    pub fn iter(&self) -> core::slice::Iter<usize> {
131        self.0.iter()
132    }
133    /// Creates a mutable iterator over the elements of the shape.
134    pub fn iter_mut(&mut self) -> core::slice::IterMut<usize> {
135        self.0.iter_mut()
136    }
137    /// The number of columns in the shape.
138    pub fn ncols(&self) -> usize {
139        if self.len() >= 2 {
140            self[1]
141        } else if self.len() == 1 {
142            1
143        } else {
144            0
145        }
146    }
147    #[doc(hidden)]
148    /// Iteration -- Use self as size, and return next index after `index`
149    /// or None if there are no more.
150    // FIXME: use &Self for index or even &mut?
151    #[inline]
152    pub fn next_for<D>(&self, index: D) -> Option<Vec<usize>>
153    where
154        D: AsRef<[usize]>,
155    {
156        let mut index = index.as_ref().to_vec();
157        let mut done = false;
158        for (&dim, ix) in zip(self.as_slice(), index.as_mut_slice()).rev() {
159            *ix += 1;
160            if *ix == dim {
161                *ix = 0;
162            } else {
163                done = true;
164                break;
165            }
166        }
167        if done {
168            Some(index)
169        } else {
170            None
171        }
172    }
173    /// The number of rows in the shape.
174    pub fn nrows(&self) -> usize {
175        if self.len() >= 1 {
176            self[0]
177        } else {
178            0
179        }
180    }
181    /// Removes and returns the last dimension of the shape.
182    pub fn pop(&mut self) -> Option<usize> {
183        self.0.pop()
184    }
185    /// Add a new dimension to the shape.
186    pub fn push(&mut self, dim: usize) {
187        self.0.push(dim)
188    }
189    /// Get the number of dimensions, or [Rank], of the shape
190    pub fn rank(&self) -> Rank {
191        self.0.len().into()
192    }
193    /// Remove the dimension at the given [Axis],
194    pub fn remove(&mut self, index: Axis) -> usize {
195        self.0.remove(*index)
196    }
197    /// Remove the dimension at the given [Axis].
198    pub fn remove_axis(&self, index: Axis) -> Shape {
199        let mut shape = self.clone();
200        shape.remove(index);
201        shape
202    }
203    /// Reverse the dimensions of the shape.
204    pub fn reverse(&mut self) {
205        self.0.reverse()
206    }
207    /// Set the dimension at the given [Axis].
208    pub fn set(&mut self, index: Axis, dim: usize) {
209        self[index] = dim
210    }
211    /// The number of elements in the shape.
212    pub fn size(&self) -> usize {
213        self.0.iter().product()
214    }
215    /// Swap the dimensions of the current [Shape] at the given [Axis].
216    pub fn swap(&mut self, a: Axis, b: Axis) {
217        self.0.swap(a.axis(), b.axis())
218    }
219    /// Swap the dimensions at the given [Axis], creating a new [Shape]
220    pub fn swap_axes(&self, swap: Axis, with: Axis) -> Self {
221        let mut shape = self.clone();
222        shape.swap(swap, with);
223        shape
224    }
225    /// A utilitarian function for converting the shape to a vector.
226    pub fn to_vec(&self) -> Vec<usize> {
227        self.0.clone()
228    }
229}
230
231// Internal methods
232#[allow(dead_code)]
233#[doc(hidden)]
234impl Shape {
235    pub fn default_strides(&self) -> Stride {
236        // Compute default array strides
237        // Shape (a, b, c) => Give strides (b * c, c, 1)
238        let mut strides = Stride::zeros(self.rank());
239        // For empty arrays, use all zero strides.
240        if self.iter().all(|&d| d != 0) {
241            let mut it = strides.as_slice_mut().iter_mut().rev();
242            // Set first element to 1
243            if let Some(rs) = it.next() {
244                *rs = 1;
245            }
246            let mut cum_prod = 1;
247            for (rs, dim) in it.zip(self.iter().rev()) {
248                cum_prod *= *dim;
249                *rs = cum_prod;
250            }
251        }
252        strides
253    }
254
255    pub(crate) fn matmul(&self, other: &Self) -> Result<Self, ShapeError> {
256        if self.rank() == 2 && other.rank() == 2 {
257            return Ok(Self::from((self[0], other[1])));
258        } else if self.rank() == 2 && other.rank() == 1 {
259            return Ok(Self::from(self[0]));
260        } else if self.rank() == 1 && other.rank() == 2 {
261            return Ok(Self::from(other[0]));
262        } else if self.rank() == 1 && other.rank() == 1 {
263            return Ok(Self::scalar());
264        }
265        Err(ShapeError::IncompatibleShapes)
266    }
267
268    pub(crate) fn matmul_shape(&self, other: &Self) -> Result<Self, ShapeError> {
269        if *self.rank() != 2 || *other.rank() != 2 || self[1] != other[0] {
270            return Err(ShapeError::IncompatibleShapes);
271        }
272        Ok(Self::from((self[0], other[1])))
273    }
274
275    pub(crate) fn stride_contiguous(&self) -> Stride {
276        let mut stride: Vec<_> = self
277            .0
278            .iter()
279            .rev()
280            .scan(1, |prod, u| {
281                let prod_pre_mult = *prod;
282                *prod *= u;
283                Some(prod_pre_mult)
284            })
285            .collect();
286        stride.reverse();
287        stride.into()
288    }
289
290    pub(crate) fn upcast(&self, to: &Shape, stride: &Stride) -> Option<Stride> {
291        let mut new_stride = to.as_slice().to_vec();
292        // begin at the back (the least significant dimension)
293        // size of the axis has to either agree or `from` has to be 1
294        if to.rank() < self.rank() {
295            return None;
296        }
297
298        let mut iter = new_stride.as_mut_slice().iter_mut().rev();
299        for ((er, es), dr) in self
300            .as_slice()
301            .iter()
302            .rev()
303            .zip(stride.as_slice().iter().rev())
304            .zip(iter.by_ref())
305        {
306            /* update strides */
307            if *dr == *er {
308                /* keep stride */
309                *dr = *es;
310            } else if *er == 1 {
311                /* dead dimension, zero stride */
312                *dr = 0
313            } else {
314                return None;
315            }
316        }
317
318        /* set remaining strides to zero */
319        for dr in iter {
320            *dr = 0;
321        }
322
323        Some(new_stride.into())
324    }
325}
326
327impl AsRef<[usize]> for Shape {
328    fn as_ref(&self) -> &[usize] {
329        &self.0
330    }
331}
332
333impl AsMut<[usize]> for Shape {
334    fn as_mut(&mut self) -> &mut [usize] {
335        &mut self.0
336    }
337}
338
339impl Deref for Shape {
340    type Target = [usize];
341
342    fn deref(&self) -> &Self::Target {
343        &self.0
344    }
345}
346
347impl Extend<usize> for Shape {
348    fn extend<I: IntoIterator<Item = usize>>(&mut self, iter: I) {
349        self.0.extend(iter)
350    }
351}
352
353impl From<Shape> for Vec<usize> {
354    fn from(shape: Shape) -> Self {
355        shape.0
356    }
357}
358
359impl_partial_eq!(Shape -> 0: [[usize], Vec<usize>]);
360
361impl SwapAxes for Shape {
362    fn swap_axes(&self, a: Axis, b: Axis) -> Self {
363        self.swap_axes(a, b)
364    }
365}
366
367impl FromIterator<usize> for Shape {
368    fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
369        Self(Vec::from_iter(iter))
370    }
371}
372
373impl IntoIterator for Shape {
374    type Item = usize;
375    type IntoIter = vec::IntoIter<Self::Item>;
376
377    fn into_iter(self) -> Self::IntoIter {
378        self.0.into_iter()
379    }
380}
381
382impl<'a> IntoIterator for &'a mut Shape {
383    type Item = &'a mut usize;
384    type IntoIter = core::slice::IterMut<'a, usize>;
385
386    fn into_iter(self) -> Self::IntoIter {
387        self.0.iter_mut()
388    }
389}
390
391impl ops::Index<usize> for Shape {
392    type Output = usize;
393
394    fn index(&self, index: usize) -> &Self::Output {
395        &self.0[index]
396    }
397}
398
399impl ops::Index<Axis> for Shape {
400    type Output = usize;
401
402    fn index(&self, index: Axis) -> &Self::Output {
403        &self.0[*index]
404    }
405}
406
407impl ops::IndexMut<usize> for Shape {
408    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
409        &mut self.0[index]
410    }
411}
412
413impl ops::IndexMut<Axis> for Shape {
414    fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
415        &mut self.0[*index]
416    }
417}
418
419impl ops::Index<ops::Range<usize>> for Shape {
420    type Output = [usize];
421
422    fn index(&self, index: ops::Range<usize>) -> &Self::Output {
423        &self.0[index]
424    }
425}
426
427impl ops::Index<ops::RangeTo<usize>> for Shape {
428    type Output = [usize];
429
430    fn index(&self, index: ops::RangeTo<usize>) -> &Self::Output {
431        &self.0[index]
432    }
433}
434
435impl ops::Index<ops::RangeFrom<usize>> for Shape {
436    type Output = [usize];
437
438    fn index(&self, index: ops::RangeFrom<usize>) -> &Self::Output {
439        &self.0[index]
440    }
441}
442
443impl ops::Index<ops::RangeFull> for Shape {
444    type Output = [usize];
445
446    fn index(&self, index: ops::RangeFull) -> &Self::Output {
447        &self.0[index]
448    }
449}
450
451impl ops::Index<ops::RangeInclusive<usize>> for Shape {
452    type Output = [usize];
453
454    fn index(&self, index: ops::RangeInclusive<usize>) -> &Self::Output {
455        &self.0[index]
456    }
457}
458
459impl ops::Index<ops::RangeToInclusive<usize>> for Shape {
460    type Output = [usize];
461
462    fn index(&self, index: ops::RangeToInclusive<usize>) -> &Self::Output {
463        &self.0[index]
464    }
465}
466
467unsafe impl Send for Shape {}
468
469unsafe impl Sync for Shape {}
470
471impl From<()> for Shape {
472    fn from(_: ()) -> Self {
473        Self::default()
474    }
475}
476
477impl From<usize> for Shape {
478    fn from(dim: usize) -> Self {
479        Self(vec![dim])
480    }
481}
482
483impl From<Vec<usize>> for Shape {
484    fn from(shape: Vec<usize>) -> Self {
485        Self(shape)
486    }
487}
488
489impl From<&[usize]> for Shape {
490    fn from(shape: &[usize]) -> Self {
491        Self(shape.to_vec())
492    }
493}
494
495impl<const N: usize> From<[usize; N]> for Shape {
496    fn from(shape: [usize; N]) -> Self {
497        Self(shape.to_vec())
498    }
499}
500
501impl From<(usize,)> for Shape {
502    fn from(shape: (usize,)) -> Self {
503        Self(vec![shape.0])
504    }
505}
506
507impl From<(usize, usize)> for Shape {
508    fn from(shape: (usize, usize)) -> Self {
509        Self(vec![shape.0, shape.1])
510    }
511}
512
513impl From<(usize, usize, usize)> for Shape {
514    fn from(shape: (usize, usize, usize)) -> Self {
515        Self(vec![shape.0, shape.1, shape.2])
516    }
517}
518
519impl From<(usize, usize, usize, usize)> for Shape {
520    fn from(shape: (usize, usize, usize, usize)) -> Self {
521        Self(vec![shape.0, shape.1, shape.2, shape.3])
522    }
523}
524
525impl From<(usize, usize, usize, usize, usize)> for Shape {
526    fn from(shape: (usize, usize, usize, usize, usize)) -> Self {
527        Self(vec![shape.0, shape.1, shape.2, shape.3, shape.4])
528    }
529}
530
531impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
532    fn from(shape: (usize, usize, usize, usize, usize, usize)) -> Self {
533        Self(vec![shape.0, shape.1, shape.2, shape.3, shape.4, shape.5])
534    }
535}
536
537// macro_rules! tuple_vec {
538//     ($($n:tt),*) => {
539//         vec![$($n,)*]
540//     };
541
542// }
543
544// macro_rules! impl_from_tuple {
545//     ($($n:tt: $name:ident),+) => {
546//         impl<$($name),+> From<($($name,)+)> for Shape
547//         where
548//             $($name: Into<usize>,)+
549//         {
550//             fn from(shape: ($($name,)+)) -> Self {
551//                 Self(vec![$($name.into(),)+])
552//             }
553//         }
554//     };
555// }
556
557// impl_from_tuple!(A: A);