1use crate::{Error, Result};
2use super::{Dim, Shape};
3
4#[derive(Debug, PartialEq, Eq, Clone)]
5pub struct Layout {
6 pub(crate) shape: Shape,
7 pub(crate) stride: Vec<usize>,
8 pub(crate) start_offset: usize,
9}
10
11impl Layout {
12 pub fn new<S: Into<Shape>>(shape: S, stride: Vec<usize>, start_offset: usize) -> Self {
13 Self {
14 shape: shape.into(), stride, start_offset
15 }
16 }
17
18 pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
19 let shape = shape.into();
20 let stride = shape.stride_contiguous();
21 Self {
22 shape,
23 stride,
24 start_offset: 0,
25 }
26 }
27
28 pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
29 let shape = shape.into();
30 let stride = shape.stride_contiguous();
31 Self {
32 shape,
33 stride,
34 start_offset,
35 }
36 }
37
38 pub fn dims(&self) -> &[usize] {
39 self.shape.dims()
40 }
41
42 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
43 let dim = dim.to_index(&self.shape, "dim")?;
44 Ok(self.dims()[dim])
45 }
46
47 pub fn shape(&self) -> &Shape {
48 &self.shape
49 }
50
51 pub fn stride(&self) -> &[usize] {
52 &self.stride
53 }
54
55 pub fn start_offset(&self) -> usize {
56 self.start_offset
57 }
58
59 pub fn element_count(&self) -> usize {
60 self.shape().element_count()
61 }
62
63 pub fn is_contiguous(&self) -> bool {
64 self.shape.is_contiguous(&self.stride)
65 }
66
67 pub fn slice(&self, dim: usize, start: usize, end: usize, step: usize) -> Result<Self> {
68 let dims = self.shape().dims();
69
70 if dim >= dims.len() {
71 Err(Error::DimOutOfRange {
72 shape: self.shape().clone(),
73 dim: dim as i32,
74 op: "slice"
75 })?;
76 }
77
78 if step == 0 {
79 return Err(Error::NarrowInvalidArgs {
80 shape: self.shape.clone(),
81 dim, start, len: 0,
82 msg: "step cannot be 0",
83 }.into());
84 }
85
86 if start > end || end > dims[dim] {
87 return Err(Error::NarrowInvalidArgs {
88 shape: self.shape.clone(),
89 dim, start, len: end.saturating_sub(start),
90 msg: "index out of range",
91 }.into());
92 }
93
94 let new_len = if start == end { 0 } else { (start..end).step_by(step).len() };
95
96 let mut new_dims = dims.to_vec();
97 new_dims[dim] = new_len;
98
99 let mut new_stride = self.stride.clone();
100 new_stride[dim] *= step;
101
102 Ok(Self::new(
103 new_dims,
104 new_stride,
105 self.start_offset + self.stride[dim] * start
106 ))
107 }
108
109 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
110 self.slice(dim, start, start + len, 1)
111 }
112
113 pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
114 let rank = self.shape.rank();
115 if rank <= dim1 || rank <= dim2 {
116 Err(Error::UnexpectedNumberOfDims {
117 expected: usize::max(dim1, dim2),
118 got: rank,
119 shape: self.shape().clone(),
120 })?
121 }
122
123 let mut stride = self.stride().to_vec();
124 let mut dims = self.shape().dims().to_vec();
125 dims.swap(dim1, dim2);
126 stride.swap(dim1, dim2);
127
128 Ok(Self::new(dims, stride, self.start_offset))
129 }
130
131 pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
132 let shape = shape.into();
133 if shape.rank() < self.shape().rank() {
134 return Err(Error::BroadcastIncompatibleShapes {
135 src_shape: self.shape().clone(),
136 dst_shape: shape,
137 })?;
138 }
139
140 let added_dims = shape.rank() - self.shape().rank();
141 let mut stride = vec![0; added_dims];
142 for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
143 .iter()
144 .zip(self.dims().iter().zip(self.stride()))
145 {
146 let s = if dst_dim == src_dim {
147 src_stride
148 } else if src_dim != 1 {
149 return Err(Error::BroadcastIncompatibleShapes {
150 src_shape: self.shape().clone(),
151 dst_shape: shape,
152 })?;
153 } else {
154 0
155 };
156 stride.push(s)
157 }
158 Ok(Self {
159 shape,
160 stride,
161 start_offset: self.start_offset,
162 })
163 }
164
165 pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
166 let is_permutation =
167 idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
168 if !is_permutation {
169 crate::bail!(
170 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
171 self.dims(),
172 idxs
173 )
174 }
175 let stride = self.stride();
176 let dims = self.shape().dims();
177 let mut perm_stride = stride.to_vec();
178 let mut perm_dims = dims.to_vec();
179 for (i, &idx) in idxs.iter().enumerate() {
180 perm_stride[i] = stride[idx];
181 perm_dims[i] = dims[idx];
182 }
183 Ok(Self {
184 shape: Shape::from(perm_dims),
185 stride: perm_stride,
186 start_offset: self.start_offset,
187 })
188 }
189
190 pub fn storage_indices(&self) -> StorageIndices {
199 StorageIndices::from_layout(self)
200 }
201}
202
203#[derive(Debug, Clone)]
208pub enum StorageIndices<'a> {
209 UncontiguousStorageIndices(UncontiguousStorageIndices<'a>),
210 ContiguousStorageIndices(ContiguousStorageIndices),
211}
212
213impl<'a> StorageIndices<'a> {
214 pub fn from_layout(l: &'a Layout) -> Self {
215 if l.is_contiguous() {
216 Self::ContiguousStorageIndices(ContiguousStorageIndices::from_layout(l))
217 } else {
218 Self::UncontiguousStorageIndices(UncontiguousStorageIndices::from_layout(l))
219 }
220 }
221
222 pub fn reset(&mut self) {
223 match self {
224 Self::UncontiguousStorageIndices(index) => index.reset(),
225 Self::ContiguousStorageIndices(index) => index.reset(),
226 }
227 }
228
229 pub fn len(&self) -> usize {
230 match self {
231 Self::UncontiguousStorageIndices(index) => index.len(),
232 Self::ContiguousStorageIndices(index) => index.len(),
233 }
234 }
235}
236
237impl<'a> Iterator for StorageIndices<'a> {
238 type Item = usize;
239
240 fn next(&mut self) -> Option<Self::Item> {
241 match self {
242 Self::ContiguousStorageIndices(i) => i.next(),
243 Self::UncontiguousStorageIndices(i) => i.next(),
244 }
245 }
246}
247
248#[derive(Debug, Clone)]
249pub struct ContiguousStorageIndices {
250 init_storage_index: usize,
251 storage_index: usize,
252 end_index: usize,
253}
254
255impl ContiguousStorageIndices {
256 fn from_layout(l: &Layout) -> Self {
257 Self {
258 init_storage_index: l.start_offset(),
259 storage_index: l.start_offset(),
260 end_index: l.start_offset() + l.element_count(),
261 }
262 }
263
264 fn reset(&mut self) {
265 self.storage_index = self.init_storage_index;
266 }
267
268 fn len(&self) -> usize {
269 self.end_index - self.init_storage_index
270 }
271}
272
273impl Iterator for ContiguousStorageIndices {
274 type Item = usize;
275
276 fn next(&mut self) -> Option<Self::Item> {
277 if self.storage_index >= self.end_index {
278 None
279 } else {
280 let index = self.storage_index;
281 self.storage_index += 1;
282 Some(index)
283 }
284 }
285}
286
287impl<S: Into<Shape>> From<S> for Layout {
288 fn from(value: S) -> Self {
289 Layout::contiguous(value.into())
290 }
291}
292
293#[derive(Debug, Clone)]
294pub struct UncontiguousStorageIndices<'a> {
295 init_storage_index: Option<usize>, next_storage_index: Option<usize>,
297 multi_index: Vec<usize>,
298 dims: &'a [usize],
299 stride: &'a [usize],
300 len: usize,
301}
302
303impl<'a> UncontiguousStorageIndices<'a> {
304 fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
305 let elem_count: usize = dims.iter().product();
306 let next_storage_index = if elem_count == 0 {
307 None
308 } else {
309 Some(start_offset)
311 };
312 UncontiguousStorageIndices {
313 init_storage_index: next_storage_index,
314 next_storage_index,
315 multi_index: vec![0; dims.len()],
316 dims,
317 stride,
318 len: elem_count,
319 }
320 }
321
322 fn from_layout(l: &'a Layout) -> Self {
323 Self::new(l.dims(), l.stride(), l.start_offset())
324 }
325
326 pub fn reset(&mut self) {
327 self.next_storage_index = self.init_storage_index;
328 }
329
330 pub fn len(&self) -> usize {
331 self.len
332 }
333}
334
335impl Iterator for UncontiguousStorageIndices<'_> {
336 type Item = usize;
337
338 fn next(&mut self) -> Option<Self::Item> {
339 let storage_index = self.next_storage_index?;
340 let mut updated = false;
341 let mut next_storage_index = storage_index;
342 for ((multi_i, max_i), stride_i) in self
343 .multi_index
344 .iter_mut()
345 .zip(self.dims.iter())
346 .zip(self.stride.iter())
347 .rev()
348 {
349 let next_i = *multi_i + 1;
350 if next_i < *max_i {
351 *multi_i = next_i;
352 updated = true;
353 next_storage_index += stride_i;
354 break;
355 } else {
356 next_storage_index -= *multi_i * stride_i;
357 *multi_i = 0
358 }
359 }
360 self.next_storage_index = if updated {
361 Some(next_storage_index)
362 } else {
363 None
364 };
365 Some(storage_index)
366 }
367}
368
369#[cfg(test)]
370#[allow(unused)]
371mod tests {
372 use super::{Layout, StorageIndices};
373
374 #[test]
375 fn test_strided_index1() {
376 let layout = Layout::contiguous((2, 5, 4));
377 let index = StorageIndices::from_layout(&layout);
378 for i in index {
379 println!("{}", i);
380 }
381 }
382
383 #[test]
384 fn test_strided_index2() {
385 let layout = Layout::contiguous((2, 3, 3));
386 let layout = layout.narrow(1, 1, 1).unwrap();
387 println!("{:?}", layout.stride());
388 let index = StorageIndices::from_layout(&layout);
389 for i in index {
390 println!("{}", i);
391 }
392 }
393}
394