1use crate::{Error, Result, Shape};
2
3#[derive(Debug, PartialEq, Eq, Clone)]
4pub struct Layout {
5 shape: Shape,
6 stride: Vec<usize>,
8 start_offset: usize,
9}
10
11impl Layout {
12 pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
13 Self {
14 shape,
15 stride,
16 start_offset,
17 }
18 }
19
20 pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
21 let shape = shape.into();
22 let stride = shape.stride_contiguous();
23 Self {
24 shape,
25 stride,
26 start_offset,
27 }
28 }
29
30 pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
31 Self::contiguous_with_offset(shape, 0)
32 }
33
34 pub fn dims(&self) -> &[usize] {
35 self.shape.dims()
36 }
37
38 pub fn shape(&self) -> &Shape {
39 &self.shape
40 }
41
42 pub fn stride(&self) -> &[usize] {
43 &self.stride
44 }
45
46 pub fn start_offset(&self) -> usize {
47 self.start_offset
48 }
49
50 pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
53 if self.is_contiguous() {
54 let start_o = self.start_offset;
55 Some((start_o, start_o + self.shape.elem_count()))
56 } else {
57 None
58 }
59 }
60
61 pub fn is_contiguous(&self) -> bool {
65 self.shape.is_contiguous(&self.stride)
66 }
67
68 pub fn is_fortran_contiguous(&self) -> bool {
70 self.shape.is_fortran_contiguous(&self.stride)
71 }
72
73 pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
74 let dims = self.shape().dims();
75 if dim >= dims.len() {
76 Err(Error::DimOutOfRange {
77 shape: self.shape().clone(),
78 dim: dim as i32,
79 op: "narrow",
80 }
81 .bt())?
82 }
83 if start + len > dims[dim] {
84 Err(Error::NarrowInvalidArgs {
85 shape: self.shape.clone(),
86 dim,
87 start,
88 len,
89 msg: "start + len > dim_len",
90 }
91 .bt())?
92 }
93 let mut dims = dims.to_vec();
94 dims[dim] = len;
95 Ok(Self {
96 shape: Shape::from(dims),
97 stride: self.stride.clone(),
98 start_offset: self.start_offset + self.stride[dim] * start,
99 })
100 }
101
102 pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
103 let rank = self.shape.rank();
104 if rank <= dim1 || rank <= dim2 {
105 Err(Error::UnexpectedNumberOfDims {
106 expected: usize::max(dim1, dim2),
107 got: rank,
108 shape: self.shape().clone(),
109 }
110 .bt())?
111 }
112 let mut stride = self.stride().to_vec();
113 let mut dims = self.shape().dims().to_vec();
114 dims.swap(dim1, dim2);
115 stride.swap(dim1, dim2);
116 Ok(Self {
117 shape: Shape::from(dims),
118 stride,
119 start_offset: self.start_offset,
120 })
121 }
122
123 pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
124 let is_permutation =
125 idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
126 if !is_permutation {
127 crate::bail!(
128 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
129 self.dims(),
130 idxs
131 )
132 }
133 let stride = self.stride();
134 let dims = self.shape().dims();
135 let mut perm_stride = stride.to_vec();
136 let mut perm_dims = dims.to_vec();
137 for (i, &idx) in idxs.iter().enumerate() {
138 perm_stride[i] = stride[idx];
139 perm_dims[i] = dims[idx];
140 }
141 Ok(Self {
142 shape: Shape::from(perm_dims),
143 stride: perm_stride,
144 start_offset: self.start_offset,
145 })
146 }
147
148 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
149 let shape = shape.into();
150 if shape.rank() < self.shape().rank() {
151 return Err(Error::BroadcastIncompatibleShapes {
152 src_shape: self.shape().clone(),
153 dst_shape: shape,
154 }
155 .bt());
156 }
157 let added_dims = shape.rank() - self.shape().rank();
158 let mut stride = vec![0; added_dims];
159 for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
160 .iter()
161 .zip(self.dims().iter().zip(self.stride()))
162 {
163 let s = if dst_dim == src_dim {
164 src_stride
165 } else if src_dim != 1 {
166 return Err(Error::BroadcastIncompatibleShapes {
167 src_shape: self.shape().clone(),
168 dst_shape: shape,
169 }
170 .bt());
171 } else {
172 0
173 };
174 stride.push(s)
175 }
176 Ok(Self {
177 shape,
178 stride,
179 start_offset: self.start_offset,
180 })
181 }
182
183 pub(crate) fn strided_index(&self) -> crate::StridedIndex {
184 crate::StridedIndex::from_layout(self)
185 }
186
187 pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
188 let mut block_len = 1;
189 let mut contiguous_dims = 0; for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
191 if stride != block_len {
192 break;
193 }
194 block_len *= dim;
195 contiguous_dims += 1;
196 }
197 let index_dims = self.dims().len() - contiguous_dims;
198 if index_dims == 0 {
199 crate::StridedBlocks::SingleBlock {
200 start_offset: self.start_offset,
201 len: block_len,
202 }
203 } else {
204 let block_start_index = crate::StridedIndex::new(
205 &self.dims()[..index_dims],
206 &self.stride[..index_dims],
207 self.start_offset,
208 );
209 crate::StridedBlocks::MultipleBlocks {
210 block_start_index,
211 block_len,
212 }
213 }
214 }
215
216 pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
218 let mut left_broadcast = 1;
219 let mut right_broadcast = 1;
220 let strides = self.stride();
221 let dims = self.dims();
222 let mut start_cont = 0;
223 let mut end_cont = dims.len();
224 for (&s, &d) in strides.iter().zip(dims.iter()) {
225 if s != 0 {
226 break;
227 }
228 start_cont += 1;
229 left_broadcast *= d;
230 }
231 if start_cont == dims.len() {
232 return Some(ContiguousOffsetsWithBroadcast {
233 start: self.start_offset,
234 len: 1,
235 left_broadcast,
236 right_broadcast: 1,
237 });
238 }
239 for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
240 if s != 0 {
241 break;
242 }
243 end_cont -= 1;
244 right_broadcast *= d;
245 }
246 let strides = &strides[start_cont..end_cont];
248 let dims = &dims[start_cont..end_cont];
249 let mut len = 1;
250 for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
251 if stride != len {
252 return None;
253 }
254 len *= dim;
255 }
256 Some(ContiguousOffsetsWithBroadcast {
257 start: self.start_offset,
258 len,
259 left_broadcast,
260 right_broadcast,
261 })
262 }
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub struct ContiguousOffsetsWithBroadcast {
267 pub start: usize,
268 pub len: usize,
269 pub left_broadcast: usize,
270 pub right_broadcast: usize,
271}