Skip to main content

burn_ndarray/
tensor.rs

1use burn_backend::{
2    AllocationProperty, DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,
3    quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},
4};
5use burn_std::BoolStore;
6
7use crate::NdArrayStorage;
8use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
9use alloc::vec::Vec;
10use ndarray::{ArcArray, ArrayD, IxDyn};
11
12/// Concrete storage type for ndarray (owned with COW semantics via Arc)
13pub type SharedArray<E> = ArcArray<E, IxDyn>;
14
15/// Tensor primitive used by the [ndarray backend](crate::NdArray).
16///
17/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`.
18/// When data is borrowed from external sources (like burnpack files),
19/// it remains zero-copy until a mutating operation is performed.
20#[derive(Debug, Clone)]
21#[allow(missing_docs)]
22pub enum NdArrayTensor {
23    F64(NdArrayStorage<f64>),
24    F32(NdArrayStorage<f32>),
25    I64(NdArrayStorage<i64>),
26    I32(NdArrayStorage<i32>),
27    I16(NdArrayStorage<i16>),
28    I8(NdArrayStorage<i8>),
29    U64(NdArrayStorage<u64>),
30    U32(NdArrayStorage<u32>),
31    U16(NdArrayStorage<u16>),
32    U8(NdArrayStorage<u8>),
33    Bool(NdArrayStorage<bool>),
34}
35
36impl NdArrayTensor {
37    /// Extract bool array, converting to owned if necessary.
38    pub(crate) fn bool(self) -> SharedArray<bool> {
39        match self {
40            NdArrayTensor::Bool(storage) => storage.into_shared(),
41            _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
42        }
43    }
44
45    /// Returns true if this tensor uses borrowed (zero-copy) storage.
46    #[inline]
47    pub fn is_borrowed(&self) -> bool {
48        macro_rules! check {
49            ($($variant:ident),*) => {
50                match self {
51                    $(NdArrayTensor::$variant(s) => s.is_borrowed(),)*
52                }
53            };
54        }
55        check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
56    }
57}
58
59pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
60where
61    NdArrayTensor: From<SharedArray<E1>>,
62{
63    fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
64        array.mapv(|a| a.elem()).into_shared()
65    }
66
67    if E1::dtype() == dtype {
68        return array.into();
69    }
70
71    match dtype {
72        DType::F64 => cast::<E1, f64>(array).into(),
73        DType::F32 => cast::<E1, f32>(array).into(),
74        DType::Flex32 => cast::<E1, f32>(array).into(),
75        DType::I64 => cast::<E1, i64>(array).into(),
76        DType::I32 => cast::<E1, i32>(array).into(),
77        DType::I16 => cast::<E1, i16>(array).into(),
78        DType::I8 => cast::<E1, i8>(array).into(),
79        DType::U64 => cast::<E1, u64>(array).into(),
80        DType::U32 => cast::<E1, u32>(array).into(),
81        DType::U16 => cast::<E1, u16>(array).into(),
82        DType::U8 => cast::<E1, u8>(array).into(),
83        DType::Bool(BoolStore::Native) => cast::<E1, bool>(array).into(),
84        dtype => panic!("Unsupported dtype: {dtype:?}"),
85    }
86}
87
88macro_rules! impl_from {
89    ($($ty: ty => $dtype: ident),*) => {
90        // From SharedArray (owned) -> NdArrayTensor
91        $(impl From<SharedArray<$ty>> for NdArrayTensor {
92           fn from(value: SharedArray<$ty>) -> NdArrayTensor {
93                NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))
94           }
95        })*
96
97        // From NdArrayStorage -> NdArrayTensor
98        $(impl From<NdArrayStorage<$ty>> for NdArrayTensor {
99           fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {
100                NdArrayTensor::$dtype(value)
101           }
102        })*
103    };
104}
105
106impl_from!(
107    f64 => F64, f32 => F32,
108    i64 => I64, i32 => I32, i16 => I16, i8 => I8,
109    u64 => U64, u32 => U32, u16 => U16, u8 => U8,
110    bool => Bool
111);
112
113/// Macro to execute an operation on a given element type.
114///
115/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation.
116///
117/// # Panics
118/// Since there is no automatic type cast at this time, binary operations for different
119/// floating point precision data types will panic with a data type mismatch.
120#[macro_export]
121macro_rules! execute_with_dtype {
122    (($lhs:expr, $rhs:expr),$element:ident,  $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
123        let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);
124        let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);
125        match ($lhs, $rhs) {
126            $(
127                ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
128                    #[allow(unused)]
129                    type $element = $ty;
130                    // Convert storage to SharedArray for compatibility with existing operations
131                    $op(lhs.into_shared(), rhs.into_shared()).into()
132                }
133            )*
134            _ => panic!(
135                "Data type mismatch (lhs: {:?}, rhs: {:?})",
136                lhs_dtype, rhs_dtype
137            ),
138        }
139    }};
140    // Binary op: type automatically inferred by the compiler
141    (($lhs:expr, $rhs:expr), $op:expr) => {{
142        $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
143    }};
144
145    // Binary op: generic type cannot be inferred for an operation
146    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
147        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
148            F64 => f64, F32 => f32,
149            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
150            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
151            Bool => bool
152        ])
153    }};
154
155    ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
156        match $tensor {
157            $(
158                $crate::NdArrayTensor::$dtype(storage) => {
159                    #[allow(unused)]
160                    type $element = $ty;
161                    // Convert to SharedArray for compatibility with most operations
162                    $op(storage.into_shared()).into()
163                }
164            )*
165            #[allow(unreachable_patterns)]
166            other => unimplemented!("unsupported dtype: {:?}", other.dtype())
167        }
168    }};
169    // Unary op: type automatically inferred by the compiler
170    ($tensor:expr, $op:expr) => {{
171        $crate::execute_with_dtype!($tensor, E, $op)
172    }};
173
174    // Unary op: generic type cannot be inferred for an operation
175    ($tensor:expr, $element:ident, $op:expr) => {{
176        $crate::execute_with_dtype!($tensor, $element, $op, [
177            F64 => f64, F32 => f32,
178            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
179            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
180            Bool => bool
181        ])
182    }};
183}
184
185/// Macro to execute an operation a given element type.
186/// Only handles float types.
187///
188/// # Panics
189/// Since there is no automatic type cast at this time, binary operations for different
190/// floating point precision data types will panic with a data type mismatch.
191#[macro_export]
192macro_rules! execute_with_float_dtype {
193    // Binary op: type automatically inferred by the compiler
194    (($lhs:expr, $rhs:expr), $op:expr) => {{
195        $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
196    }};
197
198    // Binary op: generic type cannot be inferred for an operation
199    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
200        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
201            F64 => f64, F32 => f32
202        ])
203    }};
204
205    // Unary op: type automatically inferred by the compiler
206    ($tensor:expr, $op:expr) => {{
207        $crate::execute_with_float_dtype!($tensor, E, $op)
208    }};
209
210    // Unary op: generic type cannot be inferred for an operation
211    ($tensor:expr, $element:ident, $op:expr) => {{
212        $crate::execute_with_dtype!($tensor, $element, $op, [
213            F64 => f64, F32 => f32
214        ])
215    }};
216}
217
218/// Macro to execute an operation a given element type.
219/// Only handles int types.
220///
221/// # Panics
222/// Since there is no automatic type cast at this time, binary operations for different
223/// floating point precision data types will panic with a data type mismatch.
224#[macro_export]
225macro_rules! execute_with_int_dtype {
226    // Binary op: type automatically inferred by the compiler
227    (($lhs:expr, $rhs:expr), $op:expr) => {{
228        $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
229    }};
230
231    // Binary op: generic type cannot be inferred for an operation
232    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
233        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
234            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
235            U64 => u64, U32 => u32, U16 => u16, U8 => u8
236        ])
237    }};
238
239    // Unary op: type automatically inferred by the compiler
240    ($tensor:expr, $op:expr) => {{
241        $crate::execute_with_int_dtype!($tensor, E, $op)
242    }};
243
244    // Unary op: generic type cannot be inferred for an operation
245    ($tensor:expr, $element:ident, $op:expr) => {{
246        $crate::execute_with_dtype!($tensor, $element, $op, [
247            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
248            U64 => u64, U32 => u32, U16 => u16, U8 => u8
249        ])
250    }};
251}
252
253/// Macro to execute an operation a given element type.
254/// Only handles numeric types
255///
256/// # Panics
257/// Since there is no automatic type cast at this time, binary operations for different
258/// floating point precision data types will panic with a data type mismatch.
259#[macro_export]
260macro_rules! execute_with_numeric_dtype {
261    // Binary op: type automatically inferred by the compiler
262    (($lhs:expr, $rhs:expr), $op:expr) => {{
263        $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
264    }};
265
266    // Binary op: generic type cannot be inferred for an operation
267    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
268        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
269            F64 => f64, F32 => f32,
270            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
271            U64 => u64, U32 => u32, U16 => u16, U8 => u8
272        ])
273    }};
274
275    // Unary op: type automatically inferred by the compiler
276    ($tensor:expr, $op:expr) => {{
277        $crate::execute_with_numeric_dtype!($tensor, E, $op)
278    }};
279
280    // Unary op: generic type cannot be inferred for an operation
281    ($tensor:expr, $element:ident, $op:expr) => {{
282        $crate::execute_with_dtype!($tensor, $element, $op, [
283            F64 => f64, F32 => f32,
284            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
285            U64 => u64, U32 => u32, U16 => u16, U8 => u8
286        ])
287    }};
288}
289
290/// Macro to execute a cat operation on a given set of element types.
291///
292/// Uses zero-copy views from storage for concatenation.
293///
294/// # Panics
295/// Since there is no automatic type cast at this time, binary operations for different
296/// floating point precision data types will panic with a data type mismatch.
297#[macro_export]
298macro_rules! cat_with_dtype {
299    ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
300        match &$tensors[0] {
301            $(NdArrayTensor::$dtype(_) => {
302                let tensors = $tensors
303                    .iter()
304                    .map(|t| {
305                        if let NdArrayTensor::$dtype(storage) = t {
306                            // Use storage.view() for zero-copy access
307                            storage.view()
308                        } else {
309                            panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype())
310                        }
311                    })
312                    .collect::<Vec<_>>();
313                NdArrayOps::concatenate(&tensors, $dim).into()
314            })*
315            _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
316        }
317    };
318}
319
320/// Macro to execute an operation that returns a given element type.
321#[macro_export]
322macro_rules! execute_with_float_out_dtype {
323    ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
324        match $out_dtype {
325            $(
326                burn_std::FloatDType::$dtype => {
327                    #[allow(unused)]
328                    type $element = $ty;
329                    $op
330                }
331            )*
332            #[allow(unreachable_patterns)]
333            other => unimplemented!("unsupported dtype: {other:?}")
334        }
335    }};
336    // Unary op: type automatically inferred by the compiler
337    ($out_dtype:expr, $op:expr) => {{
338        $crate::execute_with_float_out_dtype!($out_dtype, E, $op)
339    }};
340
341    // Unary op: generic type cannot be inferred for an operation
342    ($out_dtype:expr, $element:ident, $op:expr) => {{
343        $crate::execute_with_float_out_dtype!($out_dtype, $element, $op, [
344            F64 => f64, F32 => f32
345        ])
346    }};
347}
348
349/// Macro to execute an operation that returns a given element type.
350#[macro_export]
351macro_rules! execute_with_int_out_dtype {
352    ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
353        match $out_dtype {
354            $(
355                burn_std::IntDType::$dtype => {
356                    #[allow(unused)]
357                    type $element = $ty;
358                    $op
359                }
360            )*
361            #[allow(unreachable_patterns)]
362            other => unimplemented!("unsupported dtype: {other:?}")
363        }
364    }};
365    // Unary op: type automatically inferred by the compiler
366    ($out_dtype:expr, $op:expr) => {{
367        $crate::execute_with_int_out_dtype!($out_dtype, E, $op)
368    }};
369
370    // Unary op: generic type cannot be inferred for an operation
371    ($out_dtype:expr, $element:ident, $op:expr) => {{
372        $crate::execute_with_int_out_dtype!($out_dtype, $element, $op, [
373            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
374            U64 => u64, U32 => u32, U16 => u16, U8 => u8
375        ])
376    }};
377}
378
379impl TensorMetadata for NdArrayTensor {
380    fn dtype(&self) -> DType {
381        match self {
382            NdArrayTensor::F64(_) => DType::F64,
383            NdArrayTensor::F32(_) => DType::F32,
384            NdArrayTensor::I64(_) => DType::I64,
385            NdArrayTensor::I32(_) => DType::I32,
386            NdArrayTensor::I16(_) => DType::I16,
387            NdArrayTensor::I8(_) => DType::I8,
388            NdArrayTensor::U64(_) => DType::U64,
389            NdArrayTensor::U32(_) => DType::U32,
390            NdArrayTensor::U16(_) => DType::U16,
391            NdArrayTensor::U8(_) => DType::U8,
392            NdArrayTensor::Bool(_) => DType::Bool(BoolStore::Native),
393        }
394    }
395
396    fn shape(&self) -> Shape {
397        // Use storage's shape method (works for both borrowed and owned)
398        macro_rules! get_shape {
399            ($($variant:ident),*) => {
400                match self {
401                    $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*
402                }
403            };
404        }
405        get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
406    }
407
408    fn rank(&self) -> usize {
409        self.shape().num_dims()
410    }
411}
412
413pub(crate) trait ShapeOps {
414    fn num_dims(self) -> usize;
415    fn num_elements(self) -> usize;
416    fn dims<const N: usize>(self) -> [usize; N];
417    fn into_shape(self) -> Shape;
418}
419
420impl ShapeOps for &[usize] {
421    fn num_dims(self) -> usize {
422        self.len()
423    }
424
425    fn num_elements(self) -> usize {
426        self.iter().product()
427    }
428
429    fn dims<const N: usize>(self) -> [usize; N] {
430        self.try_into().unwrap()
431    }
432
433    fn into_shape(self) -> Shape {
434        Shape::from(self)
435    }
436}
437
438mod utils {
439    use burn_std::tensor::is_contiguous;
440
441    use super::*;
442
443    impl NdArrayTensor {
444        pub(crate) fn into_data(self) -> TensorData {
445            let shape = self.shape();
446            let contiguous = self.is_contiguous();
447
448            fn inner<E: Element>(
449                shape: Shape,
450                is_contiguous: bool,
451                array: ArcArray<E, IxDyn>,
452            ) -> TensorData {
453                let vec = if is_contiguous {
454                    match array.try_into_owned_nocopy() {
455                        Ok(owned) => {
456                            let (mut vec, offset) = owned.into_raw_vec_and_offset();
457                            if let Some(offset) = offset {
458                                vec.drain(..offset);
459                            }
460                            if vec.len() > shape.num_elements() {
461                                vec.drain(shape.num_elements()..vec.len());
462                            }
463                            vec
464                        }
465                        Err(array) => array.into_iter().collect(),
466                    }
467                } else {
468                    array.into_iter().collect()
469                };
470
471                TensorData::new(vec, shape)
472            }
473
474            // Convert storage to owned array before extracting data
475            execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
476        }
477
478        pub(crate) fn is_contiguous(&self) -> bool {
479            // For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous)
480            // For owned data, we check the strides
481            macro_rules! check_contiguous {
482                ($($variant:ident),*) => {
483                    match self {
484                        $(NdArrayTensor::$variant(storage) => {
485                            match storage {
486                                NdArrayStorage::Borrowed { .. } => {
487                                    // Borrowed storage requires contiguous row-major data
488                                    // (see NdArrayStorage::from_borrowed documentation)
489                                    true
490                                }
491                                NdArrayStorage::Owned(array) => {
492                                    let shape = array.shape();
493                                    let mut strides = Vec::with_capacity(array.strides().len());
494                                    for &stride in array.strides() {
495                                        if stride <= 0 {
496                                            return false;
497                                        }
498                                        strides.push(stride as usize);
499                                    }
500                                    is_contiguous(shape, &strides)
501                                }
502                            }
503                        })*
504                    }
505                };
506            }
507            check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
508        }
509    }
510}
511
512/// Converts a slice of usize to a typed dimension.
513#[macro_export(local_inner_macros)]
514macro_rules! to_typed_dims {
515    (
516        $n:expr,
517        $dims:expr,
518        justdim
519    ) => {{
520        let mut dims = [0; $n];
521        for i in 0..$n {
522            dims[i] = $dims[i];
523        }
524        let dim: Dim<[usize; $n]> = Dim(dims);
525        dim
526    }};
527}
528
529/// Reshapes an array into a tensor.
530#[macro_export(local_inner_macros)]
531macro_rules! reshape {
532    (
533        ty $ty:ty,
534        n $n:expr,
535        shape $shape:expr,
536        array $array:expr
537    ) => {{
538        let dim = $crate::to_typed_dims!($n, $shape, justdim);
539        let array = match $array.is_standard_layout() {
540            true => {
541                match $array.to_shape(dim) {
542                    Ok(val) => val.into_shared(),
543                    Err(err) => {
544                        core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
545                    }
546                }
547            },
548            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
549        };
550        array.into_dyn()
551    }};
552    (
553        ty $ty:ty,
554        shape $shape:expr,
555        array $array:expr,
556        d $D:expr
557    ) => {{
558        match $D {
559            1 => reshape!(ty $ty, n 1, shape $shape, array $array),
560            2 => reshape!(ty $ty, n 2, shape $shape, array $array),
561            3 => reshape!(ty $ty, n 3, shape $shape, array $array),
562            4 => reshape!(ty $ty, n 4, shape $shape, array $array),
563            5 => reshape!(ty $ty, n 5, shape $shape, array $array),
564            6 => reshape!(ty $ty, n 6, shape $shape, array $array),
565            _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
566        }
567    }};
568}
569
570/// Slice a tensor
571#[macro_export]
572macro_rules! slice {
573    ($tensor:expr, $slices:expr) => {
574        slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
575    };
576    ($tensor:expr, $slices:expr, $($variant:ident),*) => {
577        match $tensor {
578            $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*
579        }
580    };
581}
582
583impl NdArrayTensor {
584    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
585    ///
586    /// This method attempts zero-copy loading when possible. If the data has properly
587    /// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise,
588    /// it falls back to copying the data.
589    ///
590    /// Zero-copy loading works when:
591    /// - The data's bytes are properly aligned for the element type
592    /// - The bytes can be borrowed (e.g., from mmap'd file or static data)
593    pub fn from_data(data: TensorData) -> NdArrayTensor {
594        // Only use Borrowed storage for non-native allocations (e.g., burnpack mmap/file).
595        // For native Rust heap allocations (the common case), go directly to owned storage:
596        // `from_data_owned` reclaims the Vec zero-copy via `into_vec`, while
597        // Borrowed storage would trigger a full memcopy on every single operation.
598        if data.bytes.property() != AllocationProperty::Native {
599            match Self::try_from_data_borrowed(data) {
600                Ok(tensor) => return tensor,
601                Err(data) => return Self::from_data_owned(data),
602            }
603        }
604        Self::from_data_owned(data)
605    }
606
607    /// Try to create a tensor with borrowed storage (zero-copy).
608    ///
609    /// Takes ownership of TensorData and returns it back on failure.
610    /// No cloning occurs - bytes are moved into storage or returned on failure.
611    ///
612    /// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data).
613    fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {
614        let TensorData {
615            bytes,
616            shape,
617            dtype,
618        } = data;
619
620        macro_rules! try_borrow {
621            ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {
622                match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {
623                    Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),
624                    Err((bytes, shape)) => (bytes, shape),
625                }
626            };
627        }
628
629        // Try to create borrowed storage; get bytes back on failure
630        let (bytes, shape) = match dtype {
631            DType::F64 => try_borrow!(f64, F64, bytes, shape),
632            DType::F32 => try_borrow!(f32, F32, bytes, shape),
633            DType::I64 => try_borrow!(i64, I64, bytes, shape),
634            DType::I32 => try_borrow!(i32, I32, bytes, shape),
635            DType::I16 => try_borrow!(i16, I16, bytes, shape),
636            DType::I8 => try_borrow!(i8, I8, bytes, shape),
637            DType::U64 => try_borrow!(u64, U64, bytes, shape),
638            DType::U32 => try_borrow!(u32, U32, bytes, shape),
639            DType::U16 => try_borrow!(u16, U16, bytes, shape),
640            DType::U8 => try_borrow!(u8, U8, bytes, shape),
641            DType::Bool(BoolStore::Native) => try_borrow!(bool, Bool, bytes, shape),
642            _ => (bytes, shape), // QFloat not supported for zero-copy
643        };
644
645        Err(TensorData {
646            bytes,
647            shape,
648            dtype,
649        })
650    }
651
652    /// Create a tensor with owned storage.
653    ///
654    /// This may or may not copy data depending on whether the underlying bytes
655    /// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned,
656    /// no copy occurs; otherwise data is copied to a new allocation.
657    fn from_data_owned(data: TensorData) -> NdArrayTensor {
658        let shape = data.shape.to_vec(); // TODO: into_vec
659
660        macro_rules! execute {
661            ($data: expr, [$($dtype: pat => $ty: ty),*]) => {
662                match $data.dtype {
663                    $( $dtype => {
664                        match data.into_vec::<$ty>() {
665                            Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
666                            Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
667                        }.into()
668                    }, )*
669                    other => unimplemented!("Unsupported dtype {other:?}"),
670                }
671            };
672        }
673
674        execute!(data, [
675            DType::F64 => f64, DType::F32 => f32,
676            DType::I64 => i64, DType::I32 => i32, DType::I16 => i16, DType::I8 => i8,
677            DType::U64 => u64, DType::U32 => u32, DType::U16 => u16, DType::U8 => u8,
678            DType::Bool(BoolStore::Native) => bool
679        ])
680    }
681}
682
683/// A quantized tensor for the ndarray backend.
684#[derive(Clone, Debug)]
685pub struct NdArrayQTensor {
686    /// The quantized tensor.
687    pub qtensor: NdArrayTensor,
688    /// The quantization scheme.
689    pub scheme: QuantScheme,
690    /// The quantization parameters.
691    pub qparams: Vec<QParams<f32>>,
692}
693
694impl NdArrayQTensor {
695    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
696    pub fn strategy(&self) -> QuantizationStrategy {
697        match self.scheme {
698            QuantScheme {
699                level: QuantLevel::Tensor,
700                mode: QuantMode::Symmetric,
701                value:
702                    QuantValue::Q8F
703                    | QuantValue::Q8S
704                    | QuantValue::E4M3
705                    | QuantValue::E5M2
706                    | QuantValue::Q4F
707                    | QuantValue::Q4S
708                    | QuantValue::E2M1
709                    | QuantValue::Q2F
710                    | QuantValue::Q2S,
711                ..
712            } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
713                self.qparams[0].scales,
714                self.scheme.value,
715            )),
716            QuantScheme {
717                level: QuantLevel::Block(block_size),
718                mode: QuantMode::Symmetric,
719                value:
720                    QuantValue::Q8F
721                    | QuantValue::Q8S
722                    | QuantValue::E4M3
723                    | QuantValue::E5M2
724                    | QuantValue::Q4F
725                    | QuantValue::Q4S
726                    | QuantValue::E2M1
727                    | QuantValue::Q2F
728                    | QuantValue::Q2S,
729                ..
730            } => QuantizationStrategy::PerBlockSymmetric(
731                self.qparams
732                    .iter()
733                    .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
734                    .collect(),
735                block_size,
736            ),
737        }
738    }
739}
740
741impl QTensorPrimitive for NdArrayQTensor {
742    fn scheme(&self) -> &QuantScheme {
743        &self.scheme
744    }
745
746    fn default_scheme() -> QuantScheme {
747        QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)
748    }
749}
750
751impl TensorMetadata for NdArrayQTensor {
752    fn dtype(&self) -> DType {
753        DType::QFloat(self.scheme)
754    }
755
756    fn shape(&self) -> Shape {
757        self.qtensor.shape()
758    }
759
760    fn rank(&self) -> usize {
761        self.shape().num_dims()
762    }
763}
764
765#[cfg(test)]
766mod tests {
767    use crate::NdArray;
768    use alloc::vec;
769
770    use super::*;
771    use burn_backend::{
772        Distribution,
773        ops::{FloatTensorOps, QTensorOps},
774        quantization::{QuantStore, QuantizationParametersPrimitive},
775    };
776    use burn_std::rand::get_seeded_rng;
777
778    #[test]
779    fn should_support_into_and_from_data_1d() {
780        let data_expected = TensorData::random::<f32, _, _>(
781            Shape::new([3]),
782            Distribution::Default,
783            &mut get_seeded_rng(),
784        );
785        let tensor = NdArrayTensor::from_data(data_expected.clone());
786
787        let data_actual = tensor.into_data();
788
789        assert_eq!(data_expected, data_actual);
790    }
791
792    #[test]
793    fn should_support_into_and_from_data_2d() {
794        let data_expected = TensorData::random::<f32, _, _>(
795            Shape::new([2, 3]),
796            Distribution::Default,
797            &mut get_seeded_rng(),
798        );
799        let tensor = NdArrayTensor::from_data(data_expected.clone());
800
801        let data_actual = tensor.into_data();
802
803        assert_eq!(data_expected, data_actual);
804    }
805
806    #[test]
807    fn should_support_into_and_from_data_3d() {
808        let data_expected = TensorData::random::<f32, _, _>(
809            Shape::new([2, 3, 4]),
810            Distribution::Default,
811            &mut get_seeded_rng(),
812        );
813        let tensor = NdArrayTensor::from_data(data_expected.clone());
814
815        let data_actual = tensor.into_data();
816
817        assert_eq!(data_expected, data_actual);
818    }
819
820    #[test]
821    fn should_support_into_and_from_data_4d() {
822        let data_expected = TensorData::random::<f32, _, _>(
823            Shape::new([2, 3, 4, 2]),
824            Distribution::Default,
825            &mut get_seeded_rng(),
826        );
827        let tensor = NdArrayTensor::from_data(data_expected.clone());
828
829        let data_actual = tensor.into_data();
830
831        assert_eq!(data_expected, data_actual);
832    }
833
834    #[test]
835    fn should_support_qtensor_strategy() {
836        type B = NdArray<f32, i64, i8>;
837        let scale: f32 = 0.009_019_608;
838        let device = Default::default();
839
840        let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
841        let scheme = QuantScheme::default()
842            .with_value(QuantValue::Q8S)
843            .with_store(QuantStore::Native);
844        let qparams = QuantizationParametersPrimitive {
845            scales: B::float_from_data(TensorData::from([scale]), &device),
846        };
847        let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
848
849        assert_eq!(qtensor.scheme(), &scheme);
850        assert_eq!(
851            qtensor.strategy(),
852            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
853                scale,
854                QuantValue::Q8S
855            ))
856        );
857    }
858
859    // ==========================================================================
860    // Zero-copy integration tests
861    // These tests verify end-to-end zero-copy behavior through NdArrayTensor.
862    // ==========================================================================
863
864    #[test]
865    fn zero_copy_creates_borrowed_storage_for_non_native() {
866        // Verify that from_data creates borrowed storage for non-native allocations
867        // (e.g. burnpack mmap/file data tagged with AllocationProperty::Other or File).
868        // Native heap allocations intentionally use Owned storage for performance.
869        use burn_backend::AllocationProperty;
870        use burn_std::Bytes;
871
872        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
873        let bytes = Bytes::from_elems(data);
874        // Tag as Other to simulate burnpack / mmap data (non-native backing storage)
875        let non_native_bytes = Bytes::from_shared(
876            bytes::Bytes::copy_from_slice(&bytes),
877            AllocationProperty::Other,
878        );
879        let tensor_data = TensorData::from_bytes(non_native_bytes, Shape::new([2, 2]), DType::F32);
880
881        let tensor = NdArrayTensor::from_data(tensor_data);
882
883        match &tensor {
884            NdArrayTensor::F32(storage) => {
885                assert!(
886                    storage.is_borrowed(),
887                    "ZERO-COPY REGRESSION: from_data should create borrowed storage \
888                     for non-native (e.g. burnpack) TensorData"
889                );
890                assert!(
891                    !storage.is_unique(),
892                    "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false"
893                );
894            }
895            _ => panic!("Expected F32 tensor"),
896        }
897    }
898
899    #[test]
900    fn native_alloc_creates_owned_storage() {
901        // Native heap allocations must use Owned storage to avoid the memcpy.
902        use burn_std::Bytes;
903
904        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
905        let bytes = Bytes::from_elems(data); // AllocationProperty::Native
906        let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
907
908        let tensor = NdArrayTensor::from_data(tensor_data);
909
910        match &tensor {
911            NdArrayTensor::F32(storage) => {
912                assert!(
913                    !storage.is_borrowed(),
914                    "PERF REGRESSION: from_data must NOT create borrowed storage \
915                     for native TensorData"
916                );
917            }
918            _ => panic!("Expected F32 tensor"),
919        }
920    }
921
922    #[test]
923    fn zero_copy_data_integrity() {
924        // Verify data is correctly accessible through borrowed storage
925        use burn_std::Bytes;
926
927        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
928        let bytes = Bytes::from_elems(data);
929        let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
930
931        let tensor = NdArrayTensor::from_data(tensor_data);
932
933        match &tensor {
934            NdArrayTensor::F32(storage) => {
935                let view = storage.view();
936                assert_eq!(view[[0, 0]], 1.0);
937                assert_eq!(view[[0, 1]], 2.0);
938                assert_eq!(view[[1, 0]], 3.0);
939                assert_eq!(view[[1, 1]], 4.0);
940            }
941            _ => panic!("Expected F32 tensor"),
942        }
943    }
944
945    #[test]
946    fn zero_copy_fallback_when_bytes_owned() {
947        // When TensorData owns bytes exclusively, it may use the copy path
948        // This is expected behavior - verify it still works correctly
949        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
950        let tensor = NdArrayTensor::from_data(data.clone());
951        let result = tensor.into_data();
952
953        assert_eq!(data, result, "Data should round-trip correctly");
954    }
955}