burn_ndarray/
tensor.rs

1use core::mem;
2
3use burn_tensor::{
4    DType, Element, Shape, TensorData, TensorMetadata,
5    quantization::{
6        QParams, QTensorPrimitive, QuantizationMode, QuantizationScheme, QuantizationStrategy,
7        QuantizationType, SymmetricQuantization,
8    },
9};
10
11use alloc::vec::Vec;
12use ndarray::{ArcArray, ArrayD, IxDyn};
13
14use crate::element::QuantElement;
15
16/// Tensor primitive used by the [ndarray backend](crate::NdArray).
17#[derive(new, Debug, Clone)]
18pub struct NdArrayTensor<E> {
19    /// Dynamic array that contains the data of type E.
20    pub array: ArcArray<E, IxDyn>,
21}
22
23impl<E: Element> TensorMetadata for NdArrayTensor<E> {
24    fn dtype(&self) -> DType {
25        E::dtype()
26    }
27
28    fn shape(&self) -> Shape {
29        Shape::from(self.array.shape().to_vec())
30    }
31}
32
33/// Float tensor primitive.
34#[derive(Debug, Clone)]
35pub enum NdArrayTensorFloat {
36    /// 32-bit float.
37    F32(NdArrayTensor<f32>),
38    /// 64-bit float.
39    F64(NdArrayTensor<f64>),
40}
41
42impl From<NdArrayTensor<f32>> for NdArrayTensorFloat {
43    fn from(value: NdArrayTensor<f32>) -> Self {
44        NdArrayTensorFloat::F32(value)
45    }
46}
47
48impl From<NdArrayTensor<f64>> for NdArrayTensorFloat {
49    fn from(value: NdArrayTensor<f64>) -> Self {
50        NdArrayTensorFloat::F64(value)
51    }
52}
53
54impl TensorMetadata for NdArrayTensorFloat {
55    fn dtype(&self) -> DType {
56        match self {
57            NdArrayTensorFloat::F32(tensor) => tensor.dtype(),
58            NdArrayTensorFloat::F64(tensor) => tensor.dtype(),
59        }
60    }
61
62    fn shape(&self) -> Shape {
63        match self {
64            NdArrayTensorFloat::F32(tensor) => tensor.shape(),
65            NdArrayTensorFloat::F64(tensor) => tensor.shape(),
66        }
67    }
68}
69
70/// Macro to create a new [float tensor](NdArrayTensorFloat) based on the element type.
71#[macro_export]
72macro_rules! new_tensor_float {
73    // Op executed with default dtype
74    ($tensor:expr) => {{
75        match E::dtype() {
76            burn_tensor::DType::F64 => $crate::NdArrayTensorFloat::F64($tensor),
77            burn_tensor::DType::F32 => $crate::NdArrayTensorFloat::F32($tensor),
78            // FloatNdArrayElement only implemented for f64 and f32
79            _ => unimplemented!("Unsupported dtype"),
80        }
81    }};
82}
83
84/// Macro to execute an operation a given element type.
85///
86/// # Panics
87/// Since there is no automatic type cast at this time, binary operations for different
88/// floating point precision data types will panic with a data type mismatch.
89#[macro_export]
90macro_rules! execute_with_float_dtype {
91    // Binary op: type automatically inferred by the compiler
92    (($lhs:expr, $rhs:expr), $op:expr) => {{
93        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
94        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
95        match ($lhs, $rhs) {
96            ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
97                $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
98            }
99            ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
100                $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
101            }
102            _ => panic!(
103                "Data type mismatch (lhs: {:?}, rhs: {:?})",
104                lhs_dtype, rhs_dtype
105            ),
106        }
107    }};
108
109    // Binary op: generic type cannot be inferred for an operation
110    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
111        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
112        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
113        match ($lhs, $rhs) {
114            ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
115                type $element = f64;
116                $crate::NdArrayTensorFloat::F64($op(lhs, rhs))
117            }
118            ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
119                type $element = f32;
120                $crate::NdArrayTensorFloat::F32($op(lhs, rhs))
121            }
122            _ => panic!(
123                "Data type mismatch (lhs: {:?}, rhs: {:?})",
124                lhs_dtype, rhs_dtype
125            ),
126        }
127    }};
128
129    // Binary op: type automatically inferred by the compiler but return type is not a float tensor
130    (($lhs:expr, $rhs:expr) => $op:expr) => {{
131        let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
132        let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
133        match ($lhs, $rhs) {
134            ($crate::NdArrayTensorFloat::F64(lhs), $crate::NdArrayTensorFloat::F64(rhs)) => {
135                $op(lhs, rhs)
136            }
137            ($crate::NdArrayTensorFloat::F32(lhs), $crate::NdArrayTensorFloat::F32(rhs)) => {
138                $op(lhs, rhs)
139            }
140            _ => panic!(
141                "Data type mismatch (lhs: {:?}, rhs: {:?})",
142                lhs_dtype, rhs_dtype
143            ),
144        }
145    }};
146
147    // Unary op: type automatically inferred by the compiler
148    ($tensor:expr, $op:expr) => {{
149        match $tensor {
150            $crate::NdArrayTensorFloat::F64(tensor) => $crate::NdArrayTensorFloat::F64($op(tensor)),
151            $crate::NdArrayTensorFloat::F32(tensor) => $crate::NdArrayTensorFloat::F32($op(tensor)),
152        }
153    }};
154
155    // Unary op: generic type cannot be inferred for an operation
156    ($tensor:expr, $element:ident, $op:expr) => {{
157        match $tensor {
158            $crate::NdArrayTensorFloat::F64(tensor) => {
159                type $element = f64;
160                $crate::NdArrayTensorFloat::F64($op(tensor))
161            }
162            $crate::NdArrayTensorFloat::F32(tensor) => {
163                type $element = f32;
164                $crate::NdArrayTensorFloat::F32($op(tensor))
165            }
166        }
167    }};
168
169    // Unary op: type automatically inferred by the compiler but return type is not a float tensor
170    ($tensor:expr => $op:expr) => {{
171        match $tensor {
172            $crate::NdArrayTensorFloat::F64(tensor) => $op(tensor),
173            $crate::NdArrayTensorFloat::F32(tensor) => $op(tensor),
174        }
175    }};
176
177    // Unary op: generic type cannot be inferred for an operation and return type is not a float tensor
178    ($tensor:expr, $element:ident => $op:expr) => {{
179        match $tensor {
180            $crate::NdArrayTensorFloat::F64(tensor) => {
181                type $element = f64;
182                $op(tensor)
183            }
184            $crate::NdArrayTensorFloat::F32(tensor) => {
185                type $element = f32;
186                $op(tensor)
187            }
188        }
189    }};
190}
191
192mod utils {
193    use super::*;
194
195    impl<E> NdArrayTensor<E>
196    where
197        E: Element,
198    {
199        pub(crate) fn into_data(self) -> TensorData {
200            let shape = self.shape();
201
202            let vec = if self.is_contiguous() {
203                match self.array.try_into_owned_nocopy() {
204                    Ok(owned) => {
205                        let (mut vec, offset) = owned.into_raw_vec_and_offset();
206                        if let Some(offset) = offset {
207                            vec.drain(..offset);
208                        }
209                        vec
210                    }
211                    Err(array) => array.into_iter().collect(),
212                }
213            } else {
214                self.array.into_iter().collect()
215            };
216
217            TensorData::new(vec, shape)
218        }
219
220        pub(crate) fn is_contiguous(&self) -> bool {
221            let shape = self.array.shape();
222            let strides = self.array.strides();
223
224            if shape.is_empty() {
225                return true;
226            }
227
228            if shape.len() == 1 {
229                return strides[0] == 1;
230            }
231
232            let mut prev_stride = 1;
233            let mut current_num_elems_shape = 1;
234
235            for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
236                let stride = if *stride <= 0 {
237                    return false;
238                } else {
239                    *stride as usize
240                };
241                if i > 0 {
242                    if current_num_elems_shape != stride {
243                        return false;
244                    }
245
246                    if prev_stride > stride {
247                        return false;
248                    }
249                }
250
251                current_num_elems_shape *= shape;
252                prev_stride = stride;
253            }
254
255            true
256        }
257    }
258}
259
260/// Converts a slice of usize to a typed dimension.
261#[macro_export(local_inner_macros)]
262macro_rules! to_typed_dims {
263    (
264        $n:expr,
265        $dims:expr,
266        justdim
267    ) => {{
268        let mut dims = [0; $n];
269        for i in 0..$n {
270            dims[i] = $dims[i];
271        }
272        let dim: Dim<[usize; $n]> = Dim(dims);
273        dim
274    }};
275}
276
277/// Reshapes an array into a tensor.
278#[macro_export(local_inner_macros)]
279macro_rules! reshape {
280    (
281        ty $ty:ty,
282        n $n:expr,
283        shape $shape:expr,
284        array $array:expr
285    ) => {{
286        let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
287        let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
288            true => {
289                match $array.to_shape(dim) {
290                    Ok(val) => val.into_shared(),
291                    Err(err) => {
292                        core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
293                    }
294                }
295            },
296            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
297        };
298        let array = array.into_dyn();
299
300        NdArrayTensor::new(array)
301    }};
302    (
303        ty $ty:ty,
304        shape $shape:expr,
305        array $array:expr,
306        d $D:expr
307    ) => {{
308        match $D {
309            1 => reshape!(ty $ty, n 1, shape $shape, array $array),
310            2 => reshape!(ty $ty, n 2, shape $shape, array $array),
311            3 => reshape!(ty $ty, n 3, shape $shape, array $array),
312            4 => reshape!(ty $ty, n 4, shape $shape, array $array),
313            5 => reshape!(ty $ty, n 5, shape $shape, array $array),
314            6 => reshape!(ty $ty, n 6, shape $shape, array $array),
315            _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
316        }
317    }};
318}
319
320impl<E> NdArrayTensor<E>
321where
322    E: Element,
323{
324    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
325    pub fn from_data(mut data: TensorData) -> NdArrayTensor<E> {
326        let shape = mem::take(&mut data.shape);
327
328        let array = match data.into_vec::<E>() {
329            // Safety: TensorData checks shape validity on creation, so we don't need to repeat that check here
330            Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
331            Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
332        };
333
334        NdArrayTensor::new(array)
335    }
336}
337
338/// A quantized tensor for the ndarray backend.
339#[derive(Clone, Debug)]
340pub struct NdArrayQTensor<Q: QuantElement> {
341    /// The quantized tensor.
342    pub qtensor: NdArrayTensor<Q>,
343    /// The quantization scheme.
344    pub scheme: QuantizationScheme,
345    /// The quantization parameters.
346    pub qparams: Vec<QParams<f32, Q>>,
347}
348
349impl<Q: QuantElement> NdArrayQTensor<Q> {
350    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
351    pub fn strategy(&self) -> QuantizationStrategy {
352        match self.scheme {
353            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
354                QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
355                    self.qparams[0].scale,
356                ))
357            }
358        }
359    }
360}
361
362impl<Q: QuantElement> QTensorPrimitive for NdArrayQTensor<Q> {
363    fn scheme(&self) -> &QuantizationScheme {
364        &self.scheme
365    }
366}
367
368impl<Q: QuantElement> TensorMetadata for NdArrayQTensor<Q> {
369    fn dtype(&self) -> DType {
370        DType::QFloat(self.scheme)
371    }
372
373    fn shape(&self) -> Shape {
374        self.qtensor.shape()
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use crate::NdArray;
381
382    use super::*;
383    use burn_common::rand::get_seeded_rng;
384    use burn_tensor::{
385        Distribution,
386        ops::{FloatTensorOps, QTensorOps},
387        quantization::{QuantizationParametersPrimitive, QuantizationType},
388    };
389
390    #[test]
391    fn should_support_into_and_from_data_1d() {
392        let data_expected = TensorData::random::<f32, _, _>(
393            Shape::new([3]),
394            Distribution::Default,
395            &mut get_seeded_rng(),
396        );
397        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
398
399        let data_actual = tensor.into_data();
400
401        assert_eq!(data_expected, data_actual);
402    }
403
404    #[test]
405    fn should_support_into_and_from_data_2d() {
406        let data_expected = TensorData::random::<f32, _, _>(
407            Shape::new([2, 3]),
408            Distribution::Default,
409            &mut get_seeded_rng(),
410        );
411        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
412
413        let data_actual = tensor.into_data();
414
415        assert_eq!(data_expected, data_actual);
416    }
417
418    #[test]
419    fn should_support_into_and_from_data_3d() {
420        let data_expected = TensorData::random::<f32, _, _>(
421            Shape::new([2, 3, 4]),
422            Distribution::Default,
423            &mut get_seeded_rng(),
424        );
425        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
426
427        let data_actual = tensor.into_data();
428
429        assert_eq!(data_expected, data_actual);
430    }
431
432    #[test]
433    fn should_support_into_and_from_data_4d() {
434        let data_expected = TensorData::random::<f32, _, _>(
435            Shape::new([2, 3, 4, 2]),
436            Distribution::Default,
437            &mut get_seeded_rng(),
438        );
439        let tensor = NdArrayTensor::<f32>::from_data(data_expected.clone());
440
441        let data_actual = tensor.into_data();
442
443        assert_eq!(data_expected, data_actual);
444    }
445
446    #[test]
447    fn should_support_qtensor_strategy() {
448        type B = NdArray<f32, i64, i8>;
449        let scale: f32 = 0.009_019_608;
450        let device = Default::default();
451
452        let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
453        let scheme =
454            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8);
455        let qparams = QuantizationParametersPrimitive {
456            scale: B::float_from_data(TensorData::from([scale]), &device),
457            offset: None,
458        };
459        let qtensor: NdArrayQTensor<i8> = B::quantize(tensor, &scheme, qparams);
460
461        assert_eq!(qtensor.scheme(), &scheme);
462        assert_eq!(
463            qtensor.strategy(),
464            QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale))
465        );
466    }
467}