Skip to main content

burn_ndarray/
tensor.rs

1use core::mem;
2
3use burn_backend::{
4    DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,
5    quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},
6};
7
8use crate::NdArrayStorage;
9use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
10use alloc::vec::Vec;
11use ndarray::{ArcArray, ArrayD, IxDyn};
12
13/// Concrete storage type for ndarray (owned with COW semantics via Arc)
14pub type SharedArray<E> = ArcArray<E, IxDyn>;
15
16/// Tensor primitive used by the [ndarray backend](crate::NdArray).
17///
18/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`.
19/// When data is borrowed from external sources (like burnpack files),
20/// it remains zero-copy until a mutating operation is performed.
21#[derive(Debug, Clone)]
22#[allow(missing_docs)]
23pub enum NdArrayTensor {
24    F64(NdArrayStorage<f64>),
25    F32(NdArrayStorage<f32>),
26    I64(NdArrayStorage<i64>),
27    I32(NdArrayStorage<i32>),
28    I16(NdArrayStorage<i16>),
29    I8(NdArrayStorage<i8>),
30    U64(NdArrayStorage<u64>),
31    U32(NdArrayStorage<u32>),
32    U16(NdArrayStorage<u16>),
33    U8(NdArrayStorage<u8>),
34    Bool(NdArrayStorage<bool>),
35}
36
37impl NdArrayTensor {
38    /// Extract bool array, converting to owned if necessary.
39    pub(crate) fn bool(self) -> SharedArray<bool> {
40        match self {
41            NdArrayTensor::Bool(storage) => storage.into_shared(),
42            _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
43        }
44    }
45
46    /// Returns true if this tensor uses borrowed (zero-copy) storage.
47    #[inline]
48    pub fn is_borrowed(&self) -> bool {
49        macro_rules! check {
50            ($($variant:ident),*) => {
51                match self {
52                    $(NdArrayTensor::$variant(s) => s.is_borrowed(),)*
53                }
54            };
55        }
56        check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
57    }
58}
59
60pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
61where
62    NdArrayTensor: From<SharedArray<E1>>,
63{
64    fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
65        array.mapv(|a| a.elem()).into_shared()
66    }
67
68    if E1::dtype() == dtype {
69        return array.into();
70    }
71
72    match dtype {
73        DType::F64 => cast::<E1, f64>(array).into(),
74        DType::F32 => cast::<E1, f32>(array).into(),
75        DType::Flex32 => cast::<E1, f32>(array).into(),
76        DType::I64 => cast::<E1, i64>(array).into(),
77        DType::I32 => cast::<E1, i32>(array).into(),
78        DType::I16 => cast::<E1, i16>(array).into(),
79        DType::I8 => cast::<E1, i8>(array).into(),
80        DType::U64 => cast::<E1, u64>(array).into(),
81        DType::U32 => cast::<E1, u32>(array).into(),
82        DType::U16 => cast::<E1, u16>(array).into(),
83        DType::U8 => cast::<E1, u8>(array).into(),
84        DType::Bool => cast::<E1, bool>(array).into(),
85        dtype => panic!("Unsupported dtype: {dtype:?}"),
86    }
87}
88
89macro_rules! impl_from {
90    ($($ty: ty => $dtype: ident),*) => {
91        // From SharedArray (owned) -> NdArrayTensor
92        $(impl From<SharedArray<$ty>> for NdArrayTensor {
93           fn from(value: SharedArray<$ty>) -> NdArrayTensor {
94                NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))
95           }
96        })*
97
98        // From NdArrayStorage -> NdArrayTensor
99        $(impl From<NdArrayStorage<$ty>> for NdArrayTensor {
100           fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {
101                NdArrayTensor::$dtype(value)
102           }
103        })*
104    };
105}
106
107impl_from!(
108    f64 => F64, f32 => F32,
109    i64 => I64, i32 => I32, i16 => I16, i8 => I8,
110    u64 => U64, u32 => U32, u16 => U16, u8 => U8,
111    bool => Bool
112);
113
114/// Macro to execute an operation on a given element type.
115///
116/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation.
117///
118/// # Panics
119/// Since there is no automatic type cast at this time, binary operations for different
120/// floating point precision data types will panic with a data type mismatch.
121#[macro_export]
122macro_rules! execute_with_dtype {
123    (($lhs:expr, $rhs:expr),$element:ident,  $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
124        let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);
125        let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);
126        match ($lhs, $rhs) {
127            $(
128                ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
129                    #[allow(unused)]
130                    type $element = $ty;
131                    // Convert storage to SharedArray for compatibility with existing operations
132                    $op(lhs.into_shared(), rhs.into_shared()).into()
133                }
134            )*
135            _ => panic!(
136                "Data type mismatch (lhs: {:?}, rhs: {:?})",
137                lhs_dtype, rhs_dtype
138            ),
139        }
140    }};
141    // Binary op: type automatically inferred by the compiler
142    (($lhs:expr, $rhs:expr), $op:expr) => {{
143        $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
144    }};
145
146    // Binary op: generic type cannot be inferred for an operation
147    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
148        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
149            F64 => f64, F32 => f32,
150            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
151            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
152            Bool => bool
153        ])
154    }};
155
156    ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
157        match $tensor {
158            $(
159                $crate::NdArrayTensor::$dtype(storage) => {
160                    #[allow(unused)]
161                    type $element = $ty;
162                    // Convert to SharedArray for compatibility with most operations
163                    $op(storage.into_shared()).into()
164                }
165            )*
166            #[allow(unreachable_patterns)]
167            other => unimplemented!("unsupported dtype: {:?}", other.dtype())
168        }
169    }};
170    // Unary op: type automatically inferred by the compiler
171    ($tensor:expr, $op:expr) => {{
172        $crate::execute_with_dtype!($tensor, E, $op)
173    }};
174
175    // Unary op: generic type cannot be inferred for an operation
176    ($tensor:expr, $element:ident, $op:expr) => {{
177        $crate::execute_with_dtype!($tensor, $element, $op, [
178            F64 => f64, F32 => f32,
179            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
180            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
181            Bool => bool
182        ])
183    }};
184}
185
186/// Macro to execute an operation a given element type.
187/// Only handles float types.
188///
189/// # Panics
190/// Since there is no automatic type cast at this time, binary operations for different
191/// floating point precision data types will panic with a data type mismatch.
192#[macro_export]
193macro_rules! execute_with_float_dtype {
194    // Binary op: type automatically inferred by the compiler
195    (($lhs:expr, $rhs:expr), $op:expr) => {{
196        $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
197    }};
198
199    // Binary op: generic type cannot be inferred for an operation
200    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
201        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
202            F64 => f64, F32 => f32
203        ])
204    }};
205
206    // Unary op: type automatically inferred by the compiler
207    ($tensor:expr, $op:expr) => {{
208        $crate::execute_with_float_dtype!($tensor, E, $op)
209    }};
210
211    // Unary op: generic type cannot be inferred for an operation
212    ($tensor:expr, $element:ident, $op:expr) => {{
213        $crate::execute_with_dtype!($tensor, $element, $op, [
214            F64 => f64, F32 => f32
215        ])
216    }};
217}
218
219/// Macro to execute an operation a given element type.
220/// Only handles int types.
221///
222/// # Panics
223/// Since there is no automatic type cast at this time, binary operations for different
224/// floating point precision data types will panic with a data type mismatch.
225#[macro_export]
226macro_rules! execute_with_int_dtype {
227    // Binary op: type automatically inferred by the compiler
228    (($lhs:expr, $rhs:expr), $op:expr) => {{
229        $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
230    }};
231
232    // Binary op: generic type cannot be inferred for an operation
233    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
234        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
235            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
236            U64 => u64, U32 => u32, U16 => u16, U8 => u8
237        ])
238    }};
239
240    // Unary op: type automatically inferred by the compiler
241    ($tensor:expr, $op:expr) => {{
242        $crate::execute_with_int_dtype!($tensor, E, $op)
243    }};
244
245    // Unary op: generic type cannot be inferred for an operation
246    ($tensor:expr, $element:ident, $op:expr) => {{
247        $crate::execute_with_dtype!($tensor, $element, $op, [
248            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
249            U64 => u64, U32 => u32, U16 => u16, U8 => u8
250        ])
251    }};
252}
253
254/// Macro to execute an operation a given element type.
255/// Only handles numeric types
256///
257/// # Panics
258/// Since there is no automatic type cast at this time, binary operations for different
259/// floating point precision data types will panic with a data type mismatch.
260#[macro_export]
261macro_rules! execute_with_numeric_dtype {
262    // Binary op: type automatically inferred by the compiler
263    (($lhs:expr, $rhs:expr), $op:expr) => {{
264        $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
265    }};
266
267    // Binary op: generic type cannot be inferred for an operation
268    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
269        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
270            F64 => f64, F32 => f32,
271            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
272            U64 => u64, U32 => u32, U16 => u16, U8 => u8
273        ])
274    }};
275
276    // Unary op: type automatically inferred by the compiler
277    ($tensor:expr, $op:expr) => {{
278        $crate::execute_with_numeric_dtype!($tensor, E, $op)
279    }};
280
281    // Unary op: generic type cannot be inferred for an operation
282    ($tensor:expr, $element:ident, $op:expr) => {{
283        $crate::execute_with_dtype!($tensor, $element, $op, [
284            F64 => f64, F32 => f32,
285            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
286            U64 => u64, U32 => u32, U16 => u16, U8 => u8
287        ])
288    }};
289}
290
291/// Macro to execute a cat operation on a given set of element types.
292///
293/// Uses zero-copy views from storage for concatenation.
294///
295/// # Panics
296/// Since there is no automatic type cast at this time, binary operations for different
297/// floating point precision data types will panic with a data type mismatch.
298#[macro_export]
299macro_rules! cat_with_dtype {
300    ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
301        match &$tensors[0] {
302            $(NdArrayTensor::$dtype(_) => {
303                let tensors = $tensors
304                    .iter()
305                    .map(|t| {
306                        if let NdArrayTensor::$dtype(storage) = t {
307                            // Use storage.view() for zero-copy access
308                            storage.view()
309                        } else {
310                            panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype())
311                        }
312                    })
313                    .collect::<Vec<_>>();
314                NdArrayOps::concatenate(&tensors, $dim).into()
315            })*
316            _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
317        }
318    };
319}
320
321impl TensorMetadata for NdArrayTensor {
322    fn dtype(&self) -> DType {
323        match self {
324            NdArrayTensor::F64(_) => DType::F64,
325            NdArrayTensor::F32(_) => DType::F32,
326            NdArrayTensor::I64(_) => DType::I64,
327            NdArrayTensor::I32(_) => DType::I32,
328            NdArrayTensor::I16(_) => DType::I16,
329            NdArrayTensor::I8(_) => DType::I8,
330            NdArrayTensor::U64(_) => DType::U64,
331            NdArrayTensor::U32(_) => DType::U32,
332            NdArrayTensor::U16(_) => DType::U16,
333            NdArrayTensor::U8(_) => DType::U8,
334            NdArrayTensor::Bool(_) => DType::Bool,
335        }
336    }
337
338    fn shape(&self) -> Shape {
339        // Use storage's shape method (works for both borrowed and owned)
340        macro_rules! get_shape {
341            ($($variant:ident),*) => {
342                match self {
343                    $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*
344                }
345            };
346        }
347        get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
348    }
349
350    fn rank(&self) -> usize {
351        self.shape().num_dims()
352    }
353}
354
355pub(crate) trait ShapeOps {
356    fn num_dims(self) -> usize;
357    fn num_elements(self) -> usize;
358    fn dims<const N: usize>(self) -> [usize; N];
359    fn into_shape(self) -> Shape;
360}
361
362impl ShapeOps for &[usize] {
363    fn num_dims(self) -> usize {
364        self.len()
365    }
366
367    fn num_elements(self) -> usize {
368        self.iter().product()
369    }
370
371    fn dims<const N: usize>(self) -> [usize; N] {
372        self.try_into().unwrap()
373    }
374
375    fn into_shape(self) -> Shape {
376        Shape::from(self)
377    }
378}
379
380mod utils {
381    use burn_std::tensor::is_contiguous;
382
383    use super::*;
384
385    impl NdArrayTensor {
386        pub(crate) fn into_data(self) -> TensorData {
387            let shape = self.shape();
388            let contiguous = self.is_contiguous();
389
390            fn inner<E: Element>(
391                shape: Shape,
392                is_contiguous: bool,
393                array: ArcArray<E, IxDyn>,
394            ) -> TensorData {
395                let vec = if is_contiguous {
396                    match array.try_into_owned_nocopy() {
397                        Ok(owned) => {
398                            let (mut vec, offset) = owned.into_raw_vec_and_offset();
399                            if let Some(offset) = offset {
400                                vec.drain(..offset);
401                            }
402                            if vec.len() > shape.num_elements() {
403                                vec.drain(shape.num_elements()..vec.len());
404                            }
405                            vec
406                        }
407                        Err(array) => array.into_iter().collect(),
408                    }
409                } else {
410                    array.into_iter().collect()
411                };
412
413                TensorData::new(vec, shape)
414            }
415
416            // Convert storage to owned array before extracting data
417            execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
418        }
419
420        pub(crate) fn is_contiguous(&self) -> bool {
421            // For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous)
422            // For owned data, we check the strides
423            macro_rules! check_contiguous {
424                ($($variant:ident),*) => {
425                    match self {
426                        $(NdArrayTensor::$variant(storage) => {
427                            match storage {
428                                NdArrayStorage::Borrowed { .. } => {
429                                    // Borrowed storage requires contiguous row-major data
430                                    // (see NdArrayStorage::from_borrowed documentation)
431                                    true
432                                }
433                                NdArrayStorage::Owned(array) => {
434                                    let shape = array.shape();
435                                    let mut strides = Vec::with_capacity(array.strides().len());
436                                    for &stride in array.strides() {
437                                        if stride <= 0 {
438                                            return false;
439                                        }
440                                        strides.push(stride as usize);
441                                    }
442                                    is_contiguous(shape, &strides)
443                                }
444                            }
445                        })*
446                    }
447                };
448            }
449            check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
450        }
451    }
452}
453
454/// Converts a slice of usize to a typed dimension.
455#[macro_export(local_inner_macros)]
456macro_rules! to_typed_dims {
457    (
458        $n:expr,
459        $dims:expr,
460        justdim
461    ) => {{
462        let mut dims = [0; $n];
463        for i in 0..$n {
464            dims[i] = $dims[i];
465        }
466        let dim: Dim<[usize; $n]> = Dim(dims);
467        dim
468    }};
469}
470
471/// Reshapes an array into a tensor.
472#[macro_export(local_inner_macros)]
473macro_rules! reshape {
474    (
475        ty $ty:ty,
476        n $n:expr,
477        shape $shape:expr,
478        array $array:expr
479    ) => {{
480        let dim = $crate::to_typed_dims!($n, $shape, justdim);
481        let array = match $array.is_standard_layout() {
482            true => {
483                match $array.to_shape(dim) {
484                    Ok(val) => val.into_shared(),
485                    Err(err) => {
486                        core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
487                    }
488                }
489            },
490            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
491        };
492        array.into_dyn()
493    }};
494    (
495        ty $ty:ty,
496        shape $shape:expr,
497        array $array:expr,
498        d $D:expr
499    ) => {{
500        match $D {
501            1 => reshape!(ty $ty, n 1, shape $shape, array $array),
502            2 => reshape!(ty $ty, n 2, shape $shape, array $array),
503            3 => reshape!(ty $ty, n 3, shape $shape, array $array),
504            4 => reshape!(ty $ty, n 4, shape $shape, array $array),
505            5 => reshape!(ty $ty, n 5, shape $shape, array $array),
506            6 => reshape!(ty $ty, n 6, shape $shape, array $array),
507            _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
508        }
509    }};
510}
511
512/// Slice a tensor
513#[macro_export]
514macro_rules! slice {
515    ($tensor:expr, $slices:expr) => {
516        slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
517    };
518    ($tensor:expr, $slices:expr, $($variant:ident),*) => {
519        match $tensor {
520            $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*
521        }
522    };
523}
524
525impl NdArrayTensor {
526    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
527    ///
528    /// This method attempts zero-copy loading when possible. If the data has properly
529    /// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise,
530    /// it falls back to copying the data.
531    ///
532    /// Zero-copy loading works when:
533    /// - The data's bytes are properly aligned for the element type
534    /// - The bytes can be borrowed (e.g., from mmap'd file or static data)
535    pub fn from_data(data: TensorData) -> NdArrayTensor {
536        // Try borrowed storage first, fall back to owned if not possible
537        match Self::try_from_data_borrowed(data) {
538            Ok(tensor) => tensor,
539            Err(data) => Self::from_data_owned(data),
540        }
541    }
542
543    /// Try to create a tensor with borrowed storage (zero-copy).
544    ///
545    /// Takes ownership of TensorData and returns it back on failure.
546    /// No cloning occurs - bytes are moved into storage or returned on failure.
547    ///
548    /// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data).
549    fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {
550        let TensorData {
551            bytes,
552            shape,
553            dtype,
554        } = data;
555
556        macro_rules! try_borrow {
557            ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {
558                match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {
559                    Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),
560                    Err((bytes, shape)) => (bytes, shape),
561                }
562            };
563        }
564
565        // Try to create borrowed storage; get bytes back on failure
566        let (bytes, shape) = match dtype {
567            DType::F64 => try_borrow!(f64, F64, bytes, shape),
568            DType::F32 => try_borrow!(f32, F32, bytes, shape),
569            DType::I64 => try_borrow!(i64, I64, bytes, shape),
570            DType::I32 => try_borrow!(i32, I32, bytes, shape),
571            DType::I16 => try_borrow!(i16, I16, bytes, shape),
572            DType::I8 => try_borrow!(i8, I8, bytes, shape),
573            DType::U64 => try_borrow!(u64, U64, bytes, shape),
574            DType::U32 => try_borrow!(u32, U32, bytes, shape),
575            DType::U16 => try_borrow!(u16, U16, bytes, shape),
576            DType::U8 => try_borrow!(u8, U8, bytes, shape),
577            DType::Bool => try_borrow!(bool, Bool, bytes, shape),
578            _ => (bytes, shape), // QFloat not supported for zero-copy
579        };
580
581        Err(TensorData {
582            bytes,
583            shape,
584            dtype,
585        })
586    }
587
588    /// Create a tensor with owned storage.
589    ///
590    /// This may or may not copy data depending on whether the underlying bytes
591    /// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned,
592    /// no copy occurs; otherwise data is copied to a new allocation.
593    fn from_data_owned(mut data: TensorData) -> NdArrayTensor {
594        let shape = mem::take(&mut data.shape);
595
596        macro_rules! execute {
597            ($data: expr, [$($dtype: ident => $ty: ty),*]) => {
598                match $data.dtype {
599                    $(DType::$dtype => {
600                        match data.into_vec::<$ty>() {
601                            // Safety: TensorData checks shape validity on creation
602                            Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
603                            Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
604                        }.into()
605                    },)*
606                    other => unimplemented!("Unsupported dtype {other:?}"),
607                }
608            };
609        }
610
611        execute!(data, [
612            F64 => f64, F32 => f32,
613            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
614            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
615            Bool => bool
616        ])
617    }
618}
619
620/// A quantized tensor for the ndarray backend.
621#[derive(Clone, Debug)]
622pub struct NdArrayQTensor {
623    /// The quantized tensor.
624    pub qtensor: NdArrayTensor,
625    /// The quantization scheme.
626    pub scheme: QuantScheme,
627    /// The quantization parameters.
628    pub qparams: Vec<QParams<f32>>,
629}
630
631impl NdArrayQTensor {
632    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
633    pub fn strategy(&self) -> QuantizationStrategy {
634        match self.scheme {
635            QuantScheme {
636                level: QuantLevel::Tensor,
637                mode: QuantMode::Symmetric,
638                value:
639                    QuantValue::Q8F
640                    | QuantValue::Q8S
641                    | QuantValue::E4M3
642                    | QuantValue::E5M2
643                    | QuantValue::Q4F
644                    | QuantValue::Q4S
645                    | QuantValue::E2M1
646                    | QuantValue::Q2F
647                    | QuantValue::Q2S,
648                ..
649            } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
650                self.qparams[0].scales,
651                self.scheme.value,
652            )),
653            QuantScheme {
654                level: QuantLevel::Block(block_size),
655                mode: QuantMode::Symmetric,
656                value:
657                    QuantValue::Q8F
658                    | QuantValue::Q8S
659                    | QuantValue::E4M3
660                    | QuantValue::E5M2
661                    | QuantValue::Q4F
662                    | QuantValue::Q4S
663                    | QuantValue::E2M1
664                    | QuantValue::Q2F
665                    | QuantValue::Q2S,
666                ..
667            } => QuantizationStrategy::PerBlockSymmetric(
668                self.qparams
669                    .iter()
670                    .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
671                    .collect(),
672                block_size,
673            ),
674        }
675    }
676}
677
678impl QTensorPrimitive for NdArrayQTensor {
679    fn scheme(&self) -> &QuantScheme {
680        &self.scheme
681    }
682
683    fn default_scheme() -> QuantScheme {
684        QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)
685    }
686}
687
688impl TensorMetadata for NdArrayQTensor {
689    fn dtype(&self) -> DType {
690        DType::QFloat(self.scheme)
691    }
692
693    fn shape(&self) -> Shape {
694        self.qtensor.shape()
695    }
696
697    fn rank(&self) -> usize {
698        self.shape().num_dims()
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use crate::NdArray;
705    use alloc::vec;
706
707    use super::*;
708    use burn_backend::{
709        Distribution,
710        ops::{FloatTensorOps, QTensorOps},
711        quantization::{QuantStore, QuantizationParametersPrimitive},
712    };
713    use burn_std::rand::get_seeded_rng;
714
715    #[test]
716    fn should_support_into_and_from_data_1d() {
717        let data_expected = TensorData::random::<f32, _, _>(
718            Shape::new([3]),
719            Distribution::Default,
720            &mut get_seeded_rng(),
721        );
722        let tensor = NdArrayTensor::from_data(data_expected.clone());
723
724        let data_actual = tensor.into_data();
725
726        assert_eq!(data_expected, data_actual);
727    }
728
729    #[test]
730    fn should_support_into_and_from_data_2d() {
731        let data_expected = TensorData::random::<f32, _, _>(
732            Shape::new([2, 3]),
733            Distribution::Default,
734            &mut get_seeded_rng(),
735        );
736        let tensor = NdArrayTensor::from_data(data_expected.clone());
737
738        let data_actual = tensor.into_data();
739
740        assert_eq!(data_expected, data_actual);
741    }
742
743    #[test]
744    fn should_support_into_and_from_data_3d() {
745        let data_expected = TensorData::random::<f32, _, _>(
746            Shape::new([2, 3, 4]),
747            Distribution::Default,
748            &mut get_seeded_rng(),
749        );
750        let tensor = NdArrayTensor::from_data(data_expected.clone());
751
752        let data_actual = tensor.into_data();
753
754        assert_eq!(data_expected, data_actual);
755    }
756
757    #[test]
758    fn should_support_into_and_from_data_4d() {
759        let data_expected = TensorData::random::<f32, _, _>(
760            Shape::new([2, 3, 4, 2]),
761            Distribution::Default,
762            &mut get_seeded_rng(),
763        );
764        let tensor = NdArrayTensor::from_data(data_expected.clone());
765
766        let data_actual = tensor.into_data();
767
768        assert_eq!(data_expected, data_actual);
769    }
770
771    #[test]
772    fn should_support_qtensor_strategy() {
773        type B = NdArray<f32, i64, i8>;
774        let scale: f32 = 0.009_019_608;
775        let device = Default::default();
776
777        let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
778        let scheme = QuantScheme::default()
779            .with_value(QuantValue::Q8S)
780            .with_store(QuantStore::Native);
781        let qparams = QuantizationParametersPrimitive {
782            scales: B::float_from_data(TensorData::from([scale]), &device),
783        };
784        let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
785
786        assert_eq!(qtensor.scheme(), &scheme);
787        assert_eq!(
788            qtensor.strategy(),
789            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
790                scale,
791                QuantValue::Q8S
792            ))
793        );
794    }
795
796    // ==========================================================================
797    // Zero-copy integration tests
798    // These tests verify end-to-end zero-copy behavior through NdArrayTensor.
799    // ==========================================================================
800
801    #[test]
802    fn zero_copy_creates_borrowed_storage() {
803        // Verify that from_data creates borrowed storage when possible.
804        // Note: For native allocations, Bytes::clone() copies data internally,
805        // but the storage type (Borrowed) is preserved, which is important for
806        // the is_unique() behavior that triggers copy-on-write.
807        use burn_std::Bytes;
808
809        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
810        let bytes = Bytes::from_elems(data);
811        let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
812
813        let tensor = NdArrayTensor::from_data(tensor_data);
814
815        match &tensor {
816            NdArrayTensor::F32(storage) => {
817                assert!(
818                    storage.is_borrowed(),
819                    "ZERO-COPY REGRESSION: from_data should create borrowed storage \
820                     for properly aligned TensorData with Bytes"
821                );
822                assert!(
823                    !storage.is_unique(),
824                    "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false"
825                );
826            }
827            _ => panic!("Expected F32 tensor"),
828        }
829    }
830
831    #[test]
832    fn zero_copy_data_integrity() {
833        // Verify data is correctly accessible through borrowed storage
834        use burn_std::Bytes;
835
836        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
837        let bytes = Bytes::from_elems(data);
838        let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
839
840        let tensor = NdArrayTensor::from_data(tensor_data);
841
842        match &tensor {
843            NdArrayTensor::F32(storage) => {
844                let view = storage.view();
845                assert_eq!(view[[0, 0]], 1.0);
846                assert_eq!(view[[0, 1]], 2.0);
847                assert_eq!(view[[1, 0]], 3.0);
848                assert_eq!(view[[1, 1]], 4.0);
849            }
850            _ => panic!("Expected F32 tensor"),
851        }
852    }
853
854    #[test]
855    fn zero_copy_fallback_when_bytes_owned() {
856        // When TensorData owns bytes exclusively, it may use the copy path
857        // This is expected behavior - verify it still works correctly
858        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
859        let tensor = NdArrayTensor::from_data(data.clone());
860        let result = tensor.into_data();
861
862        assert_eq!(data, result, "Data should round-trip correctly");
863    }
864}