acme_tensor/shape/
layout.rs1use crate::iter::LayoutIter;
6use crate::shape::dim::stride_offset;
7use crate::shape::{Axis, IntoShape, IntoStride, Rank, Shape, ShapeError, ShapeResult, Stride};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
13#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
14pub struct Layout {
15 pub(crate) offset: usize,
16 pub(crate) shape: Shape,
17 pub(crate) strides: Stride,
18}
19
20impl Layout {
21 pub unsafe fn new(offset: usize, shape: impl IntoShape, strides: impl IntoStride) -> Self {
22 Self {
23 offset,
24 shape: shape.into_shape(),
25 strides: strides.into_stride(),
26 }
27 }
28 pub fn contiguous(shape: impl IntoShape) -> Self {
30 let shape = shape.into_shape();
31 let stride = shape.stride_contiguous();
32 Self {
33 offset: 0,
34 shape,
35 strides: stride,
36 }
37 }
38 pub fn scalar() -> Self {
40 Self::contiguous(())
41 }
42 #[doc(hidden)]
43 pub fn stride_offset(index: impl AsRef<[usize]>, strides: &Stride) -> isize {
45 let mut offset = 0;
46 for (&i, &s) in izip!(index.as_ref(), strides.as_slice()) {
47 offset += stride_offset(i, s);
48 }
49 offset
50 }
51 pub fn broadcast_as(&self, shape: impl IntoShape) -> ShapeResult<Self> {
55 let shape = shape.into_shape();
56 if shape.rank() < self.shape().rank() {
57 return Err(ShapeError::IncompatibleShapes);
58 }
59 let diff = shape.rank() - self.shape().rank();
60 let mut stride = vec![0; *diff];
61 for (&dst_dim, (&src_dim, &src_stride)) in shape[*diff..]
62 .iter()
63 .zip(self.shape().iter().zip(self.strides().iter()))
64 {
65 let s = if dst_dim == src_dim {
66 src_stride
67 } else if src_dim != 1 {
68 return Err(ShapeError::IncompatibleShapes);
69 } else {
70 0
71 };
72 stride.push(s)
73 }
74 let layout = unsafe { Layout::new(0, shape, stride) };
75 Ok(layout)
76 }
77 pub fn is_contiguous(&self) -> bool {
79 self.shape().is_contiguous(&self.strides)
80 }
81 pub fn is_scalar(&self) -> bool {
83 self.shape().is_scalar()
84 }
85 pub fn is_square(&self) -> bool {
88 self.shape().is_square()
89 }
90
91 pub fn iter(&self) -> LayoutIter {
92 LayoutIter::new(self.clone())
93 }
94 pub fn offset(&self) -> usize {
96 self.offset
97 }
98 pub fn offset_from_low_addr_ptr_to_logical_ptr(&self) -> usize {
101 let offset =
102 izip!(self.shape().as_slice(), self.strides().as_slice()).fold(0, |acc, (d, s)| {
103 let d = *d as isize;
104 let s = *s as isize;
105 if s < 0 && d > 1 {
106 acc - s * (d - 1)
107 } else {
108 acc
109 }
110 });
111 debug_assert!(offset >= 0);
112 offset as usize
113 }
114 pub fn rank(&self) -> Rank {
116 debug_assert_eq!(self.strides.len(), *self.shape.rank());
117 self.shape.rank()
118 }
119 pub fn remove_axis(&self, axis: Axis) -> Self {
121 Self {
122 offset: self.offset,
123 shape: self.shape().remove_axis(axis),
124 strides: self.strides().remove_axis(axis),
125 }
126 }
127 pub fn reshape(&mut self, shape: impl IntoShape) {
129 self.shape = shape.into_shape();
130 self.strides = self.shape.stride_contiguous();
131 }
132 pub fn reverse(&mut self) {
134 self.shape.reverse();
135 self.strides.reverse();
136 }
137 pub fn reverse_axes(mut self) -> Layout {
139 self.reverse();
140 self
141 }
142 pub const fn shape(&self) -> &Shape {
144 &self.shape
145 }
146 pub fn size(&self) -> usize {
148 self.shape().size()
149 }
150 pub const fn strides(&self) -> &Stride {
152 &self.strides
153 }
154 pub fn swap_axes(&self, a: Axis, b: Axis) -> Layout {
156 Layout {
157 offset: self.offset,
158 shape: self.shape.swap_axes(a, b),
159 strides: self.strides.swap_axes(a, b),
160 }
161 }
162 pub fn transpose(&self) -> Layout {
164 self.clone().reverse_axes()
165 }
166
167 pub fn with_offset(mut self, offset: usize) -> Self {
168 self.offset = offset;
169 self
170 }
171
172 pub fn with_shape_c(mut self, shape: impl IntoShape) -> Self {
173 self.shape = shape.into_shape();
174 self.strides = self.shape.stride_contiguous();
175 self
176 }
177
178 pub unsafe fn with_shape_unchecked(mut self, shape: impl IntoShape) -> Self {
179 self.shape = shape.into_shape();
180 self
181 }
182
183 pub unsafe fn with_strides_unchecked(mut self, stride: impl IntoStride) -> Self {
184 self.strides = stride.into_stride();
185 self
186 }
187}
188
189impl Layout {
191 pub(crate) fn index<Idx>(&self, idx: Idx) -> usize
192 where
193 Idx: AsRef<[usize]>,
194 {
195 let idx = idx.as_ref();
196 debug_assert_eq!(idx.len(), *self.rank(), "Dimension mismatch");
197 self.index_unchecked(idx)
198 }
199
200 pub(crate) fn index_unchecked<Idx>(&self, idx: Idx) -> usize
201 where
202 Idx: AsRef<[usize]>,
203 {
204 crate::coordinates_to_index::<Idx>(idx, self.strides())
205 }
206
207 pub(crate) fn _matmul(&self, rhs: &Layout) -> Result<Layout, ShapeError> {
208 let shape = self.shape().matmul(rhs.shape())?;
209 let layout = Layout {
210 offset: self.offset(),
211 shape,
212 strides: self.strides().clone(),
213 };
214 Ok(layout)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::Layout;
221
222 #[test]
223 fn test_position() {
224 let shape = (3, 3);
225 let layout = Layout::contiguous(shape);
226 assert_eq!(layout.index_unchecked([0, 0]), 0);
227 assert_eq!(layout.index([0, 1]), 1);
228 assert_eq!(layout.index([2, 2]), 8);
229 }
230}