1use alloc::vec;
2use alloc::vec::Vec;
3use burn_std::{Shape, Slice};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct Layout {
11 shape: Shape,
12 strides: Vec<isize>,
14 start_offset: usize,
15}
16
17pub(crate) fn contiguous_strides_usize(shape: &Shape) -> Vec<usize> {
19 let ndims = shape.num_dims();
20 let mut strides = vec![1usize; ndims];
21 for i in (0..ndims.saturating_sub(1)).rev() {
22 strides[i] = strides[i + 1] * shape[i + 1];
23 }
24 strides
25}
26
27pub(crate) fn slice_base_offset(
33 slice_idx: usize,
34 shape: &Shape,
35 strides: &[usize],
36 dim: usize,
37) -> usize {
38 let ndims = shape.num_dims();
39 let mut offset = 0;
40 let mut remaining = slice_idx;
41 for d in (0..ndims).rev() {
42 if d == dim {
43 continue;
44 }
45 let s = shape[d];
46 offset += (remaining % s) * strides[d];
47 remaining /= s;
48 }
49 offset
50}
51
52impl Layout {
53 pub fn contiguous(shape: Shape) -> Self {
55 let strides: Vec<isize> = contiguous_strides_usize(&shape)
56 .into_iter()
57 .map(|s| s as isize)
58 .collect();
59
60 Self {
61 shape,
62 strides,
63 start_offset: 0,
64 }
65 }
66
67 pub fn new(shape: Shape, strides: Vec<isize>, start_offset: usize) -> Self {
69 debug_assert_eq!(shape.num_dims(), strides.len());
70 Self {
71 shape,
72 strides,
73 start_offset,
74 }
75 }
76
77 pub fn shape(&self) -> &Shape {
79 &self.shape
80 }
81
82 pub fn strides(&self) -> &[isize] {
84 &self.strides
85 }
86
87 pub fn start_offset(&self) -> usize {
89 self.start_offset
90 }
91
92 pub fn num_dims(&self) -> usize {
94 self.shape.num_dims()
95 }
96
97 pub fn num_elements(&self) -> usize {
99 self.shape.num_elements()
100 }
101
102 pub fn is_contiguous(&self) -> bool {
104 if self.shape.num_dims() == 0 {
105 return true;
106 }
107
108 let mut expected_stride = 1isize;
109 for i in (0..self.shape.num_dims()).rev() {
110 if self.strides[i] != expected_stride {
111 return false;
112 }
113 expected_stride *= self.shape[i] as isize;
114 }
115 true
116 }
117
118 pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
120 if self.is_contiguous() {
121 Some((self.start_offset, self.start_offset + self.num_elements()))
122 } else {
123 None
124 }
125 }
126
127 pub fn transpose(&self, dim1: usize, dim2: usize) -> Self {
129 let mut dims = self.shape.to_vec();
130 let mut strides = self.strides.clone();
131 dims.swap(dim1, dim2);
132 strides.swap(dim1, dim2);
133 Self {
134 shape: Shape::from(dims),
135 strides,
136 start_offset: self.start_offset,
137 }
138 }
139
140 pub fn permute(&self, axes: &[usize]) -> Self {
144 debug_assert_eq!(
145 axes.len(),
146 self.num_dims(),
147 "permute: axes length must match number of dimensions"
148 );
149
150 let new_dims: Vec<usize> = axes.iter().map(|&i| self.shape[i]).collect();
151 let new_strides: Vec<isize> = axes.iter().map(|&i| self.strides[i]).collect();
152
153 Self {
154 shape: Shape::from(new_dims),
155 strides: new_strides,
156 start_offset: self.start_offset,
157 }
158 }
159
160 pub fn flip(&self, axes: &[usize]) -> Self {
165 let mut new_strides = self.strides.clone();
166 let mut offset_adjustment: isize = 0;
167
168 for &axis in axes {
169 debug_assert!(
170 axis < self.num_dims(),
171 "flip: axis {} out of bounds for {} dimensions",
172 axis,
173 self.num_dims()
174 );
175
176 let dim_size = self.shape[axis];
177 if dim_size > 1 {
178 offset_adjustment += (dim_size as isize - 1) * self.strides[axis];
180 new_strides[axis] = -new_strides[axis];
182 }
183 }
184
185 let new_start_isize = self.start_offset as isize + offset_adjustment;
186 debug_assert!(new_start_isize >= 0, "flip: negative offset");
187 let new_start = new_start_isize as usize;
188
189 Self {
190 shape: self.shape.clone(),
191 strides: new_strides,
192 start_offset: new_start,
193 }
194 }
195
196 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
198 debug_assert!(
199 start + len <= self.shape[dim],
200 "narrow: start ({}) + len ({}) exceeds dimension size ({})",
201 start,
202 len,
203 self.shape[dim]
204 );
205 let mut dims = self.shape.to_vec();
206 dims[dim] = len;
207
208 let new_offset_isize = self.start_offset as isize + self.strides[dim] * start as isize;
209 debug_assert!(new_offset_isize >= 0, "narrow: negative offset");
210 let new_offset = new_offset_isize as usize;
211
212 Self {
213 shape: Shape::from(dims),
214 strides: self.strides.clone(),
215 start_offset: new_offset,
216 }
217 }
218
219 pub fn slice(&self, slices: &[Slice]) -> (Self, bool) {
225 let ndims = self.num_dims();
226 let mut new_dims = self.shape.to_vec();
227 let mut new_strides = self.strides.clone();
228 let mut new_offset = self.start_offset as isize;
229 let mut needs_copy = false;
230
231 for (dim, slice) in slices.iter().enumerate() {
232 if dim >= ndims {
233 break;
234 }
235
236 let dim_size = self.shape[dim] as isize;
237 let stride = self.strides[dim];
238
239 let start = if slice.start < 0 {
241 (dim_size + slice.start).max(0) as usize
242 } else {
243 (slice.start as usize).min(dim_size as usize)
244 };
245
246 let end = match slice.end {
250 Some(e) if e < 0 => (dim_size + e).max(0) as usize,
251 Some(e) => (e as usize).min(dim_size as usize),
252 None => dim_size as usize, };
254
255 let step = slice.step;
256 let abs_step = step.unsigned_abs();
257
258 if step > 0 {
259 let len = if end > start {
261 (end - start).div_ceil(abs_step)
262 } else {
263 0
264 };
265 new_dims[dim] = len;
266 new_strides[dim] = stride * step;
267 new_offset += stride * start as isize;
268 } else {
269 needs_copy = true;
272 let len = if end > start {
273 (end - start).div_ceil(abs_step)
274 } else {
275 0
276 };
277 new_dims[dim] = len;
278 new_strides[dim] = stride; }
280 }
281
282 debug_assert!(new_offset >= 0, "slice: negative offset");
283
284 (
285 Self {
286 shape: Shape::from(new_dims),
287 strides: new_strides,
288 start_offset: new_offset as usize,
289 },
290 needs_copy,
291 )
292 }
293
294 pub fn reshape(&self, new_shape: Shape) -> Option<Self> {
298 if !self.is_contiguous() || self.start_offset != 0 {
299 return None;
300 }
301 debug_assert_eq!(
302 self.num_elements(),
303 new_shape.num_elements(),
304 "reshape must preserve total elements"
305 );
306 Some(Self::contiguous(new_shape))
307 }
308
309 pub fn index(&self, indices: &[usize]) -> usize {
311 debug_assert_eq!(indices.len(), self.num_dims());
312 let mut offset = self.start_offset as isize;
313 for (i, &idx) in indices.iter().enumerate() {
314 offset += idx as isize * self.strides[i];
315 }
316 debug_assert!(offset >= 0, "index: negative offset");
317 offset as usize
318 }
319
320 pub fn inner_stride(&self) -> usize {
324 self.strides.last().map(|s| s.unsigned_abs()).unwrap_or(1)
325 }
326
327 pub fn has_contiguous_inner(&self) -> bool {
330 self.inner_stride() == 1
331 }
332
333 pub fn as_2d_strides(&self) -> Option<(usize, usize, isize, isize)> {
336 if self.num_dims() != 2 {
337 return None;
338 }
339 Some((
340 self.shape[0],
341 self.shape[1],
342 self.strides[0],
343 self.strides[1],
344 ))
345 }
346
347 pub fn has_positive_strides(&self) -> bool {
349 self.strides.iter().all(|&s| s >= 0)
350 }
351
352 pub fn strided_blocks(&self) -> StridedBlocks<'_> {
362 let n = self.num_elements();
363 if n == 0 {
364 return StridedBlocks::Single { start: 0, len: 0 };
365 }
366
367 if self.is_contiguous() {
369 return StridedBlocks::Single {
370 start: self.start_offset,
371 len: n,
372 };
373 }
374
375 let ndims = self.num_dims();
378 let mut block_len = 1usize;
379 let mut expected_stride = 1isize;
380
381 for i in (0..ndims).rev() {
382 if self.strides[i] == expected_stride {
383 block_len *= self.shape[i];
384 expected_stride *= self.shape[i] as isize;
385 } else {
386 break;
387 }
388 }
389
390 if block_len == n {
391 return StridedBlocks::Single {
393 start: self.start_offset,
394 len: n,
395 };
396 }
397
398 let num_blocks = n / block_len;
399 StridedBlocks::Multiple {
400 layout: self,
401 block_len,
402 num_blocks,
403 }
404 }
405}
406
407#[derive(Debug, Clone)]
409pub enum StridedBlocks<'a> {
410 Single { start: usize, len: usize },
412 Multiple {
414 layout: &'a Layout,
415 block_len: usize,
416 num_blocks: usize,
417 },
418}
419
420impl<'a> StridedBlocks<'a> {
421 pub fn block_len(&self) -> usize {
423 match self {
424 Self::Single { len, .. } => *len,
425 Self::Multiple { block_len, .. } => *block_len,
426 }
427 }
428
429 pub fn block_starts(&self) -> BlockStartIter<'_> {
431 match self {
432 Self::Single { start, .. } => BlockStartIter::Single {
433 start: *start,
434 done: false,
435 },
436 Self::Multiple {
437 layout,
438 block_len,
439 num_blocks,
440 } => {
441 let ndims = layout.num_dims();
443 let mut outer_dims = 0;
444 let mut expected_stride = 1isize;
445
446 for i in (0..ndims).rev() {
447 if layout.strides[i] == expected_stride {
448 expected_stride *= layout.shape[i] as isize;
449 } else {
450 outer_dims = i + 1;
451 break;
452 }
453 }
454
455 BlockStartIter::Multiple {
456 layout,
457 multi_index: vec![0; outer_dims],
458 remaining: *num_blocks,
459 block_len: *block_len,
460 }
461 }
462 }
463 }
464}
465
466pub enum BlockStartIter<'a> {
468 Single {
469 start: usize,
470 done: bool,
471 },
472 Multiple {
473 layout: &'a Layout,
474 multi_index: Vec<usize>,
475 remaining: usize,
476 block_len: usize,
477 },
478}
479
480impl Iterator for BlockStartIter<'_> {
481 type Item = usize;
482
483 fn next(&mut self) -> Option<usize> {
484 match self {
485 Self::Single { start, done } => {
486 if *done {
487 None
488 } else {
489 *done = true;
490 Some(*start)
491 }
492 }
493 Self::Multiple {
494 layout,
495 multi_index,
496 remaining,
497 block_len: _,
498 } => {
499 if *remaining == 0 {
500 return None;
501 }
502
503 let outer_dims = multi_index.len();
505 let mut offset = layout.start_offset as isize;
506 for (i, &idx) in multi_index.iter().enumerate() {
507 offset += idx as isize * layout.strides[i];
508 }
509
510 *remaining -= 1;
511
512 let shape = &layout.shape;
514 for d in (0..outer_dims).rev() {
515 multi_index[d] += 1;
516 if multi_index[d] < shape[d] {
517 break;
518 }
519 multi_index[d] = 0;
520 }
521
522 Some(offset as usize)
523 }
524 }
525 }
526
527 fn size_hint(&self) -> (usize, Option<usize>) {
528 let len = match self {
529 Self::Single { done, .. } => {
530 if *done {
531 0
532 } else {
533 1
534 }
535 }
536 Self::Multiple { remaining, .. } => *remaining,
537 };
538 (len, Some(len))
539 }
540}
541
542impl ExactSizeIterator for BlockStartIter<'_> {}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn test_contiguous_layout() {
550 let layout = Layout::contiguous(Shape::from(vec![2, 3, 4]));
551 assert_eq!(layout.strides(), &[12, 4, 1]);
552 assert!(layout.is_contiguous());
553 }
554
555 #[test]
556 fn test_transpose() {
557 let layout = Layout::contiguous(Shape::from(vec![2, 3]));
558 let transposed = layout.transpose(0, 1);
559 assert_eq!(transposed.shape().to_vec(), vec![3, 2]);
560 assert_eq!(transposed.strides(), &[1, 3]);
561 assert!(!transposed.is_contiguous());
562 }
563
564 #[test]
565 fn test_narrow() {
566 let layout = Layout::contiguous(Shape::from(vec![4, 4]));
567 let narrowed = layout.narrow(0, 1, 2);
568 assert_eq!(narrowed.shape().to_vec(), vec![2, 4]);
569 assert_eq!(narrowed.start_offset(), 4);
570 }
571
572 #[test]
573 fn test_contiguous_offsets() {
574 let layout = Layout::contiguous(Shape::from(vec![2, 3]));
575 assert_eq!(layout.contiguous_offsets(), Some((0, 6)));
576 }
577
578 #[test]
579 fn test_index() {
580 let layout = Layout::contiguous(Shape::from(vec![2, 3]));
581 assert_eq!(layout.index(&[0, 0]), 0);
582 assert_eq!(layout.index(&[0, 2]), 2);
583 assert_eq!(layout.index(&[1, 0]), 3);
584 assert_eq!(layout.index(&[1, 2]), 5);
585 }
586
587 #[test]
588 fn test_flip_1d() {
589 let layout = Layout::contiguous(Shape::from(vec![4]));
592 let flipped = layout.flip(&[0]);
593
594 assert_eq!(flipped.shape().to_vec(), vec![4]);
595 assert_eq!(flipped.strides(), &[-1]);
596 assert_eq!(flipped.start_offset(), 3);
597
598 assert_eq!(flipped.index(&[0]), 3);
600 assert_eq!(flipped.index(&[1]), 2);
601 assert_eq!(flipped.index(&[2]), 1);
602 assert_eq!(flipped.index(&[3]), 0);
603 }
604
605 #[test]
606 fn test_flip_2d_axis0() {
607 let layout = Layout::contiguous(Shape::from(vec![2, 3]));
610 let flipped = layout.flip(&[0]);
611
612 assert_eq!(flipped.strides(), &[-3, 1]);
613 assert_eq!(flipped.start_offset(), 3);
614
615 assert_eq!(flipped.index(&[0, 0]), 3);
617 assert_eq!(flipped.index(&[0, 1]), 4);
618 assert_eq!(flipped.index(&[0, 2]), 5);
619 assert_eq!(flipped.index(&[1, 0]), 0);
621 assert_eq!(flipped.index(&[1, 1]), 1);
622 assert_eq!(flipped.index(&[1, 2]), 2);
623 }
624
625 #[test]
626 fn test_flip_2d_axis1() {
627 let layout = Layout::contiguous(Shape::from(vec![2, 3]));
630 let flipped = layout.flip(&[1]);
631
632 assert_eq!(flipped.strides(), &[3, -1]);
633 assert_eq!(flipped.start_offset(), 2);
634
635 assert_eq!(flipped.index(&[0, 0]), 2);
637 assert_eq!(flipped.index(&[0, 1]), 1);
638 assert_eq!(flipped.index(&[0, 2]), 0);
639 assert_eq!(flipped.index(&[1, 0]), 5);
640 assert_eq!(flipped.index(&[1, 1]), 4);
641 assert_eq!(flipped.index(&[1, 2]), 3);
642 }
643
644 #[test]
645 fn test_flip_both_axes() {
646 let layout = Layout::contiguous(Shape::from(vec![2, 3]));
648 let flipped = layout.flip(&[0, 1]);
649
650 assert_eq!(flipped.strides(), &[-3, -1]);
651 assert_eq!(flipped.start_offset(), 5); assert_eq!(flipped.index(&[0, 0]), 5);
654 assert_eq!(flipped.index(&[0, 1]), 4);
655 assert_eq!(flipped.index(&[0, 2]), 3);
656 assert_eq!(flipped.index(&[1, 0]), 2);
657 assert_eq!(flipped.index(&[1, 1]), 1);
658 assert_eq!(flipped.index(&[1, 2]), 0);
659 }
660}