1use crate::{Error, Result, Shape};
3
4#[derive(Debug, PartialEq, Eq, Clone)]
5pub struct Layout {
6 shape: Shape,
7 stride: Vec<usize>,
9 start_offset: usize,
10}
11
12impl Layout {
13 pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
14 Self {
15 shape,
16 stride,
17 start_offset,
18 }
19 }
20
21 pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
22 let shape = shape.into();
23 let stride = shape.stride_contiguous();
24 Self {
25 shape,
26 stride,
27 start_offset,
28 }
29 }
30
31 pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
32 Self::contiguous_with_offset(shape, 0)
33 }
34
35 pub fn dims(&self) -> &[usize] {
36 self.shape.dims()
37 }
38
39 pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
41 let dim = dim.to_index(&self.shape, "dim")?;
42 Ok(self.dims()[dim])
43 }
44
45 pub fn shape(&self) -> &Shape {
46 &self.shape
47 }
48
49 pub fn stride(&self) -> &[usize] {
50 &self.stride
51 }
52
53 pub fn start_offset(&self) -> usize {
54 self.start_offset
55 }
56
57 pub(crate) fn outer_stride_for_dim(&self, dim: usize) -> Option<usize> {
66 let dims = self.dims();
67 let strides = self.stride();
68
69 let mut expected = 1usize;
71 for i in (dim..dims.len()).rev() {
72 if strides[i] != expected {
73 return None;
74 }
75 expected *= dims[i];
76 }
77
78 if dim == 0 {
79 return Some(expected);
82 }
83
84 let outer_stride = strides[dim - 1];
86 let mut expected_outer = outer_stride;
87 for k in (0..dim - 1).rev() {
88 expected_outer *= dims[k + 1];
89 if strides[k] != expected_outer {
90 return None;
91 }
92 }
93
94 Some(outer_stride)
95 }
96
97 pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
100 if self.is_contiguous() {
101 let start_o = self.start_offset;
102 Some((start_o, start_o + self.shape.elem_count()))
103 } else {
104 None
105 }
106 }
107
108 pub fn is_contiguous(&self) -> bool {
112 self.shape.is_contiguous(&self.stride)
113 }
114
115 pub fn is_fortran_contiguous(&self) -> bool {
117 self.shape.is_fortran_contiguous(&self.stride)
118 }
119
120 pub fn is_scalar(&self) -> bool {
121 let dims = self.dims();
122 dims.is_empty() || dims.iter().all(|d| *d == 1)
123 }
124
125 pub fn is_scalar_broadcast(&self) -> bool {
127 self.stride().iter().all(|s| *s == 0)
128 }
129
130 pub fn is_scalar_like(&self) -> bool {
131 self.is_scalar() || self.is_scalar_broadcast()
132 }
133
134 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
135 let dims = self.shape().dims();
136 if dim >= dims.len() {
137 Err(Error::DimOutOfRange {
138 shape: self.shape().clone(),
139 dim: dim as i32,
140 op: "narrow",
141 }
142 .bt())?
143 }
144 if start + len > dims[dim] {
145 Err(Error::NarrowInvalidArgs {
146 shape: self.shape.clone(),
147 dim,
148 start,
149 len,
150 msg: "start + len > dim_len",
151 }
152 .bt())?
153 }
154 let mut dims = dims.to_vec();
155 dims[dim] = len;
156 Ok(Self {
157 shape: Shape::from(dims),
158 stride: self.stride.clone(),
159 start_offset: self.start_offset + self.stride[dim] * start,
160 })
161 }
162
163 pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
164 let rank = self.shape.rank();
165 if rank <= dim1 || rank <= dim2 {
166 Err(Error::UnexpectedNumberOfDims {
167 expected: usize::max(dim1, dim2),
168 got: rank,
169 shape: self.shape().clone(),
170 }
171 .bt())?
172 }
173 let mut stride = self.stride().to_vec();
174 let mut dims = self.shape().dims().to_vec();
175 dims.swap(dim1, dim2);
176 stride.swap(dim1, dim2);
177 Ok(Self {
178 shape: Shape::from(dims),
179 stride,
180 start_offset: self.start_offset,
181 })
182 }
183
184 pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
185 let is_permutation =
186 idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
187 if !is_permutation {
188 crate::bail!(
189 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
190 self.dims(),
191 idxs
192 )
193 }
194 let stride = self.stride();
195 let dims = self.shape().dims();
196 let mut perm_stride = stride.to_vec();
197 let mut perm_dims = dims.to_vec();
198 for (i, &idx) in idxs.iter().enumerate() {
199 perm_stride[i] = stride[idx];
200 perm_dims[i] = dims[idx];
201 }
202 Ok(Self {
203 shape: Shape::from(perm_dims),
204 stride: perm_stride,
205 start_offset: self.start_offset,
206 })
207 }
208
209 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
210 let shape = shape.into();
211 if shape.rank() < self.shape().rank() {
212 return Err(Error::BroadcastIncompatibleShapes {
213 src_shape: self.shape().clone(),
214 dst_shape: shape,
215 }
216 .bt());
217 }
218 let added_dims = shape.rank() - self.shape().rank();
219 let mut stride = vec![0; added_dims];
220 for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
221 .iter()
222 .zip(self.dims().iter().zip(self.stride()))
223 {
224 let s = if dst_dim == src_dim {
225 src_stride
226 } else if src_dim != 1 {
227 return Err(Error::BroadcastIncompatibleShapes {
228 src_shape: self.shape().clone(),
229 dst_shape: shape,
230 }
231 .bt());
232 } else {
233 0
234 };
235 stride.push(s)
236 }
237 Ok(Self {
238 shape,
239 stride,
240 start_offset: self.start_offset,
241 })
242 }
243
244 pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> {
245 crate::StridedIndex::from_layout(self)
246 }
247
248 pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
249 let mut block_len = 1usize;
250 let mut contiguous_dims = 0usize; for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
252 if dim == 1 {
254 contiguous_dims += 1;
255 continue;
256 }
257 if stride != block_len {
258 break;
259 }
260 block_len *= dim;
261 contiguous_dims += 1;
262 }
263 let index_dims = self.dims().len() - contiguous_dims;
264 match index_dims {
265 0 => crate::StridedBlocks::SingleBlock {
266 start_offset: self.start_offset,
267 len: block_len,
268 },
269 1 => crate::StridedBlocks::UniformBlocks {
270 start_offset: self.start_offset,
271 block_len,
272 count: self.dims()[0],
273 src_stride: self.stride[0],
274 },
275 _ => {
276 let block_start_index = crate::StridedIndex::new(
277 &self.dims()[..index_dims],
278 &self.stride[..index_dims],
279 self.start_offset,
280 );
281 crate::StridedBlocks::MultipleBlocks {
282 block_start_index,
283 block_len,
284 }
285 }
286 }
287 }
288
289 pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
291 let mut left_broadcast = 1;
292 let mut right_broadcast = 1;
293 let strides = self.stride();
294 let dims = self.dims();
295 let mut start_cont = 0;
296 let mut end_cont = dims.len();
297 for (&s, &d) in strides.iter().zip(dims.iter()) {
298 if s != 0 {
299 break;
300 }
301 start_cont += 1;
302 left_broadcast *= d;
303 }
304 if start_cont == dims.len() {
305 return Some(ContiguousOffsetsWithBroadcast {
306 start: self.start_offset,
307 len: 1,
308 left_broadcast,
309 right_broadcast: 1,
310 });
311 }
312 for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
313 if s != 0 {
314 break;
315 }
316 end_cont -= 1;
317 right_broadcast *= d;
318 }
319 let strides = &strides[start_cont..end_cont];
321 let dims = &dims[start_cont..end_cont];
322 let mut len = 1;
323 for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
324 if stride != len {
325 return None;
326 }
327 len *= dim;
328 }
329 Some(ContiguousOffsetsWithBroadcast {
330 start: self.start_offset,
331 len,
332 left_broadcast,
333 right_broadcast,
334 })
335 }
336}
337
338#[derive(Debug, Clone, PartialEq, Eq)]
339pub struct ContiguousOffsetsWithBroadcast {
340 pub start: usize,
341 pub len: usize,
342 pub left_broadcast: usize,
343 pub right_broadcast: usize,
344}