Skip to main content

burn_flex/
tensor.rs

1#[cfg(target_has_atomic = "ptr")]
2use alloc::sync::Arc;
3use alloc::vec::Vec;
4use core::fmt;
5#[cfg(not(target_has_atomic = "ptr"))]
6use portable_atomic_util::Arc;
7
8use burn_backend::{DType, Element, TensorData, TensorMetadata};
9use burn_std::{Bytes, Shape, bf16, f16};
10
11use crate::layout::Layout;
12
13/// CPU tensor primitive for the Flex backend.
14///
15/// Uses type-erased byte storage with runtime dtype and Arc-based sharing.
16/// Clone is O(1) (refcount increment). Copy-on-write for mutations.
17#[derive(Clone)]
18pub struct FlexTensor {
19    /// Shared byte storage. Clone increments refcount.
20    data: Arc<Bytes>,
21    /// Layout describing shape, strides, and offset.
22    layout: Layout,
23    /// Runtime data type.
24    dtype: DType,
25}
26
27impl fmt::Debug for FlexTensor {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("FlexTensor")
30            .field("shape", self.layout.shape())
31            .field("dtype", &self.dtype)
32            .field("contiguous", &self.layout.is_contiguous())
33            .field("unique", &self.is_unique())
34            .finish()
35    }
36}
37
38impl FlexTensor {
39    /// Create a new tensor from bytes, layout, and dtype.
40    pub fn new(data: Bytes, layout: Layout, dtype: DType) -> Self {
41        Self {
42            data: Arc::new(data),
43            layout,
44            dtype,
45        }
46    }
47
48    /// Create a tensor from TensorData.
49    pub fn from_data(data: TensorData) -> Self {
50        let shape = data.shape.clone();
51        let layout = Layout::contiguous(shape);
52        let dtype = data.dtype;
53        Self {
54            data: Arc::new(data.bytes),
55            layout,
56            dtype,
57        }
58    }
59
60    /// Convert tensor to TensorData.
61    ///
62    /// If non-contiguous or shared, this will copy data.
63    pub fn into_data(self) -> TensorData {
64        if self.layout.is_contiguous() && self.layout.start_offset() == 0 {
65            let expected_bytes = self.layout.num_elements() * dtype_size(self.dtype);
66            assert!(
67                expected_bytes <= self.data.len(),
68                "into_data: buffer ({} bytes) too small for {} elements of {:?}",
69                self.data.len(),
70                self.layout.num_elements(),
71                self.dtype
72            );
73            if self.data.len() == expected_bytes {
74                // Buffer exactly matches logical size; try zero-copy unwrap
75                match Arc::try_unwrap(self.data) {
76                    Ok(bytes) => TensorData {
77                        bytes,
78                        shape: self.layout.shape().clone(),
79                        dtype: self.dtype,
80                    },
81                    Err(arc) => {
82                        let bytes = Bytes::from_bytes_vec((*arc)[..expected_bytes].to_vec());
83                        TensorData {
84                            bytes,
85                            shape: self.layout.shape().clone(),
86                            dtype: self.dtype,
87                        }
88                    }
89                }
90            } else {
91                // Contiguous at offset 0 but buffer is oversized (e.g., narrowed view).
92                // Truncate to exact logical size.
93                let bytes = Bytes::from_bytes_vec(self.data[..expected_bytes].to_vec());
94                TensorData {
95                    bytes,
96                    shape: self.layout.shape().clone(),
97                    dtype: self.dtype,
98                }
99            }
100        } else {
101            // Non-contiguous or non-zero offset: copy to contiguous layout
102            self.to_contiguous().into_data()
103        }
104    }
105
106    /// Check if this tensor has exclusive ownership of its data.
107    ///
108    /// When true, in-place mutations are safe without copying.
109    #[inline]
110    pub fn is_unique(&self) -> bool {
111        Arc::strong_count(&self.data) == 1
112    }
113
114    /// Get the layout.
115    pub fn layout(&self) -> &Layout {
116        &self.layout
117    }
118
119    /// Create a new tensor with a different layout but sharing the same data.
120    ///
121    /// This is a zero-copy operation used for operations like flip, transpose, etc.
122    pub fn with_layout(self, layout: Layout) -> Self {
123        Self {
124            data: self.data,
125            layout,
126            dtype: self.dtype,
127        }
128    }
129
130    /// Get the dtype.
131    pub fn dtype(&self) -> DType {
132        self.dtype
133    }
134
135    /// Check if tensor is contiguous.
136    pub fn is_contiguous(&self) -> bool {
137        self.layout.is_contiguous()
138    }
139
140    /// Get the raw bytes (read-only).
141    pub fn bytes(&self) -> &[u8] {
142        &self.data
143    }
144
145    /// Get a clone of the Arc for sharing data with a new layout.
146    ///
147    /// Use this for zero-copy view operations (reshape, transpose, slice).
148    pub fn data_arc(&self) -> Arc<Bytes> {
149        Arc::clone(&self.data)
150    }
151
152    /// Create a tensor from shared data, layout, and dtype.
153    ///
154    /// Use this for zero-copy view operations.
155    pub fn from_arc(data: Arc<Bytes>, layout: Layout, dtype: DType) -> Self {
156        Self {
157            data,
158            layout,
159            dtype,
160        }
161    }
162
163    /// Zero-copy typed view of the full storage buffer.
164    ///
165    /// Use with `StridedIter` for non-contiguous access, or with
166    /// `layout().contiguous_offsets()` for the contiguous fast path.
167    ///
168    /// # Panics
169    /// Panics if `E::dtype()` doesn't match the tensor's dtype.
170    /// Note: Bool tensors are stored as u8, so both Bool(Native) and Bool(U8)
171    /// dtypes accept u8 access.
172    pub fn storage<E: Element + bytemuck::Pod>(&self) -> &[E] {
173        assert!(
174            E::dtype() == self.dtype
175                || (matches!(
176                    self.dtype,
177                    DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
178                ) && E::dtype() == DType::U8),
179            "storage: dtype mismatch (expected {:?}, got {:?})",
180            self.dtype,
181            E::dtype()
182        );
183        bytemuck::cast_slice(&self.data)
184    }
185
186    /// Mutable typed view with copy-on-write semantics.
187    ///
188    /// If the tensor is shared (refcount > 1), this will copy the data first.
189    /// For in-place operations, prefer `try_storage_mut()` which returns None
190    /// if shared, allowing you to choose an alternative strategy.
191    ///
192    /// # Panics
193    /// Panics if `E::dtype()` doesn't match the tensor's dtype.
194    /// Note: Bool tensors are stored as u8, so both Bool(Native) and Bool(U8)
195    /// dtypes accept u8 access.
196    pub fn storage_mut<E: Element + bytemuck::Pod>(&mut self) -> &mut [E] {
197        assert!(
198            E::dtype() == self.dtype
199                || (matches!(
200                    self.dtype,
201                    DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
202                ) && E::dtype() == DType::U8),
203            "storage_mut: dtype mismatch (expected {:?}, got {:?})",
204            self.dtype,
205            E::dtype()
206        );
207        // COW: clone data if shared
208        let bytes = Arc::make_mut(&mut self.data);
209        bytemuck::cast_slice_mut(bytes)
210    }
211
212    /// Try to get mutable storage without copying.
213    ///
214    /// Returns `Some` if tensor is uniquely owned, `None` if shared.
215    /// Use this when you want to avoid the implicit copy in `storage_mut()`.
216    /// Note: Bool tensors are stored as u8, so both Bool(Native) and Bool(U8)
217    /// dtypes accept u8 access.
218    pub fn try_storage_mut<E: Element + bytemuck::Pod>(&mut self) -> Option<&mut [E]> {
219        assert!(
220            E::dtype() == self.dtype
221                || (matches!(
222                    self.dtype,
223                    DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8)
224                ) && E::dtype() == DType::U8),
225            "try_storage_mut: dtype mismatch (expected {:?}, got {:?})",
226            self.dtype,
227            E::dtype()
228        );
229        if self.is_unique() {
230            // Safe: we're the only owner
231            let bytes = Arc::get_mut(&mut self.data)?;
232            Some(bytemuck::cast_slice_mut(bytes))
233        } else {
234            None
235        }
236    }
237
238    /// Get typed slice view (zero-cost if contiguous and offset is 0).
239    ///
240    /// Returns None if dtype doesn't match E or tensor is non-contiguous.
241    pub fn as_slice<E: Element + bytemuck::Pod>(&self) -> Option<&[E]> {
242        if E::dtype() != self.dtype {
243            return None;
244        }
245        let storage: &[E] = self.storage();
246        self.layout
247            .contiguous_offsets()
248            .map(|(start, end)| &storage[start..end])
249    }
250
251    /// Create an empty tensor with given shape and dtype.
252    pub fn empty(shape: Shape, dtype: DType) -> Self {
253        let num_elements = shape.num_elements();
254        let elem_size = dtype_size(dtype);
255        let bytes = Bytes::from_bytes_vec(alloc::vec![0u8; num_elements * elem_size]);
256        let layout = Layout::contiguous(shape);
257        Self {
258            data: Arc::new(bytes),
259            layout,
260            dtype,
261        }
262    }
263
264    /// Create a tensor filled with zeros.
265    pub fn zeros(shape: Shape, dtype: DType) -> Self {
266        Self::empty(shape, dtype)
267    }
268
269    /// Create a tensor filled with `n` copies of a typed value.
270    pub fn filled_typed<E: bytemuck::Pod + Send + Sync>(
271        shape: Shape,
272        dtype: DType,
273        value: E,
274    ) -> Self {
275        assert_eq!(
276            dtype_size(dtype),
277            core::mem::size_of::<E>(),
278            "filled_typed: dtype size mismatch"
279        );
280        let n = shape.num_elements();
281        let data = alloc::vec![value; n];
282        let bytes = Bytes::from_elems(data);
283        Self {
284            data: Arc::new(bytes),
285            layout: Layout::contiguous(shape),
286            dtype,
287        }
288    }
289
290    /// Copy to contiguous layout if needed.
291    pub fn to_contiguous(&self) -> Self {
292        // Fast path requires the logical tensor to cover the whole buffer.
293        // A contiguous prefix view (e.g. [8, 5] sliced to [5, 5]) has
294        // canonical strides and offset 0 but an oversized buffer, and would
295        // otherwise mislead callers that read `storage()` / `bytes()` by
296        // length (e.g. the SIMD `mask_fill_*` kernels).
297        if self.is_contiguous()
298            && self.layout.start_offset() == 0
299            && self.data.len() == self.layout.num_elements() * dtype_size(self.dtype)
300        {
301            return self.clone();
302        }
303
304        // Copy data to new contiguous buffer
305        match self.dtype {
306            DType::F64 => self.copy_contiguous::<f64>(),
307            DType::F32 => self.copy_contiguous::<f32>(),
308            DType::F16 => self.copy_contiguous::<f16>(),
309            DType::BF16 => self.copy_contiguous::<bf16>(),
310            DType::I64 => self.copy_contiguous::<i64>(),
311            DType::I32 => self.copy_contiguous::<i32>(),
312            DType::I16 => self.copy_contiguous::<i16>(),
313            DType::I8 => self.copy_contiguous::<i8>(),
314            DType::U64 => self.copy_contiguous::<u64>(),
315            DType::U32 => self.copy_contiguous::<u32>(),
316            DType::U16 => self.copy_contiguous::<u16>(),
317            DType::U8 => self.copy_contiguous::<u8>(),
318            DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8) => {
319                self.copy_contiguous::<u8>()
320            }
321            DType::Bool(burn_std::BoolStore::U32) => {
322                panic!("burn-flex: Bool(U32) storage is not yet supported")
323            }
324            _ => panic!("Unsupported dtype for contiguous copy: {:?}", self.dtype),
325        }
326    }
327
328    fn copy_contiguous<E: Element + bytemuck::Pod>(&self) -> Self {
329        let src: &[E] = bytemuck::cast_slice(&self.data);
330        let n = self.layout.num_elements();
331        let mut dst = Vec::with_capacity(n);
332
333        // Squeeze size-1 dims and merge adjacent stride-contiguous
334        // runs so e.g. a permuted `[N, H, W, C]` ConvNeXt layer-norm
335        // input becomes a plain 2D `[H*W, C]` transpose that the
336        // tiled copy below handles at near-memcpy speed. Without the
337        // collapse, the 4D ND fallback scalar-walks the tensor.
338        let collapsed = collapse_for_copy(self.layout.shape(), self.layout.strides());
339        let (shape, strides) = collapsed.as_slices();
340        let offset = self.layout.start_offset() as isize;
341        let all_positive = strides.iter().all(|&s| s >= 0);
342
343        if shape.len() <= 1 && all_positive {
344            // 0-D scalar or 1-D run with a uniform stride. Empty
345            // collapsed shape means rank 0 (numel 1); otherwise
346            // numel is the single dim's size (which may be 0 for
347            // zero-sized 1D tensors, so don't clamp via `.max(1)`).
348            let collapsed_numel = if shape.is_empty() { 1 } else { shape[0] };
349            debug_assert_eq!(n, collapsed_numel);
350            // SAFETY: capacity is n; we fill every position below.
351            unsafe { dst.set_len(n) };
352            if shape.is_empty() {
353                if n > 0 {
354                    dst[0] = src[offset as usize];
355                }
356            } else {
357                let len = shape[0];
358                let stride = strides[0];
359                if stride == 1 {
360                    dst[..len].copy_from_slice(&src[offset as usize..offset as usize + len]);
361                } else {
362                    for (i, slot) in dst.iter_mut().take(len).enumerate() {
363                        let idx = (offset + i as isize * stride) as usize;
364                        *slot = src[idx];
365                    }
366                }
367            }
368        } else if shape.len() == 2 && all_positive {
369            // 2D positive-stride (transpose-like): tile both dims so
370            // reads stay in cache. The loop-nesting chooser inside
371            // `copy_2d_tiled` picks whichever ordering puts the
372            // smaller source stride on the innermost loop.
373            debug_assert_eq!(shape[0] * shape[1], n, "2D strides must cover all elements");
374            // SAFETY: capacity is n; `copy_2d_tiled` writes every
375            // `(row, col)` position exactly once.
376            unsafe { dst.set_len(n) };
377            copy_2d_tiled(
378                &mut dst, src, offset, shape[0], shape[1], strides[0], strides[1],
379            );
380        } else {
381            // General fallback: covers negative strides (flipped
382            // tensors) and ND layouts that can't collapse to ≤2D.
383            for idx in crate::strided_index::StridedIter::new(&self.layout) {
384                dst.push(src[idx]);
385            }
386        }
387
388        let bytes = Bytes::from_elems(dst);
389        let layout = Layout::contiguous(self.layout.shape().clone());
390        Self {
391            data: Arc::new(bytes),
392            layout,
393            dtype: self.dtype,
394        }
395    }
396
397    /// Reshape tensor. Zero-copy if contiguous.
398    pub fn reshape(&self, new_shape: Shape) -> Self {
399        assert_eq!(
400            self.layout.num_elements(),
401            new_shape.num_elements(),
402            "reshape must preserve total elements"
403        );
404
405        if let Some(new_layout) = self.layout.reshape(new_shape.clone()) {
406            Self {
407                data: Arc::clone(&self.data),
408                layout: new_layout,
409                dtype: self.dtype,
410            }
411        } else {
412            // Non-contiguous: copy first
413            self.to_contiguous().reshape(new_shape)
414        }
415    }
416
417    /// Transpose two dimensions. Zero-copy (metadata only).
418    pub fn transpose(&self, dim1: usize, dim2: usize) -> Self {
419        Self {
420            data: Arc::clone(&self.data),
421            layout: self.layout.transpose(dim1, dim2),
422            dtype: self.dtype,
423        }
424    }
425
426    /// Narrow/slice along a dimension. Zero-copy (metadata only).
427    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
428        Self {
429            data: Arc::clone(&self.data),
430            layout: self.layout.narrow(dim, start, len),
431            dtype: self.dtype,
432        }
433    }
434
435    /// Permute dimensions according to axes. Zero-copy (metadata only).
436    pub fn permute(&self, axes: &[usize]) -> Self {
437        Self {
438            data: Arc::clone(&self.data),
439            layout: self.layout.permute(axes),
440            dtype: self.dtype,
441        }
442    }
443}
444
445impl TensorMetadata for FlexTensor {
446    fn dtype(&self) -> DType {
447        self.dtype
448    }
449
450    fn shape(&self) -> Shape {
451        self.layout.shape().clone()
452    }
453
454    fn rank(&self) -> usize {
455        self.layout.num_dims()
456    }
457}
458
459/// Max rank we're willing to handle without falling back to the
460/// strided iterator. Burn tensors are capped at 8 dims in practice.
461const COLLAPSE_MAX_RANK: usize = 8;
462
463/// Collapsed layout result of [`collapse_for_copy`], stored in stack
464/// arrays so `to_contiguous()` doesn't have to hit the allocator on
465/// its hot path.
466#[derive(Debug, Clone, Copy)]
467struct CollapsedLayout {
468    ndim: usize,
469    shape: [usize; COLLAPSE_MAX_RANK],
470    strides: [isize; COLLAPSE_MAX_RANK],
471}
472
473impl CollapsedLayout {
474    #[inline]
475    fn as_slices(&self) -> (&[usize], &[isize]) {
476        (&self.shape[..self.ndim], &self.strides[..self.ndim])
477    }
478}
479
480/// Collapse a shape/stride pair into the minimum-rank equivalent
481/// layout for a contiguous copy:
482///
483/// 1. Squeeze size-1 dims (their stride never gets stepped past 0).
484/// 2. Merge adjacent dims `(i, i+1)` when
485///    `stride[i] == stride[i+1] * shape[i+1]`, which means the two
486///    dims form a single logical run through memory.
487///
488/// Canonical example: a 4D ConvNeXt input `[1, 244, 224, 48]` with
489/// strides `[2_623_488, 224, 1, 54656]` (from
490/// `[N, C, H, W].permute([0, 2, 3, 1])`) collapses to 2D
491/// `[54656, 48]` with strides `[1, 54656]`.
492///
493/// If the input rank exceeds [`COLLAPSE_MAX_RANK`] the result is
494/// left at rank > 2 so the caller falls through to its generic
495/// strided path. If the input is rank > `COLLAPSE_MAX_RANK`, we
496/// return the original (un-collapsed) layout truncated, which the
497/// caller will reject via its `shape.len() == 2` gate.
498///
499/// PRECONDITION: the caller must gate on all-positive strides before
500/// using the collapsed layout. The merge rule assumes positive
501/// strides and will produce iteration-order-incorrect results for
502/// flipped tensors.
503fn collapse_for_copy(shape: &[usize], strides: &[isize]) -> CollapsedLayout {
504    let mut out = CollapsedLayout {
505        ndim: 0,
506        shape: [0; COLLAPSE_MAX_RANK],
507        strides: [0; COLLAPSE_MAX_RANK],
508    };
509
510    // Bail out to the caller's fallback if the rank is too large to
511    // fit our stack buffer. In practice this never triggers (burn
512    // tensors are ≤8 dims), but leaving the `ndim` high signals the
513    // caller to take the generic strided path.
514    if shape.len() > COLLAPSE_MAX_RANK {
515        out.ndim = shape.len().min(COLLAPSE_MAX_RANK);
516        return out;
517    }
518
519    // Single forward sweep: squeeze size-1 dims and merge whenever
520    // the current dim's `stride * size` equals the previous output
521    // dim's stride (i.e. the two form a contiguous run).
522    //
523    // Use `checked_mul` so a pathological layout whose stride math
524    // would overflow `isize` simply fails to merge rather than
525    // wrapping into an incorrect merge decision. Real tensors can't
526    // hit this (total numel is bounded by `isize::MAX`), but
527    // hand-built layouts passed through the test paths could.
528    for (&s, &st) in shape.iter().zip(strides.iter()) {
529        if s == 1 {
530            continue;
531        }
532        let merge = out.ndim > 0
533            && (s as isize)
534                .checked_mul(st)
535                .is_some_and(|run| out.strides[out.ndim - 1] == run);
536        if merge {
537            out.shape[out.ndim - 1] *= s;
538            out.strides[out.ndim - 1] = st;
539        } else {
540            out.shape[out.ndim] = s;
541            out.strides[out.ndim] = st;
542            out.ndim += 1;
543        }
544    }
545
546    out
547}
548
549/// Tiled 2D copy from a strided source into a contiguous destination.
550/// The loop nesting is chosen so the innermost read walks whichever
551/// source stride is smaller, which keeps the hot loop in cache even
552/// for transpose-like layouts.
553#[inline]
554fn copy_2d_tiled<E: Copy>(
555    dst: &mut [E],
556    src: &[E],
557    offset: isize,
558    rows: usize,
559    cols: usize,
560    row_stride: isize,
561    col_stride: isize,
562) {
563    const TILE: usize = 16;
564
565    if row_stride <= col_stride {
566        // row-inside-col: the inner loop walks `row_stride` (smaller).
567        for col_tile in (0..cols).step_by(TILE) {
568            let col_end = (col_tile + TILE).min(cols);
569            for row_tile in (0..rows).step_by(TILE) {
570                let row_end = (row_tile + TILE).min(rows);
571                for col in col_tile..col_end {
572                    let col_base = offset + col as isize * col_stride;
573                    for row in row_tile..row_end {
574                        let idx = (col_base + row as isize * row_stride) as usize;
575                        // SAFETY: caller set `dst.len() == rows * cols`
576                        // and each `(row, col)` is visited once.
577                        unsafe {
578                            *dst.get_unchecked_mut(row * cols + col) = src[idx];
579                        }
580                    }
581                }
582            }
583        }
584    } else {
585        // col-inside-row: the inner loop walks `col_stride` (smaller).
586        for row_tile in (0..rows).step_by(TILE) {
587            let row_end = (row_tile + TILE).min(rows);
588            for col_tile in (0..cols).step_by(TILE) {
589                let col_end = (col_tile + TILE).min(cols);
590                for row in row_tile..row_end {
591                    let row_base =
592                        offset + row as isize * row_stride + col_tile as isize * col_stride;
593                    let dst_base = row * cols + col_tile;
594                    for c in 0..(col_end - col_tile) {
595                        let idx = (row_base + c as isize * col_stride) as usize;
596                        // SAFETY: same as above.
597                        unsafe {
598                            *dst.get_unchecked_mut(dst_base + c) = src[idx];
599                        }
600                    }
601                }
602            }
603        }
604    }
605}
606
607/// Get the size in bytes for a dtype element.
608///
609/// Matches `burn_std::DType::size()` semantics: Bool(Native) and Bool(U8) are
610/// 1 byte, Bool(U32) is 4 bytes. This makes buffer-size validation correct
611/// regardless of which BoolStore variant the dtype carries.
612///
613/// # Panics
614///
615/// Panics if the dtype has a zero-byte element size. `burn_std::DType::size()`
616/// returns 0 for sub-byte quantized dtypes (Q4F, Q4S, Q2F, Q2S, and most
617/// `QuantStore::PackedNative` variants). burn-flex does not yet support these
618/// packed quantization formats; passing them here would silently produce
619/// empty allocations in `FlexTensor::empty`, truncated buffers in `into_data`,
620/// and zero-byte memcpys in `repeat_dim`. The panic turns all three into a
621/// loud, actionable failure at the dispatch boundary.
622pub(crate) fn dtype_size(dtype: DType) -> usize {
623    // Delegate to burn-std's canonical size to stay in sync.
624    let size = dtype.size();
625    assert!(
626        size > 0,
627        "burn-flex: dtype {:?} has zero-byte element size (sub-byte packed \
628         quantization is not yet supported)",
629        dtype
630    );
631    size
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use alloc::vec;
638
639    #[test]
640    fn test_from_data_roundtrip() {
641        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
642        let tensor = FlexTensor::from_data(data.clone());
643        let result = tensor.into_data();
644        assert_eq!(data.shape, result.shape);
645        assert_eq!(data.dtype, result.dtype);
646    }
647
648    #[test]
649    fn test_collapse_for_copy_squeezes_size1_and_merges_contig() {
650        // Permuted ConvNeXt input: [1, 48, 244, 224].permute([0,2,3,1]).
651        let shape = vec![1, 244, 224, 48];
652        let strides = vec![2_623_488_isize, 224, 1, 54656];
653        let collapsed = collapse_for_copy(&shape, &strides);
654        let (s, st) = collapsed.as_slices();
655        assert_eq!(s, &[54656, 48]);
656        assert_eq!(st, &[1, 54656]);
657    }
658
659    #[test]
660    fn test_collapse_for_copy_already_contiguous_3d() {
661        let collapsed = collapse_for_copy(&[2, 3, 4], &[12, 4, 1]);
662        let (s, st) = collapsed.as_slices();
663        assert_eq!(s, &[24]);
664        assert_eq!(st, &[1]);
665    }
666
667    #[test]
668    fn test_collapse_for_copy_transpose_2d() {
669        let collapsed = collapse_for_copy(&[5, 3], &[1, 5]);
670        let (s, st) = collapsed.as_slices();
671        assert_eq!(s, &[5, 3]);
672        assert_eq!(st, &[1, 5]);
673    }
674
675    #[test]
676    fn test_collapse_for_copy_all_size1() {
677        let collapsed = collapse_for_copy(&[1, 1, 1], &[0, 0, 0]);
678        let (s, st) = collapsed.as_slices();
679        assert!(s.is_empty());
680        assert!(st.is_empty());
681    }
682
683    /// Regression: an empty 1D view produced by `narrow` at a
684    /// non-zero offset forces `copy_contiguous` to run (it can't
685    /// early-return via the contiguous-at-offset-0 shortcut). The
686    /// old `debug_assert_eq!(n, shape.product().max(1))` tripped
687    /// for this shape because `.max(1)` produced 1 while the true
688    /// numel is 0.
689    #[test]
690    fn test_to_contiguous_zero_sized_narrowed() {
691        let t = FlexTensor::from_data(TensorData::new(
692            (0..6).map(|i| i as f32).collect::<Vec<_>>(),
693            vec![6],
694        ));
695        // narrow(dim, start=3, len=0): shape [0], start_offset 3.
696        let empty_view = t.narrow(0, 3, 0);
697        assert_eq!(empty_view.shape().to_vec(), vec![0]);
698        assert_ne!(empty_view.layout().start_offset(), 0);
699
700        let contig = empty_view.to_contiguous();
701        assert_eq!(contig.shape().to_vec(), vec![0]);
702        assert_eq!(contig.layout().start_offset(), 0);
703        assert_eq!(contig.into_data().bytes.len(), 0);
704    }
705
706    /// Regression for #4855: a prefix view (e.g. `narrow(dim, 0, n)`) has
707    /// canonical contiguous strides and start_offset 0, but its underlying
708    /// buffer is still the larger original. `to_contiguous` must materialize
709    /// a right-sized copy so callers keying off `storage().len()` (like the
710    /// SIMD `mask_fill_*` kernels reached from `triu`/`tril` in LU on tall
711    /// matrices) don't walk past the logical shape.
712    #[test]
713    fn test_to_contiguous_prefix_view_shrinks_buffer() {
714        let data: Vec<f32> = (0..40).map(|i| i as f32).collect();
715        let t = FlexTensor::from_data(TensorData::new(data, vec![8, 5]));
716
717        let prefix = t.narrow(0, 0, 5);
718        assert_eq!(prefix.shape().to_vec(), vec![5, 5]);
719        assert_eq!(prefix.layout().strides(), &[5, 1]);
720        assert_eq!(prefix.layout().start_offset(), 0);
721        assert!(prefix.is_contiguous());
722        assert_eq!(prefix.storage::<f32>().len(), 40);
723
724        let contig = prefix.to_contiguous();
725        assert_eq!(contig.storage::<f32>().len(), 25);
726        assert_eq!(contig.layout().num_elements(), 25);
727        assert_eq!(
728            contig.storage::<f32>(),
729            &(0..5)
730                .flat_map(|r| (0..5).map(move |c| (r * 5 + c) as f32))
731                .collect::<Vec<_>>()[..]
732        );
733    }
734
735    /// 4D permuted layout round-trips through the collapse + tiled
736    /// copy path. Mirrors the ConvNeXt channels-last permute.
737    #[test]
738    fn test_to_contiguous_4d_permuted_matches_naive() {
739        let dims = [1, 48, 4, 5];
740        let n: usize = dims.iter().product();
741        let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
742        let t = FlexTensor::from_data(TensorData::new(data.clone(), dims.to_vec()));
743        let permuted = t.permute(&[0, 2, 3, 1]);
744        assert!(!permuted.is_contiguous());
745
746        let contig = permuted.to_contiguous();
747        assert!(contig.is_contiguous());
748        assert_eq!(contig.shape().to_vec(), vec![1, 4, 5, 48]);
749
750        // Expected via manual strided walk of the source.
751        let mut expected = Vec::with_capacity(n);
752        for h in 0..4 {
753            for w in 0..5 {
754                for c in 0..48 {
755                    let idx = c * 20 + h * 5 + w;
756                    expected.push(data[idx]);
757                }
758            }
759        }
760
761        let result_data = contig.into_data();
762        let values = result_data.as_slice::<f32>().unwrap();
763        assert_eq!(values, expected.as_slice());
764    }
765
766    /// Exercise the `row_stride > col_stride` branch of the 2D tiled
767    /// copy (the ConvNeXt case hits the other branch).
768    #[test]
769    fn test_to_contiguous_2d_row_stride_gt_col_stride() {
770        // `slice(s![..;2, ..])` on a [6, 3] contiguous tensor gives a
771        // [3, 3] view with strides [6, 1] that doesn't collapse, so
772        // the 2D branch runs with row_stride > col_stride.
773        let data: Vec<f32> = (0..18).map(|i| i as f32).collect();
774        let t = FlexTensor::from_data(TensorData::new(data, vec![6, 3]));
775        let stepped = crate::ops::slice::slice(
776            t,
777            &[
778                burn_std::Slice::new(0, Some(6), 2),
779                burn_std::Slice::new(0, None, 1),
780            ],
781        );
782        // Verify the layout matches what the branch requires.
783        assert_eq!(stepped.layout().shape().to_vec(), vec![3, 3]);
784        assert_eq!(stepped.layout().strides(), &[6, 1]);
785        assert!(!stepped.layout().is_contiguous());
786
787        let contig = stepped.to_contiguous();
788        assert!(contig.is_contiguous());
789        assert_eq!(contig.shape().to_vec(), vec![3, 3]);
790
791        let result_data = contig.into_data();
792        let values = result_data.as_slice::<f32>().unwrap();
793        // Expected: rows 0, 2, 4 of the original 6x3 tensor.
794        let expected = vec![
795            0.0f32, 1.0, 2.0, // row 0
796            6.0, 7.0, 8.0, // row 2
797            12.0, 13.0, 14.0, // row 4
798        ];
799        assert_eq!(values, expected.as_slice());
800    }
801
802    #[test]
803    fn test_reshape() {
804        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
805        let tensor = FlexTensor::from_data(data);
806        let reshaped = tensor.reshape(Shape::from(vec![3, 2]));
807        assert_eq!(reshaped.shape().to_vec(), vec![3, 2]);
808    }
809
810    #[test]
811    fn test_transpose() {
812        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
813        let tensor = FlexTensor::from_data(data);
814        let transposed = tensor.transpose(0, 1);
815        assert_eq!(transposed.shape().to_vec(), vec![3, 2]);
816        assert!(!transposed.is_contiguous());
817    }
818
819    #[test]
820    fn test_clone_is_cheap() {
821        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
822        let tensor = FlexTensor::from_data(data);
823
824        // Before clone, tensor is unique
825        assert!(tensor.is_unique());
826
827        // Clone shares data
828        let cloned = tensor.clone();
829        assert!(!tensor.is_unique());
830        assert!(!cloned.is_unique());
831
832        // Both point to same data
833        assert!(core::ptr::eq(
834            tensor.bytes().as_ptr(),
835            cloned.bytes().as_ptr()
836        ));
837    }
838
839    #[test]
840    fn test_cow_on_mutation() {
841        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
842        let tensor = FlexTensor::from_data(data);
843        let mut cloned = tensor.clone();
844
845        // Both share data
846        assert!(!tensor.is_unique());
847        assert!(!cloned.is_unique());
848
849        // Mutate cloned - triggers COW
850        let storage: &mut [f32] = cloned.storage_mut();
851        storage[0] = 99.0;
852
853        // Now cloned has its own copy, tensor is unique again
854        assert!(tensor.is_unique());
855        assert!(cloned.is_unique());
856
857        // Data is different
858        assert_ne!(tensor.bytes().as_ptr(), cloned.bytes().as_ptr());
859        assert_eq!(tensor.storage::<f32>()[0], 1.0);
860        assert_eq!(cloned.storage::<f32>()[0], 99.0);
861    }
862
863    #[test]
864    fn test_into_data_narrowed_at_offset_zero() {
865        // [1, 2, 3, 4, 5, 6] shape [2, 3]
866        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
867        let tensor = FlexTensor::from_data(data);
868        // narrow to first row: shape [1, 3], offset 0, contiguous
869        let narrowed = tensor.narrow(0, 0, 1);
870        assert!(narrowed.is_contiguous());
871        assert_eq!(narrowed.layout().start_offset(), 0);
872
873        let result = narrowed.into_data();
874        assert_eq!(result.shape.to_vec(), vec![1, 3]);
875        // Must have exactly 3 f32s = 12 bytes, not 24
876        assert_eq!(result.bytes.len(), 3 * core::mem::size_of::<f32>());
877        let values: Vec<f32> = result.to_vec().unwrap();
878        assert_eq!(values, vec![1.0, 2.0, 3.0]);
879    }
880}