1use core::mem;
2
3use burn_tensor::{
4 DType, Element, Shape, TensorData, TensorMetadata,
5 quantization::{QParams, QTensorPrimitive, QuantLevel, QuantMode, QuantScheme, QuantValue},
6};
7
8use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
9use alloc::vec::Vec;
10use ndarray::{ArcArray, ArrayD, IxDyn};
11
12pub type SharedArray<E> = ArcArray<E, IxDyn>;
14
15#[derive(Debug, Clone)]
17#[allow(missing_docs)]
18pub enum NdArrayTensor {
19 F64(SharedArray<f64>),
20 F32(SharedArray<f32>),
21 I64(SharedArray<i64>),
22 I32(SharedArray<i32>),
23 I16(SharedArray<i16>),
24 I8(SharedArray<i8>),
25 U64(SharedArray<u64>),
26 U32(SharedArray<u32>),
27 U16(SharedArray<u16>),
28 U8(SharedArray<u8>),
29 Bool(SharedArray<bool>),
30}
31
32impl NdArrayTensor {
33 pub(crate) fn bool(self) -> SharedArray<bool> {
34 match self {
35 NdArrayTensor::Bool(arr) => arr,
36 _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
37 }
38 }
39}
40
41pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
42where
43 NdArrayTensor: From<SharedArray<E1>>,
44{
45 fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
46 array.mapv(|a| a.elem()).into_shared()
47 }
48
49 if E1::dtype() == dtype {
50 return array.into();
51 }
52
53 match dtype {
54 DType::F64 => cast::<E1, f64>(array).into(),
55 DType::F32 => cast::<E1, f32>(array).into(),
56 DType::Flex32 => cast::<E1, f32>(array).into(),
57 DType::I64 => cast::<E1, i64>(array).into(),
58 DType::I32 => cast::<E1, i32>(array).into(),
59 DType::I16 => cast::<E1, i16>(array).into(),
60 DType::I8 => cast::<E1, i8>(array).into(),
61 DType::U64 => cast::<E1, u64>(array).into(),
62 DType::U32 => cast::<E1, u32>(array).into(),
63 DType::U16 => cast::<E1, u16>(array).into(),
64 DType::U8 => cast::<E1, u8>(array).into(),
65 DType::Bool => cast::<E1, bool>(array).into(),
66 dtype => panic!("Unsupported dtype: {dtype:?}"),
67 }
68}
69
70macro_rules! impl_from {
71 ($($ty: ty => $dtype: ident),*) => {
72 $(impl From<SharedArray<$ty>> for NdArrayTensor {
73 fn from(value: SharedArray<$ty>) -> NdArrayTensor {
74 NdArrayTensor::$dtype(value)
75 }
76 })*
77 };
78}
79
80impl_from!(
81 f64 => F64, f32 => F32,
82 i64 => I64, i32 => I32, i16 => I16, i8 => I8,
83 u64 => U64, u32 => U32, u16 => U16, u8 => U8,
84 bool => Bool
85);
86
87#[macro_export]
93macro_rules! execute_with_dtype {
94 (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
95 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
96 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
97 match ($lhs, $rhs) {
98 $(
99 ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
100 #[allow(unused)]
101 type $element = $ty;
102 $op(lhs, rhs).into()
103 }
104 )*
105 _ => panic!(
106 "Data type mismatch (lhs: {:?}, rhs: {:?})",
107 lhs_dtype, rhs_dtype
108 ),
109 }
110 }};
111 (($lhs:expr, $rhs:expr), $op:expr) => {{
113 $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
114 }};
115
116 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
118 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
119 F64 => f64, F32 => f32,
120 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
121 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
122 Bool => bool
123 ])
124 }};
125
126 ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
127 match $tensor {
128 $(
129 $crate::NdArrayTensor::$dtype(lhs) => {
130 #[allow(unused)]
131 type $element = $ty;
132 $op(lhs).into()
133 }
134 )*
135 #[allow(unreachable_patterns)]
136 other => unimplemented!("unsupported dtype: {:?}", other.dtype())
137 }
138 }};
139 ($tensor:expr, $op:expr) => {{
141 $crate::execute_with_dtype!($tensor, E, $op)
142 }};
143
144 ($tensor:expr, $element:ident, $op:expr) => {{
146 $crate::execute_with_dtype!($tensor, $element, $op, [
147 F64 => f64, F32 => f32,
148 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
149 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
150 Bool => bool
151 ])
152 }};
153}
154
155#[macro_export]
162macro_rules! execute_with_float_dtype {
163 (($lhs:expr, $rhs:expr), $op:expr) => {{
165 $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
166 }};
167
168 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
170 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
171 F64 => f64, F32 => f32
172 ])
173 }};
174
175 ($tensor:expr, $op:expr) => {{
177 $crate::execute_with_float_dtype!($tensor, E, $op)
178 }};
179
180 ($tensor:expr, $element:ident, $op:expr) => {{
182 $crate::execute_with_dtype!($tensor, $element, $op, [
183 F64 => f64, F32 => f32
184 ])
185 }};
186}
187
188#[macro_export]
195macro_rules! execute_with_int_dtype {
196 (($lhs:expr, $rhs:expr), $op:expr) => {{
198 $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
199 }};
200
201 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
203 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
204 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
205 U64 => u64, U32 => u32, U16 => u16, U8 => u8
206 ])
207 }};
208
209 ($tensor:expr, $op:expr) => {{
211 $crate::execute_with_int_dtype!($tensor, E, $op)
212 }};
213
214 ($tensor:expr, $element:ident, $op:expr) => {{
216 $crate::execute_with_dtype!($tensor, $element, $op, [
217 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
218 U64 => u64, U32 => u32, U16 => u16, U8 => u8
219 ])
220 }};
221}
222
223#[macro_export]
230macro_rules! execute_with_numeric_dtype {
231 (($lhs:expr, $rhs:expr), $op:expr) => {{
233 $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
234 }};
235
236 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
238 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
239 F64 => f64, F32 => f32,
240 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
241 U64 => u64, U32 => u32, U16 => u16, U8 => u8
242 ])
243 }};
244
245 ($tensor:expr, $op:expr) => {{
247 $crate::execute_with_numeric_dtype!($tensor, E, $op)
248 }};
249
250 ($tensor:expr, $element:ident, $op:expr) => {{
252 $crate::execute_with_dtype!($tensor, $element, $op, [
253 F64 => f64, F32 => f32,
254 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
255 U64 => u64, U32 => u32, U16 => u16, U8 => u8
256 ])
257 }};
258}
259
260#[macro_export]
266macro_rules! cat_with_dtype {
267 ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
268 match &$tensors[0] {
269 $(NdArrayTensor::$dtype(_) => {
270 let tensors = $tensors
271 .iter()
272 .map(|t| {
273 if let NdArrayTensor::$dtype(tensor) = t {
274 tensor.view()
275 } else {
276 panic!("Concatenate data type mismatch (expected f32, got f64)")
277 }
278 })
279 .collect::<Vec<_>>();
280 NdArrayOps::concatenate(&tensors, $dim).into()
281 })*
282 _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
283 }
284 };
285}
286
287impl TensorMetadata for NdArrayTensor {
288 fn dtype(&self) -> DType {
289 match self {
290 NdArrayTensor::F64(_) => DType::F64,
291 NdArrayTensor::F32(_) => DType::F32,
292 NdArrayTensor::I64(_) => DType::I64,
293 NdArrayTensor::I32(_) => DType::I32,
294 NdArrayTensor::I16(_) => DType::I16,
295 NdArrayTensor::I8(_) => DType::I8,
296 NdArrayTensor::U64(_) => DType::U64,
297 NdArrayTensor::U32(_) => DType::U32,
298 NdArrayTensor::U16(_) => DType::U16,
299 NdArrayTensor::U8(_) => DType::U8,
300 NdArrayTensor::Bool(_) => DType::Bool,
301 }
302 }
303
304 fn shape(&self) -> Shape {
305 execute_with_dtype!(self, E, |a: &ArcArray<E, IxDyn>| Shape::from(
306 a.shape().to_vec()
307 ))
308 }
309
310 fn rank(&self) -> usize {
311 self.shape().num_dims()
312 }
313}
314
315pub(crate) trait ShapeOps {
316 fn num_dims(self) -> usize;
317 fn num_elements(self) -> usize;
318 fn dims<const N: usize>(self) -> [usize; N];
319 fn into_shape(self) -> Shape;
320}
321
322impl ShapeOps for &[usize] {
323 fn num_dims(self) -> usize {
324 self.len()
325 }
326
327 fn num_elements(self) -> usize {
328 self.iter().product()
329 }
330
331 fn dims<const N: usize>(self) -> [usize; N] {
332 self.try_into().unwrap()
333 }
334
335 fn into_shape(self) -> Shape {
336 Shape {
337 dims: self.to_vec(),
338 }
339 }
340}
341
342mod utils {
343 use burn_common::tensor::is_contiguous;
344
345 use super::*;
346
347 impl NdArrayTensor {
348 pub(crate) fn into_data(self) -> TensorData {
349 let shape = self.shape();
350 let contiguous = self.is_contiguous();
351
352 fn inner<E: Element>(
353 shape: Shape,
354 is_contiguous: bool,
355 array: ArcArray<E, IxDyn>,
356 ) -> TensorData {
357 let vec = if is_contiguous {
358 match array.try_into_owned_nocopy() {
359 Ok(owned) => {
360 let (mut vec, offset) = owned.into_raw_vec_and_offset();
361 if let Some(offset) = offset {
362 vec.drain(..offset);
363 }
364 if vec.len() > shape.num_elements() {
365 vec.drain(shape.num_elements()..vec.len());
366 }
367 vec
368 }
369 Err(array) => array.into_iter().collect(),
370 }
371 } else {
372 array.into_iter().collect()
373 };
374
375 TensorData::new(vec, shape)
376 }
377
378 execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
379 }
380
381 pub(crate) fn is_contiguous(&self) -> bool {
382 fn inner<E: Element>(array: &ArcArray<E, IxDyn>) -> bool {
383 let shape = array.shape();
384 let mut strides = Vec::with_capacity(array.strides().len());
385
386 for &stride in array.strides() {
387 if stride <= 0 {
388 return false;
389 }
390 strides.push(stride as usize);
391 }
392 is_contiguous(shape, &strides)
393 }
394
395 execute_with_dtype!(self, inner)
396 }
397 }
398}
399
400#[macro_export(local_inner_macros)]
402macro_rules! to_typed_dims {
403 (
404 $n:expr,
405 $dims:expr,
406 justdim
407 ) => {{
408 let mut dims = [0; $n];
409 for i in 0..$n {
410 dims[i] = $dims[i];
411 }
412 let dim: Dim<[usize; $n]> = Dim(dims);
413 dim
414 }};
415}
416
417#[macro_export(local_inner_macros)]
419macro_rules! reshape {
420 (
421 ty $ty:ty,
422 n $n:expr,
423 shape $shape:expr,
424 array $array:expr
425 ) => {{
426 let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
427 let array = match $array.is_standard_layout() {
428 true => {
429 match $array.to_shape(dim) {
430 Ok(val) => val.into_shared(),
431 Err(err) => {
432 core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
433 }
434 }
435 },
436 false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
437 };
438 array.into_dyn()
439 }};
440 (
441 ty $ty:ty,
442 shape $shape:expr,
443 array $array:expr,
444 d $D:expr
445 ) => {{
446 match $D {
447 1 => reshape!(ty $ty, n 1, shape $shape, array $array),
448 2 => reshape!(ty $ty, n 2, shape $shape, array $array),
449 3 => reshape!(ty $ty, n 3, shape $shape, array $array),
450 4 => reshape!(ty $ty, n 4, shape $shape, array $array),
451 5 => reshape!(ty $ty, n 5, shape $shape, array $array),
452 6 => reshape!(ty $ty, n 6, shape $shape, array $array),
453 _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
454 }
455 }};
456}
457
458impl NdArrayTensor {
459 pub fn from_data(mut data: TensorData) -> NdArrayTensor {
461 let shape = mem::take(&mut data.shape);
462
463 macro_rules! execute {
464 ($data: expr, [$($dtype: ident => $ty: ty),*]) => {
465 match $data.dtype {
466 $(DType::$dtype => {
467 match data.into_vec::<$ty>() {
468 Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
470 Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
471 }.into()
472 },)*
473 other => unimplemented!("Unsupported dtype {other:?}"),
474 }
475 };
476 }
477
478 execute!(data, [
479 F64 => f64, F32 => f32,
480 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
481 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
482 Bool => bool
483 ])
484 }
485}
486
487#[derive(Clone, Debug)]
489pub struct NdArrayQTensor {
490 pub qtensor: NdArrayTensor,
492 pub scheme: QuantScheme,
494 pub qparams: Vec<QParams<f32>>,
496}
497
498impl NdArrayQTensor {
499 pub fn strategy(&self) -> QuantizationStrategy {
501 match self.scheme {
502 QuantScheme {
503 level: QuantLevel::Tensor,
504 mode: QuantMode::Symmetric,
505 value:
506 QuantValue::Q8F
507 | QuantValue::Q8S
508 | QuantValue::E4M3
509 | QuantValue::E5M2
510 | QuantValue::Q4F
511 | QuantValue::Q4S
512 | QuantValue::E2M1
513 | QuantValue::Q2F
514 | QuantValue::Q2S,
515 ..
516 } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
517 self.qparams[0].scales,
518 self.scheme.value,
519 )),
520 QuantScheme {
521 level: QuantLevel::Block(block_size),
522 mode: QuantMode::Symmetric,
523 value:
524 QuantValue::Q8F
525 | QuantValue::Q8S
526 | QuantValue::E4M3
527 | QuantValue::E5M2
528 | QuantValue::Q4F
529 | QuantValue::Q4S
530 | QuantValue::E2M1
531 | QuantValue::Q2F
532 | QuantValue::Q2S,
533 ..
534 } => QuantizationStrategy::PerBlockSymmetric(
535 self.qparams
536 .iter()
537 .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
538 .collect(),
539 block_size,
540 ),
541 }
542 }
543}
544
545impl QTensorPrimitive for NdArrayQTensor {
546 fn scheme(&self) -> &QuantScheme {
547 &self.scheme
548 }
549
550 fn default_scheme() -> QuantScheme {
551 QuantScheme::default().with_store(burn_tensor::quantization::QuantStore::Native)
552 }
553}
554
555impl TensorMetadata for NdArrayQTensor {
556 fn dtype(&self) -> DType {
557 DType::QFloat(self.scheme)
558 }
559
560 fn shape(&self) -> Shape {
561 self.qtensor.shape()
562 }
563
564 fn rank(&self) -> usize {
565 self.shape().num_dims()
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use crate::NdArray;
572
573 use super::*;
574 use burn_common::rand::get_seeded_rng;
575 use burn_tensor::{
576 Distribution,
577 ops::{FloatTensorOps, QTensorOps},
578 quantization::{QuantStore, QuantizationParametersPrimitive},
579 };
580
581 #[test]
582 fn should_support_into_and_from_data_1d() {
583 let data_expected = TensorData::random::<f32, _, _>(
584 Shape::new([3]),
585 Distribution::Default,
586 &mut get_seeded_rng(),
587 );
588 let tensor = NdArrayTensor::from_data(data_expected.clone());
589
590 let data_actual = tensor.into_data();
591
592 assert_eq!(data_expected, data_actual);
593 }
594
595 #[test]
596 fn should_support_into_and_from_data_2d() {
597 let data_expected = TensorData::random::<f32, _, _>(
598 Shape::new([2, 3]),
599 Distribution::Default,
600 &mut get_seeded_rng(),
601 );
602 let tensor = NdArrayTensor::from_data(data_expected.clone());
603
604 let data_actual = tensor.into_data();
605
606 assert_eq!(data_expected, data_actual);
607 }
608
609 #[test]
610 fn should_support_into_and_from_data_3d() {
611 let data_expected = TensorData::random::<f32, _, _>(
612 Shape::new([2, 3, 4]),
613 Distribution::Default,
614 &mut get_seeded_rng(),
615 );
616 let tensor = NdArrayTensor::from_data(data_expected.clone());
617
618 let data_actual = tensor.into_data();
619
620 assert_eq!(data_expected, data_actual);
621 }
622
623 #[test]
624 fn should_support_into_and_from_data_4d() {
625 let data_expected = TensorData::random::<f32, _, _>(
626 Shape::new([2, 3, 4, 2]),
627 Distribution::Default,
628 &mut get_seeded_rng(),
629 );
630 let tensor = NdArrayTensor::from_data(data_expected.clone());
631
632 let data_actual = tensor.into_data();
633
634 assert_eq!(data_expected, data_actual);
635 }
636
637 #[test]
638 fn should_support_qtensor_strategy() {
639 type B = NdArray<f32, i64, i8>;
640 let scale: f32 = 0.009_019_608;
641 let device = Default::default();
642
643 let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
644 let scheme = QuantScheme::default()
645 .with_value(QuantValue::Q8S)
646 .with_store(QuantStore::Native);
647 let qparams = QuantizationParametersPrimitive {
648 scales: B::float_from_data(TensorData::from([scale]), &device),
649 };
650 let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
651
652 assert_eq!(qtensor.scheme(), &scheme);
653 assert_eq!(
654 qtensor.strategy(),
655 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
656 scale,
657 QuantValue::Q8S
658 ))
659 );
660 }
661}