burn_ndarray/
tensor.rs

1use burn_tensor::{
2    quantization::{
3        AffineQuantization, QParams, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
4        QuantizationType, SymmetricQuantization,
5    },
6    DType, Element, Shape, TensorData, TensorMetadata,
7};
8
9use ndarray::{ArcArray, Array, Dim, IxDyn};
10
11use crate::element::QuantElement;
12
13/// Tensor primitive used by the [ndarray backend](crate::NdArray).
14#[derive(new, Debug, Clone)]
15pub struct NdArrayTensor<E> {
16    /// Dynamic array that contains the data of type E.
17    pub array: ArcArray<E, IxDyn>,
18}
19
20impl<E: Element> TensorMetadata for NdArrayTensor<E> {
21    fn dtype(&self) -> DType {
22        E::dtype()
23    }
24
25    fn shape(&self) -> Shape {
26        Shape::from(self.array.shape().to_vec())
27    }
28}
29
30/// Float tensor primitive.
31#[derive(Debug, Clone)]
32pub enum NdArrayTensorFloat {
33    /// 32-bit float.
34    F32(NdArrayTensor<f32>),
35    /// 64-bit float.
36    F64(NdArrayTensor<f64>),
37}
38
39impl From<NdArrayTensor<f32>> for NdArrayTensorFloat {
40    fn from(value: NdArrayTensor<f32>) -> Self {
41        NdArrayTensorFloat::F32(value)
42    }
43}
44
45impl From<NdArrayTensor<f64>> for NdArrayTensorFloat {
46    fn from(value: NdArrayTensor<f64>) -> Self {
47        NdArrayTensorFloat::F64(value)
48    }
49}
50
51impl TensorMetadata for NdArrayTensorFloat {
52    fn dtype(&self) -> DType {
53        match self {
54            NdArrayTensorFloat::F32(tensor) => tensor.dtype(),
55            NdArrayTensorFloat::F64(tensor) => tensor.dtype(),
56        }
57    }
58
59    fn shape(&self) -> Shape {
60        match self {
61            NdArrayTensorFloat::F32(tensor) => tensor.shape(),
62            NdArrayTensorFloat::F64(tensor) => tensor.shape(),
63        }
64    }
65}
66
67/// Macro to create a new [float tensor](NdArrayTensorFloat) based on the element type.
68#[macro_export]
69macro_rules! new_tensor_float {
70    // Op executed with default dtype
71    ($tensor:expr) => {{
72        match E::dtype() {
73            burn_tensor::DType::F64 => $crate::NdArrayTensorFloat::F64($tensor),
74            burn_tensor::DType::F32 => $crate::NdArrayTensorFloat::F32($tensor),
75            // FloatNdArrayElement only implemented for f64 and f32
76            _ => unimplemented!("Unsupported dtype"),
77        }
78    }};
79}
80
81/// Macro to execute an operation a given element type.
82///
83/// # Panics
84/// Since there is no automatic type cast at this time, binary operations for different
85/// floating point precision data types will panic with a data type mismatch.
86#[macro_export]
87macro_rules! execute_with_float_dtype {
88    // Binary op: type automatically inferred by the compiler
89    (($lhs:expr, $rhs:expr), $op:expr) => {{
90        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
91        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
92        match ($lhs, $rhs) {
93            ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
94                $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
95            }
96            ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
97                $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
98            }
99            _ => panic!(
100                "Data type mismatch (lhs: {:?}, rhs: {:?})",
101                lhs_dtype, rhs_dtype
102            ),
103        }
104    }};
105
106    // Binary op: generic type cannot be inferred for an operation
107    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
108        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
109        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
110        match ($lhs, $rhs) {
111            ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
112                type $element = f64;
113                $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
114            }
115            ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
116                type $element = f32;
117                $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
118            }
119            _ => panic!(
120                "Data type mismatch (lhs: {:?}, rhs: {:?})",
121                lhs_dtype, rhs_dtype
122            ),
123        }
124    }};
125
126    // Binary op: type automatically inferred by the compiler but return type is not a float tensor
127    (($lhs:expr, $rhs:expr) => $op:expr) => {{
128        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
129        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
130        match ($lhs, $rhs) {
131            ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
132                $op(lhs, rhs)
133            }
134            ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
135                $op(lhs, rhs)
136            }
137            _ => panic!(
138                "Data type mismatch (lhs: {:?}, rhs: {:?})",
139                lhs_dtype, rhs_dtype
140            ),
141        }
142    }};
143
144    // Unary op: type automatically inferred by the compiler
145    ($tensor:expr, $op:expr) => {{
146        match $tensor {
147            $crate::NdArrayTensorFloat::F64(tensor) => $crate::NdArrayTensorFloat::F64($op(tensor)),
148            $crate::NdArrayTensorFloat::F32(tensor) => $crate::NdArrayTensorFloat::F32($op(tensor)),
149        }
150    }};
151
152    // Unary op: generic type cannot be inferred for an operation
153    ($tensor:expr, $element:ident, $op:expr) => {{
154        match $tensor {
155            $crate::NdArrayTensorFloat::F64(tensor) => {
156                type $element = f64;
157                $crate::NdArrayTensorFloat::F64($op(tensor))
158            }
159            $crate::NdArrayTensorFloat::F32(tensor) => {
160                type $element = f32;
161                $crate::NdArrayTensorFloat::F32($op(tensor))
162            }
163        }
164    }};
165
166    // Unary op: type automatically inferred by the compiler but return type is not a float tensor
167    ($tensor:expr => $op:expr) => {{
168        match $tensor {
169            $crate::NdArrayTensorFloat::F64(tensor) => $op(tensor),
170            $crate::NdArrayTensorFloat::F32(tensor) => $op(tensor),
171        }
172    }};
173
174    // Unary op: generic type cannot be inferred for an operation and return type is not a float tensor
175    ($tensor:expr, $element:ident => $op:expr) => {{
176        match $tensor {
177            $crate::NdArrayTensorFloat::F64(tensor) => {
178                type $element = f64;
179                $op(tensor)
180            }
181            $crate::NdArrayTensorFloat::F32(tensor) => {
182                type $element = f32;
183                $op(tensor)
184            }
185        }
186    }};
187}
188
189#[cfg(test)]
190mod utils {
191    use super::*;
192    use crate::element::FloatNdArrayElement;
193
194    impl<E> NdArrayTensor<E>
195    where
196        E: Default + Clone,
197    {
198        pub(crate) fn into_data(self) -> TensorData
199        where
200            E: FloatNdArrayElement,
201        {
202            let shape = self.shape();
203
204            let vec = if self.is_contiguous() {
205                match self.array.try_into_owned_nocopy() {
206                    Ok(owned) => {
207                        let (vec, _offset) = owned.into_raw_vec_and_offset();
208                        vec
209                    }
210                    Err(array) => array.into_iter().collect(),
211                }
212            } else {
213                self.array.into_iter().collect()
214            };
215
216            TensorData::new(vec, shape)
217        }
218
219        pub(crate) fn is_contiguous(&self) -> bool {
220            let shape = self.array.shape();
221            let strides = self.array.strides();
222
223            if shape.is_empty() {
224                return true;
225            }
226
227            if shape.len() == 1 {
228                return strides[0] == 1;
229            }
230
231            let mut prev_stride = 1;
232            let mut current_num_elems_shape = 1;
233
234            for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
235                let stride = if *stride <= 0 {
236                    return false;
237                } else {
238                    *stride as usize
239                };
240                if i > 0 {
241                    if current_num_elems_shape != stride {
242                        return false;
243                    }
244
245                    if prev_stride >= stride {
246                        return false;
247                    }
248                }
249
250                current_num_elems_shape *= shape;
251                prev_stride = stride;
252            }
253
254            true
255        }
256    }
257}
258
259/// Converts a slice of usize to a typed dimension.
260#[macro_export(local_inner_macros)]
261macro_rules! to_typed_dims {
262    (
263        $n:expr,
264        $dims:expr,
265        justdim
266    ) => {{
267        let mut dims = [0; $n];
268        for i in 0..$n {
269            dims[i] = $dims[i];
270        }
271        let dim: Dim<[usize; $n]> = Dim(dims);
272        dim
273    }};
274}
275
276/// Reshapes an array into a tensor.
277#[macro_export(local_inner_macros)]
278macro_rules! reshape {
279    (
280        ty $ty:ty,
281        n $n:expr,
282        shape $shape:expr,
283        array $array:expr
284    ) => {{
285        let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
286        let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
287            true => $array
288                .to_shape(dim)
289                .expect("Safe to change shape without relayout")
290                .into_shared(),
291            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
292        };
293        let array = array.into_dyn();
294
295        NdArrayTensor::new(array)
296    }};
297    (
298        ty $ty:ty,
299        shape $shape:expr,
300        array $array:expr,
301        d $D:expr
302    ) => {{
303        match $D {
304            1 => reshape!(ty $ty, n 1, shape $shape, array $array),
305            2 => reshape!(ty $ty, n 2, shape $shape, array $array),
306            3 => reshape!(ty $ty, n 3, shape $shape, array $array),
307            4 => reshape!(ty $ty, n 4, shape $shape, array $array),
308            5 => reshape!(ty $ty, n 5, shape $shape, array $array),
309            6 => reshape!(ty $ty, n 6, shape $shape, array $array),
310            _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
311        }
312    }};
313}
314
315impl<E> NdArrayTensor<E>
316where
317    E: Element,
318{
319    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
320    pub fn from_data(data: TensorData) -> NdArrayTensor<E> {
321        let shape: Shape = data.shape.clone().into();
322        let into_array = |data: TensorData| match data.into_vec::<E>() {
323            Ok(vec) => Array::from_vec(vec).into_shared(),
324            Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
325        };
326        let to_array = |data: TensorData| Array::from_iter(data.iter::<E>()).into_shared();
327
328        let array = if data.dtype == E::dtype() {
329            into_array(data)
330        } else {
331            to_array(data)
332        };
333        let ndims = shape.num_dims();
334
335        reshape!(
336            ty E,
337            shape shape,
338            array array,
339            d ndims
340        )
341    }
342}
343
344/// A quantized tensor for the ndarray backend.
345#[derive(Clone, Debug)]
346pub struct NdArrayQTensor<Q: QuantElement> {
347    /// The quantized tensor.
348    pub qtensor: NdArrayTensor<Q>,
349    /// The quantization scheme.
350    pub scheme: QuantizationScheme,
351    /// The quantization parameters.
352    pub qparams: QParams<f32, Q>,
353}
354
355impl<Q: QuantElement> NdArrayQTensor<Q> {
356    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
357    pub fn strategy(&self) -> QuantizationStrategy {
358        match self.scheme {
359            QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
360                QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
361                    self.qparams.scale,
362                    self.qparams.offset.unwrap().elem(),
363                ))
364            }
365            QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
366                QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
367                    self.qparams.scale,
368                ))
369            }
370        }
371    }
372}
373
374impl<Q: QuantElement> QTensorPrimitive for NdArrayQTensor<Q> {
375    fn scheme(&self) -> &QuantizationScheme {
376        &self.scheme
377    }
378}
379
380impl<Q: QuantElement> TensorMetadata for NdArrayQTensor<Q> {
381    fn dtype(&self) -> DType {
382        DType::QFloat(self.scheme)
383    }
384
385    fn shape(&self) -> Shape {
386        self.qtensor.shape()
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use crate::NdArray;
393
394    use super::*;
395    use burn_common::rand::get_seeded_rng;
396    use burn_tensor::{
397        ops::{FloatTensorOps, IntTensorOps, QTensorOps},
398        quantization::{AffineQuantization, QuantizationParametersPrimitive, QuantizationType},
399        Distribution,
400    };
401
402    #[test]
403    fn should_support_into_and_from_data_1d() {
404        let data_expected = TensorData::random::<f32, _, _>(
405            Shape::new([3]),
406            Distribution::Default,
407            &mut get_seeded_rng(),
408        );
409        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
410
411        let data_actual = tensor.into_data();
412
413        assert_eq!(data_expected, data_actual);
414    }
415
416    #[test]
417    fn should_support_into_and_from_data_2d() {
418        let data_expected = TensorData::random::<f32, _, _>(
419            Shape::new([2, 3]),
420            Distribution::Default,
421            &mut get_seeded_rng(),
422        );
423        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
424
425        let data_actual = tensor.into_data();
426
427        assert_eq!(data_expected, data_actual);
428    }
429
430    #[test]
431    fn should_support_into_and_from_data_3d() {
432        let data_expected = TensorData::random::<f32, _, _>(
433            Shape::new([2, 3, 4]),
434            Distribution::Default,
435            &mut get_seeded_rng(),
436        );
437        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
438
439        let data_actual = tensor.into_data();
440
441        assert_eq!(data_expected, data_actual);
442    }
443
444    #[test]
445    fn should_support_into_and_from_data_4d() {
446        let data_expected = TensorData::random::<f32, _, _>(
447            Shape::new([2, 3, 4, 2]),
448            Distribution::Default,
449            &mut get_seeded_rng(),
450        );
451        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
452
453        let data_actual = tensor.into_data();
454
455        assert_eq!(data_expected, data_actual);
456    }
457
458    #[test]
459    fn should_support_qtensor_strategy() {
460        type B = NdArray<f32, i64, i8>;
461        let device = Default::default();
462
463        let tensor = B::float_from_data(TensorData::from([-1.8, -1.0, 0.0, 0.5]), &device);
464        let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
465        let qparams = QuantizationParametersPrimitive {
466            scale: B::float_from_data(TensorData::from([0.009_019_608]), &device),
467            offset: Some(B::int_from_data(TensorData::from([72]), &device)),
468        };
469        let qtensor: NdArrayQTensor<i8> = B::quantize(tensor, &scheme, qparams);
470
471        assert_eq!(qtensor.scheme(), &scheme);
472        assert_eq!(
473            qtensor.strategy(),
474            QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
475        );
476    }
477}