1use core::mem;
2
3use burn_tensor::{
4 DType, Element, Shape, TensorData, TensorMetadata,
5 quantization::{
6 QParams, QTensorPrimitive, QuantLevel, QuantMode, QuantScheme, QuantValue,
7 QuantizationStrategy, SymmetricQuantization,
8 },
9};
10
11use alloc::vec::Vec;
12use ndarray::{ArcArray, ArrayD, IxDyn};
13
14pub type SharedArray<E> = ArcArray<E, IxDyn>;
16
17#[derive(Debug, Clone)]
19#[allow(missing_docs)]
20pub enum NdArrayTensor {
21 F64(SharedArray<f64>),
22 F32(SharedArray<f32>),
23 I64(SharedArray<i64>),
24 I32(SharedArray<i32>),
25 I16(SharedArray<i16>),
26 I8(SharedArray<i8>),
27 U64(SharedArray<u64>),
28 U32(SharedArray<u32>),
29 U16(SharedArray<u16>),
30 U8(SharedArray<u8>),
31 Bool(SharedArray<bool>),
32}
33
34impl NdArrayTensor {
35 pub(crate) fn bool(self) -> SharedArray<bool> {
36 match self {
37 NdArrayTensor::Bool(arr) => arr,
38 _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
39 }
40 }
41}
42
43pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
44where
45 NdArrayTensor: From<SharedArray<E1>>,
46{
47 fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
48 array.mapv(|a| a.elem()).into_shared()
49 }
50
51 if E1::dtype() == dtype {
52 return array.into();
53 }
54
55 match dtype {
56 DType::F64 => cast::<E1, f64>(array).into(),
57 DType::F32 => cast::<E1, f32>(array).into(),
58 DType::Flex32 => cast::<E1, f32>(array).into(),
59 DType::I64 => cast::<E1, i64>(array).into(),
60 DType::I32 => cast::<E1, i32>(array).into(),
61 DType::I16 => cast::<E1, i16>(array).into(),
62 DType::I8 => cast::<E1, i8>(array).into(),
63 DType::U64 => cast::<E1, u64>(array).into(),
64 DType::U32 => cast::<E1, u32>(array).into(),
65 DType::U16 => cast::<E1, u16>(array).into(),
66 DType::U8 => cast::<E1, u8>(array).into(),
67 DType::Bool => cast::<E1, bool>(array).into(),
68 dtype => panic!("Unsupported dtype: {dtype:?}"),
69 }
70}
71
72macro_rules! impl_from {
73 ($($ty: ty => $dtype: ident),*) => {
74 $(impl From<SharedArray<$ty>> for NdArrayTensor {
75 fn from(value: SharedArray<$ty>) -> NdArrayTensor {
76 NdArrayTensor::$dtype(value)
77 }
78 })*
79 };
80}
81
82impl_from!(
83 f64 => F64, f32 => F32,
84 i64 => I64, i32 => I32, i16 => I16, i8 => I8,
85 u64 => U64, u32 => U32, u16 => U16, u8 => U8,
86 bool => Bool
87);
88
89#[macro_export]
95macro_rules! execute_with_dtype {
96 (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
97 let lhs_dtype = burn_tensor::TensorMetadata::dtype(&$lhs);
98 let rhs_dtype = burn_tensor::TensorMetadata::dtype(&$rhs);
99 match ($lhs, $rhs) {
100 $(
101 ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
102 #[allow(unused)]
103 type $element = $ty;
104 $op(lhs, rhs).into()
105 }
106 )*
107 _ => panic!(
108 "Data type mismatch (lhs: {:?}, rhs: {:?})",
109 lhs_dtype, rhs_dtype
110 ),
111 }
112 }};
113 (($lhs:expr, $rhs:expr), $op:expr) => {{
115 $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
116 }};
117
118 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
120 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
121 F64 => f64, F32 => f32,
122 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
123 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
124 Bool => bool
125 ])
126 }};
127
128 ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
129 match $tensor {
130 $(
131 $crate::NdArrayTensor::$dtype(lhs) => {
132 #[allow(unused)]
133 type $element = $ty;
134 $op(lhs).into()
135 }
136 )*
137 #[allow(unreachable_patterns)]
138 other => unimplemented!("unsupported dtype: {:?}", other.dtype())
139 }
140 }};
141 ($tensor:expr, $op:expr) => {{
143 $crate::execute_with_dtype!($tensor, E, $op)
144 }};
145
146 ($tensor:expr, $element:ident, $op:expr) => {{
148 $crate::execute_with_dtype!($tensor, $element, $op, [
149 F64 => f64, F32 => f32,
150 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
151 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
152 Bool => bool
153 ])
154 }};
155}
156
157#[macro_export]
164macro_rules! execute_with_float_dtype {
165 (($lhs:expr, $rhs:expr), $op:expr) => {{
167 $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
168 }};
169
170 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
172 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
173 F64 => f64, F32 => f32
174 ])
175 }};
176
177 ($tensor:expr, $op:expr) => {{
179 $crate::execute_with_float_dtype!($tensor, E, $op)
180 }};
181
182 ($tensor:expr, $element:ident, $op:expr) => {{
184 $crate::execute_with_dtype!($tensor, $element, $op, [
185 F64 => f64, F32 => f32
186 ])
187 }};
188}
189
190#[macro_export]
197macro_rules! execute_with_int_dtype {
198 (($lhs:expr, $rhs:expr), $op:expr) => {{
200 $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
201 }};
202
203 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
205 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
206 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
207 U64 => u64, U32 => u32, U16 => u16, U8 => u8
208 ])
209 }};
210
211 ($tensor:expr, $op:expr) => {{
213 $crate::execute_with_int_dtype!($tensor, E, $op)
214 }};
215
216 ($tensor:expr, $element:ident, $op:expr) => {{
218 $crate::execute_with_dtype!($tensor, $element, $op, [
219 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
220 U64 => u64, U32 => u32, U16 => u16, U8 => u8
221 ])
222 }};
223}
224
225#[macro_export]
232macro_rules! execute_with_numeric_dtype {
233 (($lhs:expr, $rhs:expr), $op:expr) => {{
235 $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
236 }};
237
238 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
240 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
241 F64 => f64, F32 => f32,
242 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
243 U64 => u64, U32 => u32, U16 => u16, U8 => u8
244 ])
245 }};
246
247 ($tensor:expr, $op:expr) => {{
249 $crate::execute_with_numeric_dtype!($tensor, E, $op)
250 }};
251
252 ($tensor:expr, $element:ident, $op:expr) => {{
254 $crate::execute_with_dtype!($tensor, $element, $op, [
255 F64 => f64, F32 => f32,
256 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
257 U64 => u64, U32 => u32, U16 => u16, U8 => u8
258 ])
259 }};
260}
261
262#[macro_export]
268macro_rules! cat_with_dtype {
269 ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
270 match &$tensors[0] {
271 $(NdArrayTensor::$dtype(_) => {
272 let tensors = $tensors
273 .iter()
274 .map(|t| {
275 if let NdArrayTensor::$dtype(tensor) = t {
276 tensor.view()
277 } else {
278 panic!("Concatenate data type mismatch (expected f32, got f64)")
279 }
280 })
281 .collect::<Vec<_>>();
282 NdArrayOps::concatenate(&tensors, $dim).into()
283 })*
284 _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
285 }
286 };
287}
288
289impl TensorMetadata for NdArrayTensor {
290 fn dtype(&self) -> DType {
291 match self {
292 NdArrayTensor::F64(_) => DType::F64,
293 NdArrayTensor::F32(_) => DType::F32,
294 NdArrayTensor::I64(_) => DType::I64,
295 NdArrayTensor::I32(_) => DType::I32,
296 NdArrayTensor::I16(_) => DType::I16,
297 NdArrayTensor::I8(_) => DType::I8,
298 NdArrayTensor::U64(_) => DType::U64,
299 NdArrayTensor::U32(_) => DType::U32,
300 NdArrayTensor::U16(_) => DType::U16,
301 NdArrayTensor::U8(_) => DType::U8,
302 NdArrayTensor::Bool(_) => DType::Bool,
303 }
304 }
305
306 fn shape(&self) -> Shape {
307 execute_with_dtype!(self, E, |a: &ArcArray<E, IxDyn>| Shape::from(
308 a.shape().to_vec()
309 ))
310 }
311
312 fn rank(&self) -> usize {
313 self.shape().num_dims()
314 }
315}
316
317pub(crate) trait ShapeOps {
318 fn num_dims(self) -> usize;
319 fn num_elements(self) -> usize;
320 fn dims<const N: usize>(self) -> [usize; N];
321 fn into_shape(self) -> Shape;
322}
323
324impl ShapeOps for &[usize] {
325 fn num_dims(self) -> usize {
326 self.len()
327 }
328
329 fn num_elements(self) -> usize {
330 self.iter().product()
331 }
332
333 fn dims<const N: usize>(self) -> [usize; N] {
334 self.try_into().unwrap()
335 }
336
337 fn into_shape(self) -> Shape {
338 Shape {
339 dims: self.to_vec(),
340 }
341 }
342}
343
344mod utils {
345 use burn_common::tensor::is_contiguous;
346
347 use super::*;
348
349 impl NdArrayTensor {
350 pub(crate) fn into_data(self) -> TensorData {
351 let shape = self.shape();
352 let contiguous = self.is_contiguous();
353
354 fn inner<E: Element>(
355 shape: Shape,
356 is_contiguous: bool,
357 array: ArcArray<E, IxDyn>,
358 ) -> TensorData {
359 let vec = if is_contiguous {
360 match array.try_into_owned_nocopy() {
361 Ok(owned) => {
362 let (mut vec, offset) = owned.into_raw_vec_and_offset();
363 if let Some(offset) = offset {
364 vec.drain(..offset);
365 }
366 if vec.len() > shape.num_elements() {
367 vec.drain(shape.num_elements()..vec.len());
368 }
369 vec
370 }
371 Err(array) => array.into_iter().collect(),
372 }
373 } else {
374 array.into_iter().collect()
375 };
376
377 TensorData::new(vec, shape)
378 }
379
380 execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
381 }
382
383 pub(crate) fn is_contiguous(&self) -> bool {
384 fn inner<E: Element>(array: &ArcArray<E, IxDyn>) -> bool {
385 let shape = array.shape();
386 let mut strides = Vec::with_capacity(array.strides().len());
387
388 for &stride in array.strides() {
389 if stride <= 0 {
390 return false;
391 }
392 strides.push(stride as usize);
393 }
394 is_contiguous(shape, &strides)
395 }
396
397 execute_with_dtype!(self, inner)
398 }
399 }
400}
401
402#[macro_export(local_inner_macros)]
404macro_rules! to_typed_dims {
405 (
406 $n:expr,
407 $dims:expr,
408 justdim
409 ) => {{
410 let mut dims = [0; $n];
411 for i in 0..$n {
412 dims[i] = $dims[i];
413 }
414 let dim: Dim<[usize; $n]> = Dim(dims);
415 dim
416 }};
417}
418
419#[macro_export(local_inner_macros)]
421macro_rules! reshape {
422 (
423 ty $ty:ty,
424 n $n:expr,
425 shape $shape:expr,
426 array $array:expr
427 ) => {{
428 let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
429 let array = match $array.is_standard_layout() {
430 true => {
431 match $array.to_shape(dim) {
432 Ok(val) => val.into_shared(),
433 Err(err) => {
434 core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
435 }
436 }
437 },
438 false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
439 };
440 array.into_dyn()
441 }};
442 (
443 ty $ty:ty,
444 shape $shape:expr,
445 array $array:expr,
446 d $D:expr
447 ) => {{
448 match $D {
449 1 => reshape!(ty $ty, n 1, shape $shape, array $array),
450 2 => reshape!(ty $ty, n 2, shape $shape, array $array),
451 3 => reshape!(ty $ty, n 3, shape $shape, array $array),
452 4 => reshape!(ty $ty, n 4, shape $shape, array $array),
453 5 => reshape!(ty $ty, n 5, shape $shape, array $array),
454 6 => reshape!(ty $ty, n 6, shape $shape, array $array),
455 _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
456 }
457 }};
458}
459
460impl NdArrayTensor {
461 pub fn from_data(mut data: TensorData) -> NdArrayTensor {
463 let shape = mem::take(&mut data.shape);
464
465 macro_rules! execute {
466 ($data: expr, [$($dtype: ident => $ty: ty),*]) => {
467 match $data.dtype {
468 $(DType::$dtype => {
469 match data.into_vec::<$ty>() {
470 Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
472 Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
473 }.into()
474 },)*
475 other => unimplemented!("Unsupported dtype {other:?}"),
476 }
477 };
478 }
479
480 execute!(data, [
481 F64 => f64, F32 => f32,
482 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
483 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
484 Bool => bool
485 ])
486 }
487}
488
489#[derive(Clone, Debug)]
491pub struct NdArrayQTensor {
492 pub qtensor: NdArrayTensor,
494 pub scheme: QuantScheme,
496 pub qparams: Vec<QParams<f32>>,
498}
499
500impl NdArrayQTensor {
501 pub fn strategy(&self) -> QuantizationStrategy {
503 match self.scheme {
504 QuantScheme {
505 level: QuantLevel::Tensor,
506 mode: QuantMode::Symmetric,
507 value:
508 QuantValue::Q8F
509 | QuantValue::Q8S
510 | QuantValue::E4M3
511 | QuantValue::E5M2
512 | QuantValue::Q4F
513 | QuantValue::Q4S
514 | QuantValue::E2M1
515 | QuantValue::Q2F
516 | QuantValue::Q2S,
517 ..
518 } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
519 self.qparams[0].scales,
520 self.scheme.value,
521 )),
522 QuantScheme {
523 level: QuantLevel::Block(block_size),
524 mode: QuantMode::Symmetric,
525 value:
526 QuantValue::Q8F
527 | QuantValue::Q8S
528 | QuantValue::E4M3
529 | QuantValue::E5M2
530 | QuantValue::Q4F
531 | QuantValue::Q4S
532 | QuantValue::E2M1
533 | QuantValue::Q2F
534 | QuantValue::Q2S,
535 ..
536 } => QuantizationStrategy::PerBlockSymmetric(
537 self.qparams
538 .iter()
539 .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
540 .collect(),
541 block_size,
542 ),
543 }
544 }
545}
546
547impl QTensorPrimitive for NdArrayQTensor {
548 fn scheme(&self) -> &QuantScheme {
549 &self.scheme
550 }
551
552 fn default_scheme() -> QuantScheme {
553 QuantScheme::default().with_store(burn_tensor::quantization::QuantStore::Native)
554 }
555}
556
557impl TensorMetadata for NdArrayQTensor {
558 fn dtype(&self) -> DType {
559 DType::QFloat(self.scheme)
560 }
561
562 fn shape(&self) -> Shape {
563 self.qtensor.shape()
564 }
565
566 fn rank(&self) -> usize {
567 self.shape().num_dims()
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use crate::NdArray;
574
575 use super::*;
576 use burn_common::rand::get_seeded_rng;
577 use burn_tensor::{
578 Distribution,
579 ops::{FloatTensorOps, QTensorOps},
580 quantization::{QuantStore, QuantizationParametersPrimitive},
581 };
582
583 #[test]
584 fn should_support_into_and_from_data_1d() {
585 let data_expected = TensorData::random::<f32, _, _>(
586 Shape::new([3]),
587 Distribution::Default,
588 &mut get_seeded_rng(),
589 );
590 let tensor = NdArrayTensor::from_data(data_expected.clone());
591
592 let data_actual = tensor.into_data();
593
594 assert_eq!(data_expected, data_actual);
595 }
596
597 #[test]
598 fn should_support_into_and_from_data_2d() {
599 let data_expected = TensorData::random::<f32, _, _>(
600 Shape::new([2, 3]),
601 Distribution::Default,
602 &mut get_seeded_rng(),
603 );
604 let tensor = NdArrayTensor::from_data(data_expected.clone());
605
606 let data_actual = tensor.into_data();
607
608 assert_eq!(data_expected, data_actual);
609 }
610
611 #[test]
612 fn should_support_into_and_from_data_3d() {
613 let data_expected = TensorData::random::<f32, _, _>(
614 Shape::new([2, 3, 4]),
615 Distribution::Default,
616 &mut get_seeded_rng(),
617 );
618 let tensor = NdArrayTensor::from_data(data_expected.clone());
619
620 let data_actual = tensor.into_data();
621
622 assert_eq!(data_expected, data_actual);
623 }
624
625 #[test]
626 fn should_support_into_and_from_data_4d() {
627 let data_expected = TensorData::random::<f32, _, _>(
628 Shape::new([2, 3, 4, 2]),
629 Distribution::Default,
630 &mut get_seeded_rng(),
631 );
632 let tensor = NdArrayTensor::from_data(data_expected.clone());
633
634 let data_actual = tensor.into_data();
635
636 assert_eq!(data_expected, data_actual);
637 }
638
639 #[test]
640 fn should_support_qtensor_strategy() {
641 type B = NdArray<f32, i64, i8>;
642 let scale: f32 = 0.009_019_608;
643 let device = Default::default();
644
645 let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
646 let scheme = QuantScheme::default()
647 .with_value(QuantValue::Q8S)
648 .with_store(QuantStore::Native);
649 let qparams = QuantizationParametersPrimitive {
650 scales: B::float_from_data(TensorData::from([scale]), &device),
651 };
652 let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
653
654 assert_eq!(qtensor.scheme(), &scheme);
655 assert_eq!(
656 qtensor.strategy(),
657 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
658 scale,
659 QuantValue::Q8S
660 ))
661 );
662 }
663}