Skip to main content

burn_flex/
layout.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use burn_std::{Shape, Slice};
4
5/// Layout describes how to interpret a linear buffer as an N-dimensional tensor.
6///
7/// Stores shape, strides (in elements, can be negative for flipped dimensions),
8/// and a start offset for views/slices.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct Layout {
11    shape: Shape,
12    /// Strides in elements. Negative strides enable zero-copy flip.
13    strides: Vec<isize>,
14    start_offset: usize,
15}
16
17/// Compute row-major contiguous strides for a shape (as `usize`).
18pub(crate) fn contiguous_strides_usize(shape: &Shape) -> Vec<usize> {
19    let ndims = shape.num_dims();
20    let mut strides = vec![1usize; ndims];
21    for i in (0..ndims.saturating_sub(1)).rev() {
22        strides[i] = strides[i + 1] * shape[i + 1];
23    }
24    strides
25}
26
27/// Compute the flat offset for the `slice_idx`-th 1D fiber along `dim`.
28///
29/// Enumerates all index combinations for dimensions other than `dim`,
30/// mapping the flat `slice_idx` (0..product of non-dim sizes) to the
31/// corresponding starting offset in a contiguous buffer.
32pub(crate) fn slice_base_offset(
33    slice_idx: usize,
34    shape: &Shape,
35    strides: &[usize],
36    dim: usize,
37) -> usize {
38    let ndims = shape.num_dims();
39    let mut offset = 0;
40    let mut remaining = slice_idx;
41    for d in (0..ndims).rev() {
42        if d == dim {
43            continue;
44        }
45        let s = shape[d];
46        offset += (remaining % s) * strides[d];
47        remaining /= s;
48    }
49    offset
50}
51
52impl Layout {
53    /// Create a new contiguous layout (row-major/C-order).
54    pub fn contiguous(shape: Shape) -> Self {
55        let strides: Vec<isize> = contiguous_strides_usize(&shape)
56            .into_iter()
57            .map(|s| s as isize)
58            .collect();
59
60        Self {
61            shape,
62            strides,
63            start_offset: 0,
64        }
65    }
66
67    /// Create a layout with explicit strides.
68    pub fn new(shape: Shape, strides: Vec<isize>, start_offset: usize) -> Self {
69        debug_assert_eq!(shape.num_dims(), strides.len());
70        Self {
71            shape,
72            strides,
73            start_offset,
74        }
75    }
76
77    /// The shape of the tensor.
78    pub fn shape(&self) -> &Shape {
79        &self.shape
80    }
81
82    /// The strides in elements (can be negative for flipped dimensions).
83    pub fn strides(&self) -> &[isize] {
84        &self.strides
85    }
86
87    /// The start offset for views/slices.
88    pub fn start_offset(&self) -> usize {
89        self.start_offset
90    }
91
92    /// Number of dimensions.
93    pub fn num_dims(&self) -> usize {
94        self.shape.num_dims()
95    }
96
97    /// Total number of elements.
98    pub fn num_elements(&self) -> usize {
99        self.shape.num_elements()
100    }
101
102    /// Check if this layout is contiguous (row-major, positive strides).
103    pub fn is_contiguous(&self) -> bool {
104        if self.shape.num_dims() == 0 {
105            return true;
106        }
107
108        let mut expected_stride = 1isize;
109        for i in (0..self.shape.num_dims()).rev() {
110            if self.strides[i] != expected_stride {
111                return false;
112            }
113            expected_stride *= self.shape[i] as isize;
114        }
115        true
116    }
117
118    /// If contiguous, return (start, end) offsets for direct slice access.
119    pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
120        if self.is_contiguous() {
121            Some((self.start_offset, self.start_offset + self.num_elements()))
122        } else {
123            None
124        }
125    }
126
127    /// Transpose: swap two dimensions (zero-copy, metadata only).
128    pub fn transpose(&self, dim1: usize, dim2: usize) -> Self {
129        let mut dims = self.shape.to_vec();
130        let mut strides = self.strides.clone();
131        dims.swap(dim1, dim2);
132        strides.swap(dim1, dim2);
133        Self {
134            shape: Shape::from(dims),
135            strides,
136            start_offset: self.start_offset,
137        }
138    }
139
140    /// Permute: reorder dimensions according to axes (zero-copy, metadata only).
141    ///
142    /// `axes` must be a permutation of 0..ndim.
143    pub fn permute(&self, axes: &[usize]) -> Self {
144        debug_assert_eq!(
145            axes.len(),
146            self.num_dims(),
147            "permute: axes length must match number of dimensions"
148        );
149
150        let new_dims: Vec<usize> = axes.iter().map(|&i| self.shape[i]).collect();
151        let new_strides: Vec<isize> = axes.iter().map(|&i| self.strides[i]).collect();
152
153        Self {
154            shape: Shape::from(new_dims),
155            strides: new_strides,
156            start_offset: self.start_offset,
157        }
158    }
159
160    /// Flip: reverse elements along specified axes (zero-copy, metadata only).
161    ///
162    /// For each flipped axis, negates the stride and adjusts start_offset
163    /// to point to the last element along that dimension.
164    pub fn flip(&self, axes: &[usize]) -> Self {
165        let mut new_strides = self.strides.clone();
166        let mut offset_adjustment: isize = 0;
167
168        for &axis in axes {
169            debug_assert!(
170                axis < self.num_dims(),
171                "flip: axis {} out of bounds for {} dimensions",
172                axis,
173                self.num_dims()
174            );
175
176            let dim_size = self.shape[axis];
177            if dim_size > 1 {
178                // Move start to last element along this axis
179                offset_adjustment += (dim_size as isize - 1) * self.strides[axis];
180                // Negate stride to iterate backwards
181                new_strides[axis] = -new_strides[axis];
182            }
183        }
184
185        let new_start_isize = self.start_offset as isize + offset_adjustment;
186        debug_assert!(new_start_isize >= 0, "flip: negative offset");
187        let new_start = new_start_isize as usize;
188
189        Self {
190            shape: self.shape.clone(),
191            strides: new_strides,
192            start_offset: new_start,
193        }
194    }
195
196    /// Narrow/slice along a dimension (zero-copy, metadata only).
197    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
198        debug_assert!(
199            start + len <= self.shape[dim],
200            "narrow: start ({}) + len ({}) exceeds dimension size ({})",
201            start,
202            len,
203            self.shape[dim]
204        );
205        let mut dims = self.shape.to_vec();
206        dims[dim] = len;
207
208        let new_offset_isize = self.start_offset as isize + self.strides[dim] * start as isize;
209        debug_assert!(new_offset_isize >= 0, "narrow: negative offset");
210        let new_offset = new_offset_isize as usize;
211
212        Self {
213            shape: Shape::from(dims),
214            strides: self.strides.clone(),
215            start_offset: new_offset,
216        }
217    }
218
219    /// Apply slices to create a new layout.
220    ///
221    /// Returns `(new_layout, needs_copy)`:
222    /// - `needs_copy = false`: Can use zero-copy view with new layout
223    /// - `needs_copy = true`: Has negative steps requiring data reordering
224    pub fn slice(&self, slices: &[Slice]) -> (Self, bool) {
225        let ndims = self.num_dims();
226        let mut new_dims = self.shape.to_vec();
227        let mut new_strides = self.strides.clone();
228        let mut new_offset = self.start_offset as isize;
229        let mut needs_copy = false;
230
231        for (dim, slice) in slices.iter().enumerate() {
232            if dim >= ndims {
233                break;
234            }
235
236            let dim_size = self.shape[dim] as isize;
237            let stride = self.strides[dim];
238
239            // Normalize start index (handle negative)
240            let start = if slice.start < 0 {
241                (dim_size + slice.start).max(0) as usize
242            } else {
243                (slice.start as usize).min(dim_size as usize)
244            };
245
246            // Normalize end index (handle negative and None)
247            // Note: Range [start, end) determines WHICH elements to select,
248            // step determines iteration ORDER
249            let end = match slice.end {
250                Some(e) if e < 0 => (dim_size + e).max(0) as usize,
251                Some(e) => (e as usize).min(dim_size as usize),
252                None => dim_size as usize, // Always full range when end is None
253            };
254
255            let step = slice.step;
256            let abs_step = step.unsigned_abs();
257
258            if step > 0 {
259                // Positive step: forward iteration
260                let len = if end > start {
261                    (end - start).div_ceil(abs_step)
262                } else {
263                    0
264                };
265                new_dims[dim] = len;
266                new_strides[dim] = stride * step;
267                new_offset += stride * start as isize;
268            } else {
269                // Negative step: select range then iterate in reverse
270                // Requires copy to reorder elements
271                needs_copy = true;
272                let len = if end > start {
273                    (end - start).div_ceil(abs_step)
274                } else {
275                    0
276                };
277                new_dims[dim] = len;
278                new_strides[dim] = stride; // Will be handled during copy
279            }
280        }
281
282        debug_assert!(new_offset >= 0, "slice: negative offset");
283
284        (
285            Self {
286                shape: Shape::from(new_dims),
287                strides: new_strides,
288                start_offset: new_offset as usize,
289            },
290            needs_copy,
291        )
292    }
293
294    /// Reshape to a new shape. Only works if contiguous with zero offset.
295    ///
296    /// Returns None if not contiguous or has non-zero offset (would require data copy).
297    pub fn reshape(&self, new_shape: Shape) -> Option<Self> {
298        if !self.is_contiguous() || self.start_offset != 0 {
299            return None;
300        }
301        debug_assert_eq!(
302            self.num_elements(),
303            new_shape.num_elements(),
304            "reshape must preserve total elements"
305        );
306        Some(Self::contiguous(new_shape))
307    }
308
309    /// Compute linear index from multi-dimensional indices.
310    pub fn index(&self, indices: &[usize]) -> usize {
311        debug_assert_eq!(indices.len(), self.num_dims());
312        let mut offset = self.start_offset as isize;
313        for (i, &idx) in indices.iter().enumerate() {
314            offset += idx as isize * self.strides[i];
315        }
316        debug_assert!(offset >= 0, "index: negative offset");
317        offset as usize
318    }
319
320    /// Get stride of the innermost (last) dimension.
321    /// Returns 1 for contiguous tensors, larger values for transposed.
322    /// Returns absolute value (ignores flip).
323    pub fn inner_stride(&self) -> usize {
324        self.strides.last().map(|s| s.unsigned_abs()).unwrap_or(1)
325    }
326
327    /// Check if innermost dimension is contiguous (|stride| == 1).
328    /// This enables efficient vectorized inner loops.
329    pub fn has_contiguous_inner(&self) -> bool {
330        self.inner_stride() == 1
331    }
332
333    /// For 2D layouts, get (outer_size, inner_size, outer_stride, inner_stride).
334    /// Returns None if not 2D.
335    pub fn as_2d_strides(&self) -> Option<(usize, usize, isize, isize)> {
336        if self.num_dims() != 2 {
337            return None;
338        }
339        Some((
340            self.shape[0],
341            self.shape[1],
342            self.strides[0],
343            self.strides[1],
344        ))
345    }
346
347    /// Check if all strides are non-negative.
348    pub fn has_positive_strides(&self) -> bool {
349        self.strides.iter().all(|&s| s >= 0)
350    }
351
352    /// Compute strided blocks for efficient iteration.
353    ///
354    /// Returns (block_len, num_blocks, block_stride) where:
355    /// - block_len: number of contiguous elements in each block
356    /// - num_blocks: total number of blocks
357    /// - block_stride: stride between consecutive blocks (0 if single block)
358    ///
359    /// For contiguous tensors: single block covering all elements.
360    /// For transposed/strided: multiple blocks of contiguous data.
361    pub fn strided_blocks(&self) -> StridedBlocks<'_> {
362        let n = self.num_elements();
363        if n == 0 {
364            return StridedBlocks::Single { start: 0, len: 0 };
365        }
366
367        // Fast path: fully contiguous
368        if self.is_contiguous() {
369            return StridedBlocks::Single {
370                start: self.start_offset,
371                len: n,
372            };
373        }
374
375        // Find contiguous inner dimensions (only positive strides)
376        // Start from innermost and work outward while strides match contiguous pattern
377        let ndims = self.num_dims();
378        let mut block_len = 1usize;
379        let mut expected_stride = 1isize;
380
381        for i in (0..ndims).rev() {
382            if self.strides[i] == expected_stride {
383                block_len *= self.shape[i];
384                expected_stride *= self.shape[i] as isize;
385            } else {
386                break;
387            }
388        }
389
390        if block_len == n {
391            // All dimensions contiguous (just offset)
392            return StridedBlocks::Single {
393                start: self.start_offset,
394                len: n,
395            };
396        }
397
398        let num_blocks = n / block_len;
399        StridedBlocks::Multiple {
400            layout: self,
401            block_len,
402            num_blocks,
403        }
404    }
405}
406
407/// Result of strided block analysis.
408#[derive(Debug, Clone)]
409pub enum StridedBlocks<'a> {
410    /// Single contiguous block - direct slice access.
411    Single { start: usize, len: usize },
412    /// Multiple blocks requiring iteration.
413    Multiple {
414        layout: &'a Layout,
415        block_len: usize,
416        num_blocks: usize,
417    },
418}
419
420impl<'a> StridedBlocks<'a> {
421    /// Get the block length (elements per block).
422    pub fn block_len(&self) -> usize {
423        match self {
424            Self::Single { len, .. } => *len,
425            Self::Multiple { block_len, .. } => *block_len,
426        }
427    }
428
429    /// Iterator over block start indices.
430    pub fn block_starts(&self) -> BlockStartIter<'_> {
431        match self {
432            Self::Single { start, .. } => BlockStartIter::Single {
433                start: *start,
434                done: false,
435            },
436            Self::Multiple {
437                layout,
438                block_len,
439                num_blocks,
440            } => {
441                // Calculate dimensions for outer iteration (non-contiguous part)
442                let ndims = layout.num_dims();
443                let mut outer_dims = 0;
444                let mut expected_stride = 1isize;
445
446                for i in (0..ndims).rev() {
447                    if layout.strides[i] == expected_stride {
448                        expected_stride *= layout.shape[i] as isize;
449                    } else {
450                        outer_dims = i + 1;
451                        break;
452                    }
453                }
454
455                BlockStartIter::Multiple {
456                    layout,
457                    multi_index: vec![0; outer_dims],
458                    remaining: *num_blocks,
459                    block_len: *block_len,
460                }
461            }
462        }
463    }
464}
465
466/// Iterator over block start indices.
467pub enum BlockStartIter<'a> {
468    Single {
469        start: usize,
470        done: bool,
471    },
472    Multiple {
473        layout: &'a Layout,
474        multi_index: Vec<usize>,
475        remaining: usize,
476        block_len: usize,
477    },
478}
479
480impl Iterator for BlockStartIter<'_> {
481    type Item = usize;
482
483    fn next(&mut self) -> Option<usize> {
484        match self {
485            Self::Single { start, done } => {
486                if *done {
487                    None
488                } else {
489                    *done = true;
490                    Some(*start)
491                }
492            }
493            Self::Multiple {
494                layout,
495                multi_index,
496                remaining,
497                block_len: _,
498            } => {
499                if *remaining == 0 {
500                    return None;
501                }
502
503                // Compute current block start
504                let outer_dims = multi_index.len();
505                let mut offset = layout.start_offset as isize;
506                for (i, &idx) in multi_index.iter().enumerate() {
507                    offset += idx as isize * layout.strides[i];
508                }
509
510                *remaining -= 1;
511
512                // Advance multi-index for next iteration
513                let shape = &layout.shape;
514                for d in (0..outer_dims).rev() {
515                    multi_index[d] += 1;
516                    if multi_index[d] < shape[d] {
517                        break;
518                    }
519                    multi_index[d] = 0;
520                }
521
522                Some(offset as usize)
523            }
524        }
525    }
526
527    fn size_hint(&self) -> (usize, Option<usize>) {
528        let len = match self {
529            Self::Single { done, .. } => {
530                if *done {
531                    0
532                } else {
533                    1
534                }
535            }
536            Self::Multiple { remaining, .. } => *remaining,
537        };
538        (len, Some(len))
539    }
540}
541
542impl ExactSizeIterator for BlockStartIter<'_> {}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_contiguous_layout() {
550        let layout = Layout::contiguous(Shape::from(vec![2, 3, 4]));
551        assert_eq!(layout.strides(), &[12, 4, 1]);
552        assert!(layout.is_contiguous());
553    }
554
555    #[test]
556    fn test_transpose() {
557        let layout = Layout::contiguous(Shape::from(vec![2, 3]));
558        let transposed = layout.transpose(0, 1);
559        assert_eq!(transposed.shape().to_vec(), vec![3, 2]);
560        assert_eq!(transposed.strides(), &[1, 3]);
561        assert!(!transposed.is_contiguous());
562    }
563
564    #[test]
565    fn test_narrow() {
566        let layout = Layout::contiguous(Shape::from(vec![4, 4]));
567        let narrowed = layout.narrow(0, 1, 2);
568        assert_eq!(narrowed.shape().to_vec(), vec![2, 4]);
569        assert_eq!(narrowed.start_offset(), 4);
570    }
571
572    #[test]
573    fn test_contiguous_offsets() {
574        let layout = Layout::contiguous(Shape::from(vec![2, 3]));
575        assert_eq!(layout.contiguous_offsets(), Some((0, 6)));
576    }
577
578    #[test]
579    fn test_index() {
580        let layout = Layout::contiguous(Shape::from(vec![2, 3]));
581        assert_eq!(layout.index(&[0, 0]), 0);
582        assert_eq!(layout.index(&[0, 2]), 2);
583        assert_eq!(layout.index(&[1, 0]), 3);
584        assert_eq!(layout.index(&[1, 2]), 5);
585    }
586
587    #[test]
588    fn test_flip_1d() {
589        // Original: [0, 1, 2, 3] with strides [1]
590        // Flipped: strides [-1], start_offset = 3
591        let layout = Layout::contiguous(Shape::from(vec![4]));
592        let flipped = layout.flip(&[0]);
593
594        assert_eq!(flipped.shape().to_vec(), vec![4]);
595        assert_eq!(flipped.strides(), &[-1]);
596        assert_eq!(flipped.start_offset(), 3);
597
598        // Verify indices: logical [0] -> physical [3], logical [1] -> physical [2], etc.
599        assert_eq!(flipped.index(&[0]), 3);
600        assert_eq!(flipped.index(&[1]), 2);
601        assert_eq!(flipped.index(&[2]), 1);
602        assert_eq!(flipped.index(&[3]), 0);
603    }
604
605    #[test]
606    fn test_flip_2d_axis0() {
607        // [[0, 1, 2], [3, 4, 5]] with strides [3, 1]
608        // Flip axis 0: strides [-3, 1], start_offset = 3
609        let layout = Layout::contiguous(Shape::from(vec![2, 3]));
610        let flipped = layout.flip(&[0]);
611
612        assert_eq!(flipped.strides(), &[-3, 1]);
613        assert_eq!(flipped.start_offset(), 3);
614
615        // Row 0 of flipped = Row 1 of original
616        assert_eq!(flipped.index(&[0, 0]), 3);
617        assert_eq!(flipped.index(&[0, 1]), 4);
618        assert_eq!(flipped.index(&[0, 2]), 5);
619        // Row 1 of flipped = Row 0 of original
620        assert_eq!(flipped.index(&[1, 0]), 0);
621        assert_eq!(flipped.index(&[1, 1]), 1);
622        assert_eq!(flipped.index(&[1, 2]), 2);
623    }
624
625    #[test]
626    fn test_flip_2d_axis1() {
627        // [[0, 1, 2], [3, 4, 5]] with strides [3, 1]
628        // Flip axis 1: strides [3, -1], start_offset = 2
629        let layout = Layout::contiguous(Shape::from(vec![2, 3]));
630        let flipped = layout.flip(&[1]);
631
632        assert_eq!(flipped.strides(), &[3, -1]);
633        assert_eq!(flipped.start_offset(), 2);
634
635        // Col 0 of flipped = Col 2 of original
636        assert_eq!(flipped.index(&[0, 0]), 2);
637        assert_eq!(flipped.index(&[0, 1]), 1);
638        assert_eq!(flipped.index(&[0, 2]), 0);
639        assert_eq!(flipped.index(&[1, 0]), 5);
640        assert_eq!(flipped.index(&[1, 1]), 4);
641        assert_eq!(flipped.index(&[1, 2]), 3);
642    }
643
644    #[test]
645    fn test_flip_both_axes() {
646        // [[0, 1, 2], [3, 4, 5]] -> [[5, 4, 3], [2, 1, 0]]
647        let layout = Layout::contiguous(Shape::from(vec![2, 3]));
648        let flipped = layout.flip(&[0, 1]);
649
650        assert_eq!(flipped.strides(), &[-3, -1]);
651        assert_eq!(flipped.start_offset(), 5); // 3 + 2 = 5
652
653        assert_eq!(flipped.index(&[0, 0]), 5);
654        assert_eq!(flipped.index(&[0, 1]), 4);
655        assert_eq!(flipped.index(&[0, 2]), 3);
656        assert_eq!(flipped.index(&[1, 0]), 2);
657        assert_eq!(flipped.index(&[1, 1]), 1);
658        assert_eq!(flipped.index(&[1, 2]), 0);
659    }
660}