burn_ndarray/
tensor.rs

1use core::mem;
2
3use burn_tensor::{
4    DType, Element, Shape, TensorData, TensorMetadata,
5    quantization::{QParams, QTensorPrimitive, QuantLevel, QuantMode, QuantScheme, QuantValue},
6};
7
8use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
9use alloc::vec::Vec;
10use ndarray::{ArcArray, ArrayD, IxDyn};
11
12/// Concrete storage type for ndarray
13pub type SharedArray<E> = ArcArray<E, IxDyn>;
14
15/// Tensor primitive used by the [ndarray backend](crate::NdArray).
16#[derive(Debug, Clone)]
17#[allow(missing_docs)]
18pub enum NdArrayTensor {
19    F64(SharedArray<f64>),
20    F32(SharedArray<f32>),
21    I64(SharedArray<i64>),
22    I32(SharedArray<i32>),
23    I16(SharedArray<i16>),
24    I8(SharedArray<i8>),
25    U64(SharedArray<u64>),
26    U32(SharedArray<u32>),
27    U16(SharedArray<u16>),
28    U8(SharedArray<u8>),
29    Bool(SharedArray<bool>),
30}
31
32impl NdArrayTensor {
33    pub(crate) fn bool(self) -> SharedArray<bool> {
34        match self {
35            NdArrayTensor::Bool(arr) => arr,
36            _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
37        }
38    }
39}
40
41pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
42where
43    NdArrayTensor: From<SharedArray<E1>>,
44{
45    fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
46        array.mapv(|a| a.elem()).into_shared()
47    }
48
49    if E1::dtype() == dtype {
50        return array.into();
51    }
52
53    match dtype {
54        DType::F64 => cast::<E1, f64>(array).into(),
55        DType::F32 => cast::<E1, f32>(array).into(),
56        DType::Flex32 => cast::<E1, f32>(array).into(),
57        DType::I64 => cast::<E1, i64>(array).into(),
58        DType::I32 => cast::<E1, i32>(array).into(),
59        DType::I16 => cast::<E1, i16>(array).into(),
60        DType::I8 => cast::<E1, i8>(array).into(),
61        DType::U64 => cast::<E1, u64>(array).into(),
62        DType::U32 => cast::<E1, u32>(array).into(),
63        DType::U16 => cast::<E1, u16>(array).into(),
64        DType::U8 => cast::<E1, u8>(array).into(),
65        DType::Bool => cast::<E1, bool>(array).into(),
66        dtype => panic!("Unsupported dtype: {dtype:?}"),
67    }
68}
69
70macro_rules! impl_from {
71    ($($ty: ty => $dtype: ident),*) => {
72        $(impl From<SharedArray<$ty>> for NdArrayTensor {
73           fn from(value: SharedArray<$ty>) -> NdArrayTensor {
74                NdArrayTensor::$dtype(value)
75           }
76        })*
77    };
78}
79
80impl_from!(
81    f64 => F64, f32 => F32,
82    i64 => I64, i32 => I32, i16 => I16, i8 => I8,
83    u64 => U64, u32 => U32, u16 => U16, u8 => U8,
84    bool => Bool
85);
86
87/// Macro to execute an operation a given element type.
88///
89/// # Panics
90/// Since there is no automatic type cast at this time, binary operations for different
91/// floating point precision data types will panic with a data type mismatch.
92#[macro_export]
93macro_rules! execute_with_dtype {
94    (($lhs:expr, $rhs:expr),$element:ident,  $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
95        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
96        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
97        match ($lhs, $rhs) {
98            $(
99                ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
100                    #[allow(unused)]
101                    type $element = $ty;
102                    $op(lhs, rhs).into()
103                }
104            )*
105            _ => panic!(
106                "Data type mismatch (lhs: {:?}, rhs: {:?})",
107                lhs_dtype, rhs_dtype
108            ),
109        }
110    }};
111    // Binary op: type automatically inferred by the compiler
112    (($lhs:expr, $rhs:expr), $op:expr) => {{
113        $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
114    }};
115
116    // Binary op: generic type cannot be inferred for an operation
117    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
118        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
119            F64 => f64, F32 => f32,
120            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
121            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
122            Bool => bool
123        ])
124    }};
125
126    ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
127        match $tensor {
128            $(
129                $crate::NdArrayTensor::$dtype(lhs) => {
130                    #[allow(unused)]
131                    type $element = $ty;
132                    $op(lhs).into()
133                }
134            )*
135            #[allow(unreachable_patterns)]
136            other => unimplemented!("unsupported dtype: {:?}", other.dtype())
137        }
138    }};
139    // Unary op: type automatically inferred by the compiler
140    ($tensor:expr, $op:expr) => {{
141        $crate::execute_with_dtype!($tensor, E, $op)
142    }};
143
144    // Unary op: generic type cannot be inferred for an operation
145    ($tensor:expr, $element:ident, $op:expr) => {{
146        $crate::execute_with_dtype!($tensor, $element, $op, [
147            F64 => f64, F32 => f32,
148            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
149            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
150            Bool => bool
151        ])
152    }};
153}
154
155/// Macro to execute an operation a given element type.
156/// Only handles float types.
157///
158/// # Panics
159/// Since there is no automatic type cast at this time, binary operations for different
160/// floating point precision data types will panic with a data type mismatch.
161#[macro_export]
162macro_rules! execute_with_float_dtype {
163    // Binary op: type automatically inferred by the compiler
164    (($lhs:expr, $rhs:expr), $op:expr) => {{
165        $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
166    }};
167
168    // Binary op: generic type cannot be inferred for an operation
169    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
170        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
171            F64 => f64, F32 => f32
172        ])
173    }};
174
175    // Unary op: type automatically inferred by the compiler
176    ($tensor:expr, $op:expr) => {{
177        $crate::execute_with_float_dtype!($tensor, E, $op)
178    }};
179
180    // Unary op: generic type cannot be inferred for an operation
181    ($tensor:expr, $element:ident, $op:expr) => {{
182        $crate::execute_with_dtype!($tensor, $element, $op, [
183            F64 => f64, F32 => f32
184        ])
185    }};
186}
187
188/// Macro to execute an operation a given element type.
189/// Only handles int types.
190///
191/// # Panics
192/// Since there is no automatic type cast at this time, binary operations for different
193/// floating point precision data types will panic with a data type mismatch.
194#[macro_export]
195macro_rules! execute_with_int_dtype {
196    // Binary op: type automatically inferred by the compiler
197    (($lhs:expr, $rhs:expr), $op:expr) => {{
198        $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
199    }};
200
201    // Binary op: generic type cannot be inferred for an operation
202    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
203        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
204            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
205            U64 => u64, U32 => u32, U16 => u16, U8 => u8
206        ])
207    }};
208
209    // Unary op: type automatically inferred by the compiler
210    ($tensor:expr, $op:expr) => {{
211        $crate::execute_with_int_dtype!($tensor, E, $op)
212    }};
213
214    // Unary op: generic type cannot be inferred for an operation
215    ($tensor:expr, $element:ident, $op:expr) => {{
216        $crate::execute_with_dtype!($tensor, $element, $op, [
217            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
218            U64 => u64, U32 => u32, U16 => u16, U8 => u8
219        ])
220    }};
221}
222
223/// Macro to execute an operation a given element type.
224/// Only handles numeric types
225///
226/// # Panics
227/// Since there is no automatic type cast at this time, binary operations for different
228/// floating point precision data types will panic with a data type mismatch.
229#[macro_export]
230macro_rules! execute_with_numeric_dtype {
231    // Binary op: type automatically inferred by the compiler
232    (($lhs:expr, $rhs:expr), $op:expr) => {{
233        $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
234    }};
235
236    // Binary op: generic type cannot be inferred for an operation
237    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
238        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
239            F64 => f64, F32 => f32,
240            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
241            U64 => u64, U32 => u32, U16 => u16, U8 => u8
242        ])
243    }};
244
245    // Unary op: type automatically inferred by the compiler
246    ($tensor:expr, $op:expr) => {{
247        $crate::execute_with_numeric_dtype!($tensor, E, $op)
248    }};
249
250    // Unary op: generic type cannot be inferred for an operation
251    ($tensor:expr, $element:ident, $op:expr) => {{
252        $crate::execute_with_dtype!($tensor, $element, $op, [
253            F64 => f64, F32 => f32,
254            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
255            U64 => u64, U32 => u32, U16 => u16, U8 => u8
256        ])
257    }};
258}
259
260/// Macro to execute an cat operation on a given set of element types.
261///
262/// # Panics
263/// Since there is no automatic type cast at this time, binary operations for different
264/// floating point precision data types will panic with a data type mismatch.
265#[macro_export]
266macro_rules! cat_with_dtype {
267    ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
268        match &$tensors[0] {
269            $(NdArrayTensor::$dtype(_) => {
270                let tensors = $tensors
271                    .iter()
272                    .map(|t| {
273                        if let NdArrayTensor::$dtype(tensor) = t {
274                            tensor.view()
275                        } else {
276                            panic!("Concatenate data type mismatch (expected f32, got f64)")
277                        }
278                    })
279                    .collect::<Vec<_>>();
280                NdArrayOps::concatenate(&tensors, $dim).into()
281            })*
282            _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
283        }
284    };
285}
286
287impl TensorMetadata for NdArrayTensor {
288    fn dtype(&self) -> DType {
289        match self {
290            NdArrayTensor::F64(_) => DType::F64,
291            NdArrayTensor::F32(_) => DType::F32,
292            NdArrayTensor::I64(_) => DType::I64,
293            NdArrayTensor::I32(_) => DType::I32,
294            NdArrayTensor::I16(_) => DType::I16,
295            NdArrayTensor::I8(_) => DType::I8,
296            NdArrayTensor::U64(_) => DType::U64,
297            NdArrayTensor::U32(_) => DType::U32,
298            NdArrayTensor::U16(_) => DType::U16,
299            NdArrayTensor::U8(_) => DType::U8,
300            NdArrayTensor::Bool(_) => DType::Bool,
301        }
302    }
303
304    fn shape(&self) -> Shape {
305        execute_with_dtype!(self, E, |a: &ArcArray<E, IxDyn>| Shape::from(
306            a.shape().to_vec()
307        ))
308    }
309
310    fn rank(&self) -> usize {
311        self.shape().num_dims()
312    }
313}
314
315pub(crate) trait ShapeOps {
316    fn num_dims(self) -> usize;
317    fn num_elements(self) -> usize;
318    fn dims<const N: usize>(self) -> [usize; N];
319    fn into_shape(self) -> Shape;
320}
321
322impl ShapeOps for &[usize] {
323    fn num_dims(self) -> usize {
324        self.len()
325    }
326
327    fn num_elements(self) -> usize {
328        self.iter().product()
329    }
330
331    fn dims<const N: usize>(self) -> [usize; N] {
332        self.try_into().unwrap()
333    }
334
335    fn into_shape(self) -> Shape {
336        Shape {
337            dims: self.to_vec(),
338        }
339    }
340}
341
342mod utils {
343    use burn_std::tensor::is_contiguous;
344
345    use super::*;
346
347    impl NdArrayTensor {
348        pub(crate) fn into_data(self) -> TensorData {
349            let shape = self.shape();
350            let contiguous = self.is_contiguous();
351
352            fn inner<E: Element>(
353                shape: Shape,
354                is_contiguous: bool,
355                array: ArcArray<E, IxDyn>,
356            ) -> TensorData {
357                let vec = if is_contiguous {
358                    match array.try_into_owned_nocopy() {
359                        Ok(owned) => {
360                            let (mut vec, offset) = owned.into_raw_vec_and_offset();
361                            if let Some(offset) = offset {
362                                vec.drain(..offset);
363                            }
364                            if vec.len() > shape.num_elements() {
365                                vec.drain(shape.num_elements()..vec.len());
366                            }
367                            vec
368                        }
369                        Err(array) => array.into_iter().collect(),
370                    }
371                } else {
372                    array.into_iter().collect()
373                };
374
375                TensorData::new(vec, shape)
376            }
377
378            execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
379        }
380
381        pub(crate) fn is_contiguous(&self) -> bool {
382            fn inner<E: Element>(array: &ArcArray<E, IxDyn>) -> bool {
383                let shape = array.shape();
384                let mut strides = Vec::with_capacity(array.strides().len());
385
386                for &stride in array.strides() {
387                    if stride <= 0 {
388                        return false;
389                    }
390                    strides.push(stride as usize);
391                }
392                is_contiguous(shape, &strides)
393            }
394
395            execute_with_dtype!(self, inner)
396        }
397    }
398}
399
400/// Converts a slice of usize to a typed dimension.
401#[macro_export(local_inner_macros)]
402macro_rules! to_typed_dims {
403    (
404        $n:expr,
405        $dims:expr,
406        justdim
407    ) => {{
408        let mut dims = [0; $n];
409        for i in 0..$n {
410            dims[i] = $dims[i];
411        }
412        let dim: Dim<[usize; $n]> = Dim(dims);
413        dim
414    }};
415}
416
417/// Reshapes an array into a tensor.
418#[macro_export(local_inner_macros)]
419macro_rules! reshape {
420    (
421        ty $ty:ty,
422        n $n:expr,
423        shape $shape:expr,
424        array $array:expr
425    ) => {{
426        let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
427        let array = match $array.is_standard_layout() {
428            true => {
429                match $array.to_shape(dim) {
430                    Ok(val) => val.into_shared(),
431                    Err(err) => {
432                        core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
433                    }
434                }
435            },
436            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
437        };
438        array.into_dyn()
439    }};
440    (
441        ty $ty:ty,
442        shape $shape:expr,
443        array $array:expr,
444        d $D:expr
445    ) => {{
446        match $D {
447            1 => reshape!(ty $ty, n 1, shape $shape, array $array),
448            2 => reshape!(ty $ty, n 2, shape $shape, array $array),
449            3 => reshape!(ty $ty, n 3, shape $shape, array $array),
450            4 => reshape!(ty $ty, n 4, shape $shape, array $array),
451            5 => reshape!(ty $ty, n 5, shape $shape, array $array),
452            6 => reshape!(ty $ty, n 6, shape $shape, array $array),
453            _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
454        }
455    }};
456}
457
458impl NdArrayTensor {
459    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
460    pub fn from_data(mut data: TensorData) -> NdArrayTensor {
461        let shape = mem::take(&mut data.shape);
462
463        macro_rules! execute {
464            ($data: expr, [$($dtype: ident => $ty: ty),*]) => {
465                match $data.dtype {
466                    $(DType::$dtype => {
467                        match data.into_vec::<$ty>() {
468                            // Safety: TensorData checks shape validity on creation, so we don't need to repeat that check here
469                            Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
470                            Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
471                        }.into()
472                    },)*
473                    other => unimplemented!("Unsupported dtype {other:?}"),
474                }
475            };
476        }
477
478        execute!(data, [
479            F64 => f64, F32 => f32,
480            I64 => i64, I32 => i32, I16 => i16, I8 => i8,
481            U64 => u64, U32 => u32, U16 => u16, U8 => u8,
482            Bool => bool
483        ])
484    }
485}
486
487/// A quantized tensor for the ndarray backend.
488#[derive(Clone, Debug)]
489pub struct NdArrayQTensor {
490    /// The quantized tensor.
491    pub qtensor: NdArrayTensor,
492    /// The quantization scheme.
493    pub scheme: QuantScheme,
494    /// The quantization parameters.
495    pub qparams: Vec<QParams<f32>>,
496}
497
498impl NdArrayQTensor {
499    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
500    pub fn strategy(&self) -> QuantizationStrategy {
501        match self.scheme {
502            QuantScheme {
503                level: QuantLevel::Tensor,
504                mode: QuantMode::Symmetric,
505                value:
506                    QuantValue::Q8F
507                    | QuantValue::Q8S
508                    | QuantValue::E4M3
509                    | QuantValue::E5M2
510                    | QuantValue::Q4F
511                    | QuantValue::Q4S
512                    | QuantValue::E2M1
513                    | QuantValue::Q2F
514                    | QuantValue::Q2S,
515                ..
516            } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
517                self.qparams[0].scales,
518                self.scheme.value,
519            )),
520            QuantScheme {
521                level: QuantLevel::Block(block_size),
522                mode: QuantMode::Symmetric,
523                value:
524                    QuantValue::Q8F
525                    | QuantValue::Q8S
526                    | QuantValue::E4M3
527                    | QuantValue::E5M2
528                    | QuantValue::Q4F
529                    | QuantValue::Q4S
530                    | QuantValue::E2M1
531                    | QuantValue::Q2F
532                    | QuantValue::Q2S,
533                ..
534            } => QuantizationStrategy::PerBlockSymmetric(
535                self.qparams
536                    .iter()
537                    .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
538                    .collect(),
539                block_size,
540            ),
541        }
542    }
543}
544
545impl QTensorPrimitive for NdArrayQTensor {
546    fn scheme(&self) -> &QuantScheme {
547        &self.scheme
548    }
549
550    fn default_scheme() -> QuantScheme {
551        QuantScheme::default().with_store(burn_tensor::quantization::QuantStore::Native)
552    }
553}
554
555impl TensorMetadata for NdArrayQTensor {
556    fn dtype(&self) -> DType {
557        DType::QFloat(self.scheme)
558    }
559
560    fn shape(&self) -> Shape {
561        self.qtensor.shape()
562    }
563
564    fn rank(&self) -> usize {
565        self.shape().num_dims()
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use crate::NdArray;
572
573    use super::*;
574    use burn_std::rand::get_seeded_rng;
575    use burn_tensor::{
576        Distribution,
577        ops::{FloatTensorOps, QTensorOps},
578        quantization::{QuantStore, QuantizationParametersPrimitive},
579    };
580
581    #[test]
582    fn should_support_into_and_from_data_1d() {
583        let data_expected = TensorData::random::<f32, _, _>(
584            Shape::new([3]),
585            Distribution::Default,
586            &mut get_seeded_rng(),
587        );
588        let tensor = NdArrayTensor::from_data(data_expected.clone());
589
590        let data_actual = tensor.into_data();
591
592        assert_eq!(data_expected, data_actual);
593    }
594
595    #[test]
596    fn should_support_into_and_from_data_2d() {
597        let data_expected = TensorData::random::<f32, _, _>(
598            Shape::new([2, 3]),
599            Distribution::Default,
600            &mut get_seeded_rng(),
601        );
602        let tensor = NdArrayTensor::from_data(data_expected.clone());
603
604        let data_actual = tensor.into_data();
605
606        assert_eq!(data_expected, data_actual);
607    }
608
609    #[test]
610    fn should_support_into_and_from_data_3d() {
611        let data_expected = TensorData::random::<f32, _, _>(
612            Shape::new([2, 3, 4]),
613            Distribution::Default,
614            &mut get_seeded_rng(),
615        );
616        let tensor = NdArrayTensor::from_data(data_expected.clone());
617
618        let data_actual = tensor.into_data();
619
620        assert_eq!(data_expected, data_actual);
621    }
622
623    #[test]
624    fn should_support_into_and_from_data_4d() {
625        let data_expected = TensorData::random::<f32, _, _>(
626            Shape::new([2, 3, 4, 2]),
627            Distribution::Default,
628            &mut get_seeded_rng(),
629        );
630        let tensor = NdArrayTensor::from_data(data_expected.clone());
631
632        let data_actual = tensor.into_data();
633
634        assert_eq!(data_expected, data_actual);
635    }
636
637    #[test]
638    fn should_support_qtensor_strategy() {
639        type B = NdArray<f32, i64, i8>;
640        let scale: f32 = 0.009_019_608;
641        let device = Default::default();
642
643        let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
644        let scheme = QuantScheme::default()
645            .with_value(QuantValue::Q8S)
646            .with_store(QuantStore::Native);
647        let qparams = QuantizationParametersPrimitive {
648            scales: B::float_from_data(TensorData::from([scale]), &device),
649        };
650        let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
651
652        assert_eq!(qtensor.scheme(), &scheme);
653        assert_eq!(
654            qtensor.strategy(),
655            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
656                scale,
657                QuantValue::Q8S
658            ))
659        );
660    }
661}