1use crate::{Error, Result, Shape};
3
4#[derive(Debug, PartialEq, Eq, Clone)]
5pub struct Layout {
6 shape: Shape,
7 stride: Vec<usize>,
9 start_offset: usize,
10}
11
12impl Layout {
13 pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
14 Self {
15 shape,
16 stride,
17 start_offset,
18 }
19 }
20
21 pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
22 let shape = shape.into();
23 let stride = shape.stride_contiguous();
24 Self {
25 shape,
26 stride,
27 start_offset,
28 }
29 }
30
31 pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
32 Self::contiguous_with_offset(shape, 0)
33 }
34
35 pub fn dims(&self) -> &[usize] {
36 self.shape.dims()
37 }
38
39 pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
41 let dim = dim.to_index(&self.shape, "dim")?;
42 Ok(self.dims()[dim])
43 }
44
45 pub fn shape(&self) -> &Shape {
46 &self.shape
47 }
48
49 pub fn stride(&self) -> &[usize] {
50 &self.stride
51 }
52
53 pub fn start_offset(&self) -> usize {
54 self.start_offset
55 }
56
57 pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
60 if self.is_contiguous() {
61 let start_o = self.start_offset;
62 Some((start_o, start_o + self.shape.elem_count()))
63 } else {
64 None
65 }
66 }
67
68 pub fn is_contiguous(&self) -> bool {
72 self.shape.is_contiguous(&self.stride)
73 }
74
75 pub fn is_fortran_contiguous(&self) -> bool {
77 self.shape.is_fortran_contiguous(&self.stride)
78 }
79
80 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
81 let dims = self.shape().dims();
82 if dim >= dims.len() {
83 Err(Error::DimOutOfRange {
84 shape: self.shape().clone(),
85 dim: dim as i32,
86 op: "narrow",
87 }
88 .bt())?
89 }
90 if start + len > dims[dim] {
91 Err(Error::NarrowInvalidArgs {
92 shape: self.shape.clone(),
93 dim,
94 start,
95 len,
96 msg: "start + len > dim_len",
97 }
98 .bt())?
99 }
100 let mut dims = dims.to_vec();
101 dims[dim] = len;
102 Ok(Self {
103 shape: Shape::from(dims),
104 stride: self.stride.clone(),
105 start_offset: self.start_offset + self.stride[dim] * start,
106 })
107 }
108
109 pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
110 let rank = self.shape.rank();
111 if rank <= dim1 || rank <= dim2 {
112 Err(Error::UnexpectedNumberOfDims {
113 expected: usize::max(dim1, dim2),
114 got: rank,
115 shape: self.shape().clone(),
116 }
117 .bt())?
118 }
119 let mut stride = self.stride().to_vec();
120 let mut dims = self.shape().dims().to_vec();
121 dims.swap(dim1, dim2);
122 stride.swap(dim1, dim2);
123 Ok(Self {
124 shape: Shape::from(dims),
125 stride,
126 start_offset: self.start_offset,
127 })
128 }
129
130 pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
131 let is_permutation =
132 idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
133 if !is_permutation {
134 crate::bail!(
135 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
136 self.dims(),
137 idxs
138 )
139 }
140 let stride = self.stride();
141 let dims = self.shape().dims();
142 let mut perm_stride = stride.to_vec();
143 let mut perm_dims = dims.to_vec();
144 for (i, &idx) in idxs.iter().enumerate() {
145 perm_stride[i] = stride[idx];
146 perm_dims[i] = dims[idx];
147 }
148 Ok(Self {
149 shape: Shape::from(perm_dims),
150 stride: perm_stride,
151 start_offset: self.start_offset,
152 })
153 }
154
155 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
156 let shape = shape.into();
157 if shape.rank() < self.shape().rank() {
158 return Err(Error::BroadcastIncompatibleShapes {
159 src_shape: self.shape().clone(),
160 dst_shape: shape,
161 }
162 .bt());
163 }
164 let added_dims = shape.rank() - self.shape().rank();
165 let mut stride = vec![0; added_dims];
166 for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
167 .iter()
168 .zip(self.dims().iter().zip(self.stride()))
169 {
170 let s = if dst_dim == src_dim {
171 src_stride
172 } else if src_dim != 1 {
173 return Err(Error::BroadcastIncompatibleShapes {
174 src_shape: self.shape().clone(),
175 dst_shape: shape,
176 }
177 .bt());
178 } else {
179 0
180 };
181 stride.push(s)
182 }
183 Ok(Self {
184 shape,
185 stride,
186 start_offset: self.start_offset,
187 })
188 }
189
190 pub(crate) fn strided_index(&self) -> crate::StridedIndex {
191 crate::StridedIndex::from_layout(self)
192 }
193
194 pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
195 let mut block_len = 1;
196 let mut contiguous_dims = 0; for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
198 if stride != block_len {
199 break;
200 }
201 block_len *= dim;
202 contiguous_dims += 1;
203 }
204 let index_dims = self.dims().len() - contiguous_dims;
205 if index_dims == 0 {
206 crate::StridedBlocks::SingleBlock {
207 start_offset: self.start_offset,
208 len: block_len,
209 }
210 } else {
211 let block_start_index = crate::StridedIndex::new(
212 &self.dims()[..index_dims],
213 &self.stride[..index_dims],
214 self.start_offset,
215 );
216 crate::StridedBlocks::MultipleBlocks {
217 block_start_index,
218 block_len,
219 }
220 }
221 }
222
223 pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
225 let mut left_broadcast = 1;
226 let mut right_broadcast = 1;
227 let strides = self.stride();
228 let dims = self.dims();
229 let mut start_cont = 0;
230 let mut end_cont = dims.len();
231 for (&s, &d) in strides.iter().zip(dims.iter()) {
232 if s != 0 {
233 break;
234 }
235 start_cont += 1;
236 left_broadcast *= d;
237 }
238 if start_cont == dims.len() {
239 return Some(ContiguousOffsetsWithBroadcast {
240 start: self.start_offset,
241 len: 1,
242 left_broadcast,
243 right_broadcast: 1,
244 });
245 }
246 for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
247 if s != 0 {
248 break;
249 }
250 end_cont -= 1;
251 right_broadcast *= d;
252 }
253 let strides = &strides[start_cont..end_cont];
255 let dims = &dims[start_cont..end_cont];
256 let mut len = 1;
257 for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
258 if stride != len {
259 return None;
260 }
261 len *= dim;
262 }
263 Some(ContiguousOffsetsWithBroadcast {
264 start: self.start_offset,
265 len,
266 left_broadcast,
267 right_broadcast,
268 })
269 }
270}
271
272#[derive(Debug, Clone, PartialEq, Eq)]
273pub struct ContiguousOffsetsWithBroadcast {
274 pub start: usize,
275 pub len: usize,
276 pub left_broadcast: usize,
277 pub right_broadcast: usize,
278}