1use crate::{Error, Result, Shape};
4
5#[derive(Debug, PartialEq, Eq, Clone)]
6pub struct Layout {
7 shape: Shape,
8 stride: Vec<usize>,
10 start_offset: usize,
11}
12
13impl Layout {
14 pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
15 Self {
16 shape,
17 stride,
18 start_offset,
19 }
20 }
21
22 pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
23 let shape = shape.into();
24 let stride = shape.stride_contiguous();
25 Self {
26 shape,
27 stride,
28 start_offset,
29 }
30 }
31
32 pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
33 Self::contiguous_with_offset(shape, 0)
34 }
35
36 pub fn dims(&self) -> &[usize] {
37 self.shape.dims()
38 }
39
40 pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
42 let dim = dim.to_index(&self.shape, "dim")?;
43 Ok(self.dims()[dim])
44 }
45
46 pub fn shape(&self) -> &Shape {
47 &self.shape
48 }
49
50 pub fn stride(&self) -> &[usize] {
51 &self.stride
52 }
53
54 pub fn start_offset(&self) -> usize {
55 self.start_offset
56 }
57
58 pub(crate) fn outer_stride_for_dim(&self, dim: usize) -> Option<usize> {
67 let dims = self.dims();
68 let strides = self.stride();
69
70 let mut expected = 1usize;
72 for i in (dim..dims.len()).rev() {
73 if strides[i] != expected {
74 return None;
75 }
76 expected *= dims[i];
77 }
78
79 if dim == 0 {
80 return Some(expected);
83 }
84
85 let outer_stride = strides[dim - 1];
87 let mut expected_outer = outer_stride;
88 for k in (0..dim - 1).rev() {
89 expected_outer *= dims[k + 1];
90 if strides[k] != expected_outer {
91 return None;
92 }
93 }
94
95 Some(outer_stride)
96 }
97
98 pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
101 if self.is_contiguous() {
102 let start_o = self.start_offset;
103 Some((start_o, start_o + self.shape.elem_count()))
104 } else {
105 None
106 }
107 }
108
109 pub fn is_contiguous(&self) -> bool {
113 self.shape.is_contiguous(&self.stride)
114 }
115
116 pub fn is_fortran_contiguous(&self) -> bool {
118 self.shape.is_fortran_contiguous(&self.stride)
119 }
120
121 pub fn is_scalar(&self) -> bool {
122 let dims = self.dims();
123 dims.is_empty() || dims.iter().all(|d| *d == 1)
124 }
125
126 pub fn is_scalar_broadcast(&self) -> bool {
128 self.stride().iter().all(|s| *s == 0)
129 }
130
131 pub fn is_scalar_like(&self) -> bool {
132 self.is_scalar() || self.is_scalar_broadcast()
133 }
134
135 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
136 let dims = self.shape().dims();
137 if dim >= dims.len() {
138 Err(Error::DimOutOfRange {
139 shape: self.shape().clone(),
140 dim: dim as i32,
141 op: "narrow",
142 }
143 .bt())?
144 }
145 if start + len > dims[dim] {
146 Err(Error::NarrowInvalidArgs {
147 shape: self.shape.clone(),
148 dim,
149 start,
150 len,
151 msg: "start + len > dim_len",
152 }
153 .bt())?
154 }
155 let mut dims = dims.to_vec();
156 dims[dim] = len;
157 Ok(Self {
158 shape: Shape::from(dims),
159 stride: self.stride.clone(),
160 start_offset: self.start_offset + self.stride[dim] * start,
161 })
162 }
163
164 pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
165 let rank = self.shape.rank();
166 if rank <= dim1 || rank <= dim2 {
167 Err(Error::UnexpectedNumberOfDims {
168 expected: usize::max(dim1, dim2),
169 got: rank,
170 shape: self.shape().clone(),
171 }
172 .bt())?
173 }
174 let mut stride = self.stride().to_vec();
175 let mut dims = self.shape().dims().to_vec();
176 dims.swap(dim1, dim2);
177 stride.swap(dim1, dim2);
178 Ok(Self {
179 shape: Shape::from(dims),
180 stride,
181 start_offset: self.start_offset,
182 })
183 }
184
185 pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
186 let is_permutation =
187 idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
188 if !is_permutation {
189 crate::bail!(
190 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
191 self.dims(),
192 idxs
193 )
194 }
195 let stride = self.stride();
196 let dims = self.shape().dims();
197 let mut perm_stride = stride.to_vec();
198 let mut perm_dims = dims.to_vec();
199 for (i, &idx) in idxs.iter().enumerate() {
200 perm_stride[i] = stride[idx];
201 perm_dims[i] = dims[idx];
202 }
203 Ok(Self {
204 shape: Shape::from(perm_dims),
205 stride: perm_stride,
206 start_offset: self.start_offset,
207 })
208 }
209
210 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
211 let shape = shape.into();
212 if shape.rank() < self.shape().rank() {
213 return Err(Error::BroadcastIncompatibleShapes {
214 src_shape: self.shape().clone(),
215 dst_shape: shape,
216 }
217 .bt());
218 }
219 let added_dims = shape.rank() - self.shape().rank();
220 let mut stride = vec![0; added_dims];
221 for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
222 .iter()
223 .zip(self.dims().iter().zip(self.stride()))
224 {
225 let s = if dst_dim == src_dim {
226 src_stride
227 } else if src_dim != 1 {
228 return Err(Error::BroadcastIncompatibleShapes {
229 src_shape: self.shape().clone(),
230 dst_shape: shape,
231 }
232 .bt());
233 } else {
234 0
235 };
236 stride.push(s)
237 }
238 Ok(Self {
239 shape,
240 stride,
241 start_offset: self.start_offset,
242 })
243 }
244
245 pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> {
246 crate::StridedIndex::from_layout(self)
247 }
248
249 pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
250 let mut block_len = 1usize;
251 let mut contiguous_dims = 0usize; for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
253 if dim == 1 {
255 contiguous_dims += 1;
256 continue;
257 }
258 if stride != block_len {
259 break;
260 }
261 block_len *= dim;
262 contiguous_dims += 1;
263 }
264 let index_dims = self.dims().len() - contiguous_dims;
265 match index_dims {
266 0 => crate::StridedBlocks::SingleBlock {
267 start_offset: self.start_offset,
268 len: block_len,
269 },
270 1 => crate::StridedBlocks::UniformBlocks {
271 start_offset: self.start_offset,
272 block_len,
273 count: self.dims()[0],
274 src_stride: self.stride[0],
275 },
276 _ => {
277 let block_start_index = crate::StridedIndex::new(
278 &self.dims()[..index_dims],
279 &self.stride[..index_dims],
280 self.start_offset,
281 );
282 crate::StridedBlocks::MultipleBlocks {
283 block_start_index,
284 block_len,
285 }
286 }
287 }
288 }
289}