1use burn_backend::{
2 AllocationProperty, DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,
3 quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},
4};
5use burn_std::BoolStore;
6
7use crate::NdArrayStorage;
8use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
9use alloc::vec::Vec;
10use ndarray::{ArcArray, ArrayD, IxDyn};
11
12pub type SharedArray<E> = ArcArray<E, IxDyn>;
14
15#[derive(Debug, Clone)]
21#[allow(missing_docs)]
22pub enum NdArrayTensor {
23 F64(NdArrayStorage<f64>),
24 F32(NdArrayStorage<f32>),
25 I64(NdArrayStorage<i64>),
26 I32(NdArrayStorage<i32>),
27 I16(NdArrayStorage<i16>),
28 I8(NdArrayStorage<i8>),
29 U64(NdArrayStorage<u64>),
30 U32(NdArrayStorage<u32>),
31 U16(NdArrayStorage<u16>),
32 U8(NdArrayStorage<u8>),
33 Bool(NdArrayStorage<bool>),
34}
35
36impl NdArrayTensor {
37 pub(crate) fn bool(self) -> SharedArray<bool> {
39 match self {
40 NdArrayTensor::Bool(storage) => storage.into_shared(),
41 _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
42 }
43 }
44
45 #[inline]
47 pub fn is_borrowed(&self) -> bool {
48 macro_rules! check {
49 ($($variant:ident),*) => {
50 match self {
51 $(NdArrayTensor::$variant(s) => s.is_borrowed(),)*
52 }
53 };
54 }
55 check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
56 }
57}
58
59pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
60where
61 NdArrayTensor: From<SharedArray<E1>>,
62{
63 fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
64 array.mapv(|a| a.elem()).into_shared()
65 }
66
67 if E1::dtype() == dtype {
68 return array.into();
69 }
70
71 match dtype {
72 DType::F64 => cast::<E1, f64>(array).into(),
73 DType::F32 => cast::<E1, f32>(array).into(),
74 DType::Flex32 => cast::<E1, f32>(array).into(),
75 DType::I64 => cast::<E1, i64>(array).into(),
76 DType::I32 => cast::<E1, i32>(array).into(),
77 DType::I16 => cast::<E1, i16>(array).into(),
78 DType::I8 => cast::<E1, i8>(array).into(),
79 DType::U64 => cast::<E1, u64>(array).into(),
80 DType::U32 => cast::<E1, u32>(array).into(),
81 DType::U16 => cast::<E1, u16>(array).into(),
82 DType::U8 => cast::<E1, u8>(array).into(),
83 DType::Bool(BoolStore::Native) => cast::<E1, bool>(array).into(),
84 dtype => panic!("Unsupported dtype: {dtype:?}"),
85 }
86}
87
88macro_rules! impl_from {
89 ($($ty: ty => $dtype: ident),*) => {
90 $(impl From<SharedArray<$ty>> for NdArrayTensor {
92 fn from(value: SharedArray<$ty>) -> NdArrayTensor {
93 NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))
94 }
95 })*
96
97 $(impl From<NdArrayStorage<$ty>> for NdArrayTensor {
99 fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {
100 NdArrayTensor::$dtype(value)
101 }
102 })*
103 };
104}
105
106impl_from!(
107 f64 => F64, f32 => F32,
108 i64 => I64, i32 => I32, i16 => I16, i8 => I8,
109 u64 => U64, u32 => U32, u16 => U16, u8 => U8,
110 bool => Bool
111);
112
113#[macro_export]
121macro_rules! execute_with_dtype {
122 (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
123 let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);
124 let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);
125 match ($lhs, $rhs) {
126 $(
127 ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
128 #[allow(unused)]
129 type $element = $ty;
130 $op(lhs.into_shared(), rhs.into_shared()).into()
132 }
133 )*
134 _ => panic!(
135 "Data type mismatch (lhs: {:?}, rhs: {:?})",
136 lhs_dtype, rhs_dtype
137 ),
138 }
139 }};
140 (($lhs:expr, $rhs:expr), $op:expr) => {{
142 $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
143 }};
144
145 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
147 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
148 F64 => f64, F32 => f32,
149 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
150 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
151 Bool => bool
152 ])
153 }};
154
155 ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
156 match $tensor {
157 $(
158 $crate::NdArrayTensor::$dtype(storage) => {
159 #[allow(unused)]
160 type $element = $ty;
161 $op(storage.into_shared()).into()
163 }
164 )*
165 #[allow(unreachable_patterns)]
166 other => unimplemented!("unsupported dtype: {:?}", other.dtype())
167 }
168 }};
169 ($tensor:expr, $op:expr) => {{
171 $crate::execute_with_dtype!($tensor, E, $op)
172 }};
173
174 ($tensor:expr, $element:ident, $op:expr) => {{
176 $crate::execute_with_dtype!($tensor, $element, $op, [
177 F64 => f64, F32 => f32,
178 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
179 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
180 Bool => bool
181 ])
182 }};
183}
184
185#[macro_export]
192macro_rules! execute_with_float_dtype {
193 (($lhs:expr, $rhs:expr), $op:expr) => {{
195 $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
196 }};
197
198 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
200 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
201 F64 => f64, F32 => f32
202 ])
203 }};
204
205 ($tensor:expr, $op:expr) => {{
207 $crate::execute_with_float_dtype!($tensor, E, $op)
208 }};
209
210 ($tensor:expr, $element:ident, $op:expr) => {{
212 $crate::execute_with_dtype!($tensor, $element, $op, [
213 F64 => f64, F32 => f32
214 ])
215 }};
216}
217
218#[macro_export]
225macro_rules! execute_with_int_dtype {
226 (($lhs:expr, $rhs:expr), $op:expr) => {{
228 $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
229 }};
230
231 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
233 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
234 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
235 U64 => u64, U32 => u32, U16 => u16, U8 => u8
236 ])
237 }};
238
239 ($tensor:expr, $op:expr) => {{
241 $crate::execute_with_int_dtype!($tensor, E, $op)
242 }};
243
244 ($tensor:expr, $element:ident, $op:expr) => {{
246 $crate::execute_with_dtype!($tensor, $element, $op, [
247 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
248 U64 => u64, U32 => u32, U16 => u16, U8 => u8
249 ])
250 }};
251}
252
253#[macro_export]
260macro_rules! execute_with_numeric_dtype {
261 (($lhs:expr, $rhs:expr), $op:expr) => {{
263 $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
264 }};
265
266 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
268 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
269 F64 => f64, F32 => f32,
270 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
271 U64 => u64, U32 => u32, U16 => u16, U8 => u8
272 ])
273 }};
274
275 ($tensor:expr, $op:expr) => {{
277 $crate::execute_with_numeric_dtype!($tensor, E, $op)
278 }};
279
280 ($tensor:expr, $element:ident, $op:expr) => {{
282 $crate::execute_with_dtype!($tensor, $element, $op, [
283 F64 => f64, F32 => f32,
284 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
285 U64 => u64, U32 => u32, U16 => u16, U8 => u8
286 ])
287 }};
288}
289
290#[macro_export]
298macro_rules! cat_with_dtype {
299 ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
300 match &$tensors[0] {
301 $(NdArrayTensor::$dtype(_) => {
302 let tensors = $tensors
303 .iter()
304 .map(|t| {
305 if let NdArrayTensor::$dtype(storage) = t {
306 storage.view()
308 } else {
309 panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype())
310 }
311 })
312 .collect::<Vec<_>>();
313 NdArrayOps::concatenate(&tensors, $dim).into()
314 })*
315 _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
316 }
317 };
318}
319
320#[macro_export]
322macro_rules! execute_with_float_out_dtype {
323 ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
324 match $out_dtype {
325 $(
326 burn_std::FloatDType::$dtype => {
327 #[allow(unused)]
328 type $element = $ty;
329 $op
330 }
331 )*
332 #[allow(unreachable_patterns)]
333 other => unimplemented!("unsupported dtype: {other:?}")
334 }
335 }};
336 ($out_dtype:expr, $op:expr) => {{
338 $crate::execute_with_float_out_dtype!($out_dtype, E, $op)
339 }};
340
341 ($out_dtype:expr, $element:ident, $op:expr) => {{
343 $crate::execute_with_float_out_dtype!($out_dtype, $element, $op, [
344 F64 => f64, F32 => f32
345 ])
346 }};
347}
348
349#[macro_export]
351macro_rules! execute_with_int_out_dtype {
352 ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
353 match $out_dtype {
354 $(
355 burn_std::IntDType::$dtype => {
356 #[allow(unused)]
357 type $element = $ty;
358 $op
359 }
360 )*
361 #[allow(unreachable_patterns)]
362 other => unimplemented!("unsupported dtype: {other:?}")
363 }
364 }};
365 ($out_dtype:expr, $op:expr) => {{
367 $crate::execute_with_int_out_dtype!($out_dtype, E, $op)
368 }};
369
370 ($out_dtype:expr, $element:ident, $op:expr) => {{
372 $crate::execute_with_int_out_dtype!($out_dtype, $element, $op, [
373 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
374 U64 => u64, U32 => u32, U16 => u16, U8 => u8
375 ])
376 }};
377}
378
379impl TensorMetadata for NdArrayTensor {
380 fn dtype(&self) -> DType {
381 match self {
382 NdArrayTensor::F64(_) => DType::F64,
383 NdArrayTensor::F32(_) => DType::F32,
384 NdArrayTensor::I64(_) => DType::I64,
385 NdArrayTensor::I32(_) => DType::I32,
386 NdArrayTensor::I16(_) => DType::I16,
387 NdArrayTensor::I8(_) => DType::I8,
388 NdArrayTensor::U64(_) => DType::U64,
389 NdArrayTensor::U32(_) => DType::U32,
390 NdArrayTensor::U16(_) => DType::U16,
391 NdArrayTensor::U8(_) => DType::U8,
392 NdArrayTensor::Bool(_) => DType::Bool(BoolStore::Native),
393 }
394 }
395
396 fn shape(&self) -> Shape {
397 macro_rules! get_shape {
399 ($($variant:ident),*) => {
400 match self {
401 $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*
402 }
403 };
404 }
405 get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
406 }
407
408 fn rank(&self) -> usize {
409 self.shape().num_dims()
410 }
411}
412
413pub(crate) trait ShapeOps {
414 fn num_dims(self) -> usize;
415 fn num_elements(self) -> usize;
416 fn dims<const N: usize>(self) -> [usize; N];
417 fn into_shape(self) -> Shape;
418}
419
420impl ShapeOps for &[usize] {
421 fn num_dims(self) -> usize {
422 self.len()
423 }
424
425 fn num_elements(self) -> usize {
426 self.iter().product()
427 }
428
429 fn dims<const N: usize>(self) -> [usize; N] {
430 self.try_into().unwrap()
431 }
432
433 fn into_shape(self) -> Shape {
434 Shape::from(self)
435 }
436}
437
438mod utils {
439 use burn_std::tensor::is_contiguous;
440
441 use super::*;
442
443 impl NdArrayTensor {
444 pub(crate) fn into_data(self) -> TensorData {
445 let shape = self.shape();
446 let contiguous = self.is_contiguous();
447
448 fn inner<E: Element>(
449 shape: Shape,
450 is_contiguous: bool,
451 array: ArcArray<E, IxDyn>,
452 ) -> TensorData {
453 let vec = if is_contiguous {
454 match array.try_into_owned_nocopy() {
455 Ok(owned) => {
456 let (mut vec, offset) = owned.into_raw_vec_and_offset();
457 if let Some(offset) = offset {
458 vec.drain(..offset);
459 }
460 if vec.len() > shape.num_elements() {
461 vec.drain(shape.num_elements()..vec.len());
462 }
463 vec
464 }
465 Err(array) => array.into_iter().collect(),
466 }
467 } else {
468 array.into_iter().collect()
469 };
470
471 TensorData::new(vec, shape)
472 }
473
474 execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
476 }
477
478 pub(crate) fn is_contiguous(&self) -> bool {
479 macro_rules! check_contiguous {
482 ($($variant:ident),*) => {
483 match self {
484 $(NdArrayTensor::$variant(storage) => {
485 match storage {
486 NdArrayStorage::Borrowed { .. } => {
487 true
490 }
491 NdArrayStorage::Owned(array) => {
492 let shape = array.shape();
493 let mut strides = Vec::with_capacity(array.strides().len());
494 for &stride in array.strides() {
495 if stride <= 0 {
496 return false;
497 }
498 strides.push(stride as usize);
499 }
500 is_contiguous(shape, &strides)
501 }
502 }
503 })*
504 }
505 };
506 }
507 check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
508 }
509 }
510}
511
512#[macro_export(local_inner_macros)]
514macro_rules! to_typed_dims {
515 (
516 $n:expr,
517 $dims:expr,
518 justdim
519 ) => {{
520 let mut dims = [0; $n];
521 for i in 0..$n {
522 dims[i] = $dims[i];
523 }
524 let dim: Dim<[usize; $n]> = Dim(dims);
525 dim
526 }};
527}
528
529#[macro_export(local_inner_macros)]
531macro_rules! reshape {
532 (
533 ty $ty:ty,
534 n $n:expr,
535 shape $shape:expr,
536 array $array:expr
537 ) => {{
538 let dim = $crate::to_typed_dims!($n, $shape, justdim);
539 let array = match $array.is_standard_layout() {
540 true => {
541 match $array.to_shape(dim) {
542 Ok(val) => val.into_shared(),
543 Err(err) => {
544 core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
545 }
546 }
547 },
548 false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
549 };
550 array.into_dyn()
551 }};
552 (
553 ty $ty:ty,
554 shape $shape:expr,
555 array $array:expr,
556 d $D:expr
557 ) => {{
558 match $D {
559 1 => reshape!(ty $ty, n 1, shape $shape, array $array),
560 2 => reshape!(ty $ty, n 2, shape $shape, array $array),
561 3 => reshape!(ty $ty, n 3, shape $shape, array $array),
562 4 => reshape!(ty $ty, n 4, shape $shape, array $array),
563 5 => reshape!(ty $ty, n 5, shape $shape, array $array),
564 6 => reshape!(ty $ty, n 6, shape $shape, array $array),
565 _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
566 }
567 }};
568}
569
570#[macro_export]
572macro_rules! slice {
573 ($tensor:expr, $slices:expr) => {
574 slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
575 };
576 ($tensor:expr, $slices:expr, $($variant:ident),*) => {
577 match $tensor {
578 $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*
579 }
580 };
581}
582
583impl NdArrayTensor {
584 pub fn from_data(data: TensorData) -> NdArrayTensor {
594 if data.bytes.property() != AllocationProperty::Native {
599 match Self::try_from_data_borrowed(data) {
600 Ok(tensor) => return tensor,
601 Err(data) => return Self::from_data_owned(data),
602 }
603 }
604 Self::from_data_owned(data)
605 }
606
607 fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {
614 let TensorData {
615 bytes,
616 shape,
617 dtype,
618 } = data;
619
620 macro_rules! try_borrow {
621 ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {
622 match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {
623 Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),
624 Err((bytes, shape)) => (bytes, shape),
625 }
626 };
627 }
628
629 let (bytes, shape) = match dtype {
631 DType::F64 => try_borrow!(f64, F64, bytes, shape),
632 DType::F32 => try_borrow!(f32, F32, bytes, shape),
633 DType::I64 => try_borrow!(i64, I64, bytes, shape),
634 DType::I32 => try_borrow!(i32, I32, bytes, shape),
635 DType::I16 => try_borrow!(i16, I16, bytes, shape),
636 DType::I8 => try_borrow!(i8, I8, bytes, shape),
637 DType::U64 => try_borrow!(u64, U64, bytes, shape),
638 DType::U32 => try_borrow!(u32, U32, bytes, shape),
639 DType::U16 => try_borrow!(u16, U16, bytes, shape),
640 DType::U8 => try_borrow!(u8, U8, bytes, shape),
641 DType::Bool(BoolStore::Native) => try_borrow!(bool, Bool, bytes, shape),
642 _ => (bytes, shape), };
644
645 Err(TensorData {
646 bytes,
647 shape,
648 dtype,
649 })
650 }
651
652 fn from_data_owned(data: TensorData) -> NdArrayTensor {
658 let shape = data.shape.to_vec(); macro_rules! execute {
661 ($data: expr, [$($dtype: pat => $ty: ty),*]) => {
662 match $data.dtype {
663 $( $dtype => {
664 match data.into_vec::<$ty>() {
665 Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
666 Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
667 }.into()
668 }, )*
669 other => unimplemented!("Unsupported dtype {other:?}"),
670 }
671 };
672 }
673
674 execute!(data, [
675 DType::F64 => f64, DType::F32 => f32,
676 DType::I64 => i64, DType::I32 => i32, DType::I16 => i16, DType::I8 => i8,
677 DType::U64 => u64, DType::U32 => u32, DType::U16 => u16, DType::U8 => u8,
678 DType::Bool(BoolStore::Native) => bool
679 ])
680 }
681}
682
683#[derive(Clone, Debug)]
685pub struct NdArrayQTensor {
686 pub qtensor: NdArrayTensor,
688 pub scheme: QuantScheme,
690 pub qparams: Vec<QParams<f32>>,
692}
693
694impl NdArrayQTensor {
695 pub fn strategy(&self) -> QuantizationStrategy {
697 match self.scheme {
698 QuantScheme {
699 level: QuantLevel::Tensor,
700 mode: QuantMode::Symmetric,
701 value:
702 QuantValue::Q8F
703 | QuantValue::Q8S
704 | QuantValue::E4M3
705 | QuantValue::E5M2
706 | QuantValue::Q4F
707 | QuantValue::Q4S
708 | QuantValue::E2M1
709 | QuantValue::Q2F
710 | QuantValue::Q2S,
711 ..
712 } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
713 self.qparams[0].scales,
714 self.scheme.value,
715 )),
716 QuantScheme {
717 level: QuantLevel::Block(block_size),
718 mode: QuantMode::Symmetric,
719 value:
720 QuantValue::Q8F
721 | QuantValue::Q8S
722 | QuantValue::E4M3
723 | QuantValue::E5M2
724 | QuantValue::Q4F
725 | QuantValue::Q4S
726 | QuantValue::E2M1
727 | QuantValue::Q2F
728 | QuantValue::Q2S,
729 ..
730 } => QuantizationStrategy::PerBlockSymmetric(
731 self.qparams
732 .iter()
733 .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
734 .collect(),
735 block_size,
736 ),
737 }
738 }
739}
740
741impl QTensorPrimitive for NdArrayQTensor {
742 fn scheme(&self) -> &QuantScheme {
743 &self.scheme
744 }
745
746 fn default_scheme() -> QuantScheme {
747 QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)
748 }
749}
750
751impl TensorMetadata for NdArrayQTensor {
752 fn dtype(&self) -> DType {
753 DType::QFloat(self.scheme)
754 }
755
756 fn shape(&self) -> Shape {
757 self.qtensor.shape()
758 }
759
760 fn rank(&self) -> usize {
761 self.shape().num_dims()
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use crate::NdArray;
768 use alloc::vec;
769
770 use super::*;
771 use burn_backend::{
772 Distribution,
773 ops::{FloatTensorOps, QTensorOps},
774 quantization::{QuantStore, QuantizationParametersPrimitive},
775 };
776 use burn_std::rand::get_seeded_rng;
777
778 #[test]
779 fn should_support_into_and_from_data_1d() {
780 let data_expected = TensorData::random::<f32, _, _>(
781 Shape::new([3]),
782 Distribution::Default,
783 &mut get_seeded_rng(),
784 );
785 let tensor = NdArrayTensor::from_data(data_expected.clone());
786
787 let data_actual = tensor.into_data();
788
789 assert_eq!(data_expected, data_actual);
790 }
791
792 #[test]
793 fn should_support_into_and_from_data_2d() {
794 let data_expected = TensorData::random::<f32, _, _>(
795 Shape::new([2, 3]),
796 Distribution::Default,
797 &mut get_seeded_rng(),
798 );
799 let tensor = NdArrayTensor::from_data(data_expected.clone());
800
801 let data_actual = tensor.into_data();
802
803 assert_eq!(data_expected, data_actual);
804 }
805
806 #[test]
807 fn should_support_into_and_from_data_3d() {
808 let data_expected = TensorData::random::<f32, _, _>(
809 Shape::new([2, 3, 4]),
810 Distribution::Default,
811 &mut get_seeded_rng(),
812 );
813 let tensor = NdArrayTensor::from_data(data_expected.clone());
814
815 let data_actual = tensor.into_data();
816
817 assert_eq!(data_expected, data_actual);
818 }
819
820 #[test]
821 fn should_support_into_and_from_data_4d() {
822 let data_expected = TensorData::random::<f32, _, _>(
823 Shape::new([2, 3, 4, 2]),
824 Distribution::Default,
825 &mut get_seeded_rng(),
826 );
827 let tensor = NdArrayTensor::from_data(data_expected.clone());
828
829 let data_actual = tensor.into_data();
830
831 assert_eq!(data_expected, data_actual);
832 }
833
834 #[test]
835 fn should_support_qtensor_strategy() {
836 type B = NdArray<f32, i64, i8>;
837 let scale: f32 = 0.009_019_608;
838 let device = Default::default();
839
840 let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
841 let scheme = QuantScheme::default()
842 .with_value(QuantValue::Q8S)
843 .with_store(QuantStore::Native);
844 let qparams = QuantizationParametersPrimitive {
845 scales: B::float_from_data(TensorData::from([scale]), &device),
846 };
847 let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
848
849 assert_eq!(qtensor.scheme(), &scheme);
850 assert_eq!(
851 qtensor.strategy(),
852 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
853 scale,
854 QuantValue::Q8S
855 ))
856 );
857 }
858
859 #[test]
865 fn zero_copy_creates_borrowed_storage_for_non_native() {
866 use burn_backend::AllocationProperty;
870 use burn_std::Bytes;
871
872 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
873 let bytes = Bytes::from_elems(data);
874 let non_native_bytes = Bytes::from_shared(
876 bytes::Bytes::copy_from_slice(&bytes),
877 AllocationProperty::Other,
878 );
879 let tensor_data = TensorData::from_bytes(non_native_bytes, Shape::new([2, 2]), DType::F32);
880
881 let tensor = NdArrayTensor::from_data(tensor_data);
882
883 match &tensor {
884 NdArrayTensor::F32(storage) => {
885 assert!(
886 storage.is_borrowed(),
887 "ZERO-COPY REGRESSION: from_data should create borrowed storage \
888 for non-native (e.g. burnpack) TensorData"
889 );
890 assert!(
891 !storage.is_unique(),
892 "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false"
893 );
894 }
895 _ => panic!("Expected F32 tensor"),
896 }
897 }
898
899 #[test]
900 fn native_alloc_creates_owned_storage() {
901 use burn_std::Bytes;
903
904 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
905 let bytes = Bytes::from_elems(data); let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
907
908 let tensor = NdArrayTensor::from_data(tensor_data);
909
910 match &tensor {
911 NdArrayTensor::F32(storage) => {
912 assert!(
913 !storage.is_borrowed(),
914 "PERF REGRESSION: from_data must NOT create borrowed storage \
915 for native TensorData"
916 );
917 }
918 _ => panic!("Expected F32 tensor"),
919 }
920 }
921
922 #[test]
923 fn zero_copy_data_integrity() {
924 use burn_std::Bytes;
926
927 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
928 let bytes = Bytes::from_elems(data);
929 let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
930
931 let tensor = NdArrayTensor::from_data(tensor_data);
932
933 match &tensor {
934 NdArrayTensor::F32(storage) => {
935 let view = storage.view();
936 assert_eq!(view[[0, 0]], 1.0);
937 assert_eq!(view[[0, 1]], 2.0);
938 assert_eq!(view[[1, 0]], 3.0);
939 assert_eq!(view[[1, 1]], 4.0);
940 }
941 _ => panic!("Expected F32 tensor"),
942 }
943 }
944
945 #[test]
946 fn zero_copy_fallback_when_bytes_owned() {
947 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
950 let tensor = NdArrayTensor::from_data(data.clone());
951 let result = tensor.into_data();
952
953 assert_eq!(data, result, "Data should round-trip correctly");
954 }
955}