acme_tensor/shape/
layout.rs

1/*
2    Appellation: layout <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::iter::LayoutIter;
6use crate::shape::dim::stride_offset;
7use crate::shape::{Axis, IntoShape, IntoStride, Rank, Shape, ShapeError, ShapeResult, Stride};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11/// The layout describes the memory layout of a tensor.
12#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
13#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
14pub struct Layout {
15    pub(crate) offset: usize,
16    pub(crate) shape: Shape,
17    pub(crate) strides: Stride,
18}
19
20impl Layout {
21    pub unsafe fn new(offset: usize, shape: impl IntoShape, strides: impl IntoStride) -> Self {
22        Self {
23            offset,
24            shape: shape.into_shape(),
25            strides: strides.into_stride(),
26        }
27    }
28    /// Create a new layout with a contiguous stride.
29    pub fn contiguous(shape: impl IntoShape) -> Self {
30        let shape = shape.into_shape();
31        let stride = shape.stride_contiguous();
32        Self {
33            offset: 0,
34            shape,
35            strides: stride,
36        }
37    }
38    /// Create a new layout with a scalar stride.
39    pub fn scalar() -> Self {
40        Self::contiguous(())
41    }
42    #[doc(hidden)]
43    /// Return stride offset for index.
44    pub fn stride_offset(index: impl AsRef<[usize]>, strides: &Stride) -> isize {
45        let mut offset = 0;
46        for (&i, &s) in izip!(index.as_ref(), strides.as_slice()) {
47            offset += stride_offset(i, s);
48        }
49        offset
50    }
51    /// Broadcast the layout to a new shape.
52    ///
53    /// The new shape must have the same or higher rank than the current shape.
54    pub fn broadcast_as(&self, shape: impl IntoShape) -> ShapeResult<Self> {
55        let shape = shape.into_shape();
56        if shape.rank() < self.shape().rank() {
57            return Err(ShapeError::IncompatibleShapes);
58        }
59        let diff = shape.rank() - self.shape().rank();
60        let mut stride = vec![0; *diff];
61        for (&dst_dim, (&src_dim, &src_stride)) in shape[*diff..]
62            .iter()
63            .zip(self.shape().iter().zip(self.strides().iter()))
64        {
65            let s = if dst_dim == src_dim {
66                src_stride
67            } else if src_dim != 1 {
68                return Err(ShapeError::IncompatibleShapes);
69            } else {
70                0
71            };
72            stride.push(s)
73        }
74        let layout = unsafe { Layout::new(0, shape, stride) };
75        Ok(layout)
76    }
77    /// Determine if the current layout is contiguous or not.
78    pub fn is_contiguous(&self) -> bool {
79        self.shape().is_contiguous(&self.strides)
80    }
81    /// Checks to see if the layout is empy; i.e. a scalar of Rank(0)
82    pub fn is_scalar(&self) -> bool {
83        self.shape().is_scalar()
84    }
85    /// A function for determining if the layout is square.
86    /// An n-dimensional object is square if all of its dimensions are equal.
87    pub fn is_square(&self) -> bool {
88        self.shape().is_square()
89    }
90
91    pub fn iter(&self) -> LayoutIter {
92        LayoutIter::new(self.clone())
93    }
94    /// Peek the offset of the layout.
95    pub fn offset(&self) -> usize {
96        self.offset
97    }
98    /// Returns the offset from the lowest-address element to the logically first
99    /// element.
100    pub fn offset_from_low_addr_ptr_to_logical_ptr(&self) -> usize {
101        let offset =
102            izip!(self.shape().as_slice(), self.strides().as_slice()).fold(0, |acc, (d, s)| {
103                let d = *d as isize;
104                let s = *s as isize;
105                if s < 0 && d > 1 {
106                    acc - s * (d - 1)
107                } else {
108                    acc
109                }
110            });
111        debug_assert!(offset >= 0);
112        offset as usize
113    }
114    /// Return the rank (number of dimensions) of the layout.
115    pub fn rank(&self) -> Rank {
116        debug_assert_eq!(self.strides.len(), *self.shape.rank());
117        self.shape.rank()
118    }
119    /// Remove an axis from the current layout, returning the new layout.
120    pub fn remove_axis(&self, axis: Axis) -> Self {
121        Self {
122            offset: self.offset,
123            shape: self.shape().remove_axis(axis),
124            strides: self.strides().remove_axis(axis),
125        }
126    }
127    /// Reshape the layout to a new shape.
128    pub fn reshape(&mut self, shape: impl IntoShape) {
129        self.shape = shape.into_shape();
130        self.strides = self.shape.stride_contiguous();
131    }
132    /// Reverse the order of the axes.
133    pub fn reverse(&mut self) {
134        self.shape.reverse();
135        self.strides.reverse();
136    }
137    /// Reverse the order of the axes.
138    pub fn reverse_axes(mut self) -> Layout {
139        self.reverse();
140        self
141    }
142    /// Get a reference to the shape of the layout.
143    pub const fn shape(&self) -> &Shape {
144        &self.shape
145    }
146    /// Get a reference to the number of elements in the layout.
147    pub fn size(&self) -> usize {
148        self.shape().size()
149    }
150    /// Get a reference to the stride of the layout.
151    pub const fn strides(&self) -> &Stride {
152        &self.strides
153    }
154    /// Swap the axes of the layout.
155    pub fn swap_axes(&self, a: Axis, b: Axis) -> Layout {
156        Layout {
157            offset: self.offset,
158            shape: self.shape.swap_axes(a, b),
159            strides: self.strides.swap_axes(a, b),
160        }
161    }
162    /// Transpose the layout.
163    pub fn transpose(&self) -> Layout {
164        self.clone().reverse_axes()
165    }
166
167    pub fn with_offset(mut self, offset: usize) -> Self {
168        self.offset = offset;
169        self
170    }
171
172    pub fn with_shape_c(mut self, shape: impl IntoShape) -> Self {
173        self.shape = shape.into_shape();
174        self.strides = self.shape.stride_contiguous();
175        self
176    }
177
178    pub unsafe fn with_shape_unchecked(mut self, shape: impl IntoShape) -> Self {
179        self.shape = shape.into_shape();
180        self
181    }
182
183    pub unsafe fn with_strides_unchecked(mut self, stride: impl IntoStride) -> Self {
184        self.strides = stride.into_stride();
185        self
186    }
187}
188
189// Internal methods
190impl Layout {
191    pub(crate) fn index<Idx>(&self, idx: Idx) -> usize
192    where
193        Idx: AsRef<[usize]>,
194    {
195        let idx = idx.as_ref();
196        debug_assert_eq!(idx.len(), *self.rank(), "Dimension mismatch");
197        self.index_unchecked(idx)
198    }
199
200    pub(crate) fn index_unchecked<Idx>(&self, idx: Idx) -> usize
201    where
202        Idx: AsRef<[usize]>,
203    {
204        crate::coordinates_to_index::<Idx>(idx, self.strides())
205    }
206
207    pub(crate) fn _matmul(&self, rhs: &Layout) -> Result<Layout, ShapeError> {
208        let shape = self.shape().matmul(rhs.shape())?;
209        let layout = Layout {
210            offset: self.offset(),
211            shape,
212            strides: self.strides().clone(),
213        };
214        Ok(layout)
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::Layout;
221
222    #[test]
223    fn test_position() {
224        let shape = (3, 3);
225        let layout = Layout::contiguous(shape);
226        assert_eq!(layout.index_unchecked([0, 0]), 0);
227        assert_eq!(layout.index([0, 1]), 1);
228        assert_eq!(layout.index([2, 2]), 8);
229    }
230}