burn_ndarray/
tensor.rs

1use core::mem;
2
3use burn_tensor::{
4    DType, Element, Shape, TensorData, TensorMetadata,
5    quantization::{
6        QParams, QTensorPrimitive, QuantLevel, QuantMode, QuantScheme, QuantValue,
7        QuantizationStrategy, SymmetricQuantization,
8    },
9};
10
11use alloc::vec::Vec;
12use ndarray::{ArcArray, ArrayD, IxDyn};
13
14/// Concrete storage type for ndarray
15pub type SharedArray<E> = ArcArray<E, IxDyn>;
16
17/// Tensor primitive used by the [ndarray backend](crate::NdArray).
18#[derive(Debug, Clone)]
19#[allow(missing_docs)]
20pub enum NdArrayTensor {
21    F64(SharedArray<f64>),
22    F32(SharedArray<f32>),
23    I64(SharedArray<i64>),
24    I32(SharedArray<i32>),
25    I16(SharedArray<i16>),
26    I8(SharedArray<i8>),
27    U64(SharedArray<u64>),
28    U32(SharedArray<u32>),
29    U16(SharedArray<u16>),
30    U8(SharedArray<u8>),
31    Bool(SharedArray<bool>),
32}
33
34impl NdArrayTensor {
35    pub(crate) fn bool(self) -> SharedArray<bool> {
36        match self {
37            NdArrayTensor::Bool(arr) => arr,
38            _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
39        }
40    }
41}
42
43pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
44where
45    NdArrayTensor: From<SharedArray<E1>>,
46{
47    fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
48        array.mapv(|a| a.elem()).into_shared()
49    }
50
51    if E1::dtype() == dtype {
52        return array.into();
53    }
54
55    match dtype {
56        DType::F64 => cast::<E1, f64>(array).into(),
57        DType::F32 => cast::<E1, f32>(array).into(),
58        DType::Flex32 => cast::<E1, f32>(array).into(),
59        DType::I64 => cast::<E1, i64>(array).into(),
60        DType::I32 => cast::<E1, i32>(array).into(),
61        DType::I16 => cast::<E1, i16>(array).into(),
62        DType::I8 => cast::<E1, i8>(array).into(),
63        DType::U64 => cast::<E1, u64>(array).into(),
64        DType::U32 => cast::<E1, u32>(array).into(),
65        DType::U16 => cast::<E1, u16>(array).into(),
66        DType::U8 => cast::<E1, u8>(array).into(),
67        DType::Bool => cast::<E1, bool>(array).into(),
68        dtype => panic!("Unsupported dtype: {dtype:?}"),
69    }
70}
71
72macro_rules! impl_from {
73    ($($ty: ty => $dtype: ident),*) => {
74        $(impl From<SharedArray<$ty>> for NdArrayTensor {
75           fn from(value: SharedArray<$ty>) -> NdArrayTensor {
76                NdArrayTensor::$dtype(value)
77           }
78        })*
79    };
80}
81
82impl_from!(
83    f64 => F64, f32 => F32,
84    i64 => I64, i32 => I32, i16 => I16, i8 => I8,
85    u64 => U64, u32 => U32, u16 => U16, u8 => U8,
86    bool => Bool
87);
88
89/// Macro to execute an operation a given element type.
90///
91/// # Panics
92/// Since there is no automatic type cast at this time, binary operations for different
93/// floating point precision data types will panic with a data type mismatch.
94#[macro_export]
95macro_rules! execute_with_dtype {
96    (($lhs:expr, $rhs:expr),$element:ident,  $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
97        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
98        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
99        match ($lhs, $rhs) {
100            $(
101                ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
102                    #[allow(unused)]
103                    type $element = $ty;
104                    $op(lhs, rhs).into()
105                }
106            )*
107            _ => panic!(
108                "Data type mismatch (lhs: {:?}, rhs: {:?})",
109                lhs_dtype, rhs_dtype
110            ),
111        }
112    }};
113    // Binary op: type automatically inferred by the compiler
114    (($lhs:expr, $rhs:expr), $op:expr) => {{
115        $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
116    }};
117
118    // Binary op: generic type cannot be inferred for an operation
119    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
120        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
121            F64 => f64, F32 => f32,
122            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
123            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
124            Bool => bool
125        ])
126    }};
127
128    ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
129        match $tensor {
130            $(
131                $crate::NdArrayTensor::$dtype(lhs) => {
132                    #[allow(unused)]
133                    type $element = $ty;
134                    $op(lhs).into()
135                }
136            )*
137            #[allow(unreachable_patterns)]
138            other => unimplemented!("unsupported dtype: {:?}", other.dtype())
139        }
140    }};
141    // Unary op: type automatically inferred by the compiler
142    ($tensor:expr, $op:expr) => {{
143        $crate::execute_with_dtype!($tensor, E, $op)
144    }};
145
146    // Unary op: generic type cannot be inferred for an operation
147    ($tensor:expr, $element:ident, $op:expr) => {{
148        $crate::execute_with_dtype!($tensor, $element, $op, [
149            F64 => f64, F32 => f32,
150            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
151            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
152            Bool => bool
153        ])
154    }};
155}
156
157/// Macro to execute an operation a given element type.
158/// Only handles float types.
159///
160/// # Panics
161/// Since there is no automatic type cast at this time, binary operations for different
162/// floating point precision data types will panic with a data type mismatch.
163#[macro_export]
164macro_rules! execute_with_float_dtype {
165    // Binary op: type automatically inferred by the compiler
166    (($lhs:expr, $rhs:expr), $op:expr) => {{
167        $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
168    }};
169
170    // Binary op: generic type cannot be inferred for an operation
171    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
172        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
173            F64 => f64, F32 => f32
174        ])
175    }};
176
177    // Unary op: type automatically inferred by the compiler
178    ($tensor:expr, $op:expr) => {{
179        $crate::execute_with_float_dtype!($tensor, E, $op)
180    }};
181
182    // Unary op: generic type cannot be inferred for an operation
183    ($tensor:expr, $element:ident, $op:expr) => {{
184        $crate::execute_with_dtype!($tensor, $element, $op, [
185            F64 => f64, F32 => f32
186        ])
187    }};
188}
189
190/// Macro to execute an operation a given element type.
191/// Only handles int types.
192///
193/// # Panics
194/// Since there is no automatic type cast at this time, binary operations for different
195/// floating point precision data types will panic with a data type mismatch.
196#[macro_export]
197macro_rules! execute_with_int_dtype {
198    // Binary op: type automatically inferred by the compiler
199    (($lhs:expr, $rhs:expr), $op:expr) => {{
200        $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
201    }};
202
203    // Binary op: generic type cannot be inferred for an operation
204    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
205        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
206            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
207            U64 => u64, U32 => u32, U16 => u16, U8 => u8
208        ])
209    }};
210
211    // Unary op: type automatically inferred by the compiler
212    ($tensor:expr, $op:expr) => {{
213        $crate::execute_with_int_dtype!($tensor, E, $op)
214    }};
215
216    // Unary op: generic type cannot be inferred for an operation
217    ($tensor:expr, $element:ident, $op:expr) => {{
218        $crate::execute_with_dtype!($tensor, $element, $op, [
219            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
220            U64 => u64, U32 => u32, U16 => u16, U8 => u8
221        ])
222    }};
223}
224
225/// Macro to execute an operation a given element type.
226/// Only handles numeric types
227///
228/// # Panics
229/// Since there is no automatic type cast at this time, binary operations for different
230/// floating point precision data types will panic with a data type mismatch.
231#[macro_export]
232macro_rules! execute_with_numeric_dtype {
233    // Binary op: type automatically inferred by the compiler
234    (($lhs:expr, $rhs:expr), $op:expr) => {{
235        $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
236    }};
237
238    // Binary op: generic type cannot be inferred for an operation
239    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
240        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
241            F64 => f64, F32 => f32,
242            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
243            U64 => u64, U32 => u32, U16 => u16, U8 => u8
244        ])
245    }};
246
247    // Unary op: type automatically inferred by the compiler
248    ($tensor:expr, $op:expr) => {{
249        $crate::execute_with_numeric_dtype!($tensor, E, $op)
250    }};
251
252    // Unary op: generic type cannot be inferred for an operation
253    ($tensor:expr, $element:ident, $op:expr) => {{
254        $crate::execute_with_dtype!($tensor, $element, $op, [
255            F64 => f64, F32 => f32,
256            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
257            U64 => u64, U32 => u32, U16 => u16, U8 => u8
258        ])
259    }};
260}
261
262/// Macro to execute an cat operation on a given set of element types.
263///
264/// # Panics
265/// Since there is no automatic type cast at this time, binary operations for different
266/// floating point precision data types will panic with a data type mismatch.
267#[macro_export]
268macro_rules! cat_with_dtype {
269    ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
270        match &$tensors[0] {
271            $(NdArrayTensor::$dtype(_) => {
272                let tensors = $tensors
273                    .iter()
274                    .map(|t| {
275                        if let NdArrayTensor::$dtype(tensor) = t {
276                            tensor.view()
277                        } else {
278                            panic!("Concatenate data type mismatch (expected f32, got f64)")
279                        }
280                    })
281                    .collect::<Vec<_>>();
282                NdArrayOps::concatenate(&tensors, $dim).into()
283            })*
284            _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
285        }
286    };
287}
288
289impl TensorMetadata for NdArrayTensor {
290    fn dtype(&self) -> DType {
291        match self {
292            NdArrayTensor::F64(_) => DType::F64,
293            NdArrayTensor::F32(_) => DType::F32,
294            NdArrayTensor::I64(_) => DType::I64,
295            NdArrayTensor::I32(_) => DType::I32,
296            NdArrayTensor::I16(_) => DType::I16,
297            NdArrayTensor::I8(_) => DType::I8,
298            NdArrayTensor::U64(_) => DType::U64,
299            NdArrayTensor::U32(_) => DType::U32,
300            NdArrayTensor::U16(_) => DType::U16,
301            NdArrayTensor::U8(_) => DType::U8,
302            NdArrayTensor::Bool(_) => DType::Bool,
303        }
304    }
305
306    fn shape(&self) -> Shape {
307        execute_with_dtype!(self, E, |a: &ArcArray<E, IxDyn>| Shape::from(
308            a.shape().to_vec()
309        ))
310    }
311
312    fn rank(&self) -> usize {
313        self.shape().num_dims()
314    }
315}
316
317pub(crate) trait ShapeOps {
318    fn num_dims(self) -> usize;
319    fn num_elements(self) -> usize;
320    fn dims<const N: usize>(self) -> [usize; N];
321    fn into_shape(self) -> Shape;
322}
323
324impl ShapeOps for &[usize] {
325    fn num_dims(self) -> usize {
326        self.len()
327    }
328
329    fn num_elements(self) -> usize {
330        self.iter().product()
331    }
332
333    fn dims<const N: usize>(self) -> [usize; N] {
334        self.try_into().unwrap()
335    }
336
337    fn into_shape(self) -> Shape {
338        Shape {
339            dims: self.to_vec(),
340        }
341    }
342}
343
344mod utils {
345    use burn_common::tensor::is_contiguous;
346
347    use super::*;
348
349    impl NdArrayTensor {
350        pub(crate) fn into_data(self) -> TensorData {
351            let shape = self.shape();
352            let contiguous = self.is_contiguous();
353
354            fn inner<E: Element>(
355                shape: Shape,
356                is_contiguous: bool,
357                array: ArcArray<E, IxDyn>,
358            ) -> TensorData {
359                let vec = if is_contiguous {
360                    match array.try_into_owned_nocopy() {
361                        Ok(owned) => {
362                            let (mut vec, offset) = owned.into_raw_vec_and_offset();
363                            if let Some(offset) = offset {
364                                vec.drain(..offset);
365                            }
366                            if vec.len() > shape.num_elements() {
367                                vec.drain(shape.num_elements()..vec.len());
368                            }
369                            vec
370                        }
371                        Err(array) => array.into_iter().collect(),
372                    }
373                } else {
374                    array.into_iter().collect()
375                };
376
377                TensorData::new(vec, shape)
378            }
379
380            execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
381        }
382
383        pub(crate) fn is_contiguous(&self) -> bool {
384            fn inner<E: Element>(array: &ArcArray<E, IxDyn>) -> bool {
385                let shape = array.shape();
386                let mut strides = Vec::with_capacity(array.strides().len());
387
388                for &stride in array.strides() {
389                    if stride <= 0 {
390                        return false;
391                    }
392                    strides.push(stride as usize);
393                }
394                is_contiguous(shape, &strides)
395            }
396
397            execute_with_dtype!(self, inner)
398        }
399    }
400}
401
402/// Converts a slice of usize to a typed dimension.
403#[macro_export(local_inner_macros)]
404macro_rules! to_typed_dims {
405    (
406        $n:expr,
407        $dims:expr,
408        justdim
409    ) => {{
410        let mut dims = [0; $n];
411        for i in 0..$n {
412            dims[i] = $dims[i];
413        }
414        let dim: Dim<[usize; $n]> = Dim(dims);
415        dim
416    }};
417}
418
419/// Reshapes an array into a tensor.
420#[macro_export(local_inner_macros)]
421macro_rules! reshape {
422    (
423        ty $ty:ty,
424        n $n:expr,
425        shape $shape:expr,
426        array $array:expr
427    ) => {{
428        let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
429        let array = match $array.is_standard_layout() {
430            true => {
431                match $array.to_shape(dim) {
432                    Ok(val) => val.into_shared(),
433                    Err(err) => {
434                        core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
435                    }
436                }
437            },
438            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
439        };
440        array.into_dyn()
441    }};
442    (
443        ty $ty:ty,
444        shape $shape:expr,
445        array $array:expr,
446        d $D:expr
447    ) => {{
448        match $D {
449            1 => reshape!(ty $ty, n 1, shape $shape, array $array),
450            2 => reshape!(ty $ty, n 2, shape $shape, array $array),
451            3 => reshape!(ty $ty, n 3, shape $shape, array $array),
452            4 => reshape!(ty $ty, n 4, shape $shape, array $array),
453            5 => reshape!(ty $ty, n 5, shape $shape, array $array),
454            6 => reshape!(ty $ty, n 6, shape $shape, array $array),
455            _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
456        }
457    }};
458}
459
460impl NdArrayTensor {
461    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
462    pub fn from_data(mut data: TensorData) -> NdArrayTensor {
463        let shape = mem::take(&mut data.shape);
464
465        macro_rules! execute {
466            ($data: expr, [$($dtype: ident => $ty: ty),*]) => {
467                match $data.dtype {
468                    $(DType::$dtype => {
469                        match data.into_vec::<$ty>() {
470                            // Safety: TensorData checks shape validity on creation, so we don't need to repeat that check here
471                            Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
472                            Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
473                        }.into()
474                    },)*
475                    other => unimplemented!("Unsupported dtype {other:?}"),
476                }
477            };
478        }
479
480        execute!(data, [
481            F64 => f64, F32 => f32,
482            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
483            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
484            Bool => bool
485        ])
486    }
487}
488
489/// A quantized tensor for the ndarray backend.
490#[derive(Clone, Debug)]
491pub struct NdArrayQTensor {
492    /// The quantized tensor.
493    pub qtensor: NdArrayTensor,
494    /// The quantization scheme.
495    pub scheme: QuantScheme,
496    /// The quantization parameters.
497    pub qparams: Vec<QParams<f32>>,
498}
499
500impl NdArrayQTensor {
501    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
502    pub fn strategy(&self) -> QuantizationStrategy {
503        match self.scheme {
504            QuantScheme {
505                level: QuantLevel::Tensor,
506                mode: QuantMode::Symmetric,
507                value:
508                    QuantValue::Q8F
509                    | QuantValue::Q8S
510                    | QuantValue::E4M3
511                    | QuantValue::E5M2
512                    | QuantValue::Q4F
513                    | QuantValue::Q4S
514                    | QuantValue::E2M1
515                    | QuantValue::Q2F
516                    | QuantValue::Q2S,
517                ..
518            } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
519                self.qparams[0].scales,
520                self.scheme.value,
521            )),
522            QuantScheme {
523                level: QuantLevel::Block(block_size),
524                mode: QuantMode::Symmetric,
525                value:
526                    QuantValue::Q8F
527                    | QuantValue::Q8S
528                    | QuantValue::E4M3
529                    | QuantValue::E5M2
530                    | QuantValue::Q4F
531                    | QuantValue::Q4S
532                    | QuantValue::E2M1
533                    | QuantValue::Q2F
534                    | QuantValue::Q2S,
535                ..
536            } => QuantizationStrategy::PerBlockSymmetric(
537                self.qparams
538                    .iter()
539                    .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
540                    .collect(),
541                block_size,
542            ),
543        }
544    }
545}
546
547impl QTensorPrimitive for NdArrayQTensor {
548    fn scheme(&self) -> &QuantScheme {
549        &self.scheme
550    }
551
552    fn default_scheme() -> QuantScheme {
553        QuantScheme::default().with_store(burn_tensor::quantization::QuantStore::Native)
554    }
555}
556
557impl TensorMetadata for NdArrayQTensor {
558    fn dtype(&self) -> DType {
559        DType::QFloat(self.scheme)
560    }
561
562    fn shape(&self) -> Shape {
563        self.qtensor.shape()
564    }
565
566    fn rank(&self) -> usize {
567        self.shape().num_dims()
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use crate::NdArray;
574
575    use super::*;
576    use burn_common::rand::get_seeded_rng;
577    use burn_tensor::{
578        Distribution,
579        ops::{FloatTensorOps, QTensorOps},
580        quantization::{QuantStore, QuantizationParametersPrimitive},
581    };
582
583    #[test]
584    fn should_support_into_and_from_data_1d() {
585        let data_expected = TensorData::random::<f32, _, _>(
586            Shape::new([3]),
587            Distribution::Default,
588            &mut get_seeded_rng(),
589        );
590        let tensor = NdArrayTensor::from_data(data_expected.clone());
591
592        let data_actual = tensor.into_data();
593
594        assert_eq!(data_expected, data_actual);
595    }
596
597    #[test]
598    fn should_support_into_and_from_data_2d() {
599        let data_expected = TensorData::random::<f32, _, _>(
600            Shape::new([2, 3]),
601            Distribution::Default,
602            &mut get_seeded_rng(),
603        );
604        let tensor = NdArrayTensor::from_data(data_expected.clone());
605
606        let data_actual = tensor.into_data();
607
608        assert_eq!(data_expected, data_actual);
609    }
610
611    #[test]
612    fn should_support_into_and_from_data_3d() {
613        let data_expected = TensorData::random::<f32, _, _>(
614            Shape::new([2, 3, 4]),
615            Distribution::Default,
616            &mut get_seeded_rng(),
617        );
618        let tensor = NdArrayTensor::from_data(data_expected.clone());
619
620        let data_actual = tensor.into_data();
621
622        assert_eq!(data_expected, data_actual);
623    }
624
625    #[test]
626    fn should_support_into_and_from_data_4d() {
627        let data_expected = TensorData::random::<f32, _, _>(
628            Shape::new([2, 3, 4, 2]),
629            Distribution::Default,
630            &mut get_seeded_rng(),
631        );
632        let tensor = NdArrayTensor::from_data(data_expected.clone());
633
634        let data_actual = tensor.into_data();
635
636        assert_eq!(data_expected, data_actual);
637    }
638
639    #[test]
640    fn should_support_qtensor_strategy() {
641        type B = NdArray<f32, i64, i8>;
642        let scale: f32 = 0.009_019_608;
643        let device = Default::default();
644
645        let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
646        let scheme = QuantScheme::default()
647            .with_value(QuantValue::Q8S)
648            .with_store(QuantStore::Native);
649        let qparams = QuantizationParametersPrimitive {
650            scales: B::float_from_data(TensorData::from([scale]), &device),
651        };
652        let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
653
654        assert_eq!(qtensor.scheme(), &scheme);
655        assert_eq!(
656            qtensor.strategy(),
657            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
658                scale,
659                QuantValue::Q8S
660            ))
661        );
662    }
663}