1use core::mem;
2
3use burn_backend::{
4 DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,
5 quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},
6};
7
8use crate::NdArrayStorage;
9use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};
10use alloc::vec::Vec;
11use ndarray::{ArcArray, ArrayD, IxDyn};
12
13pub type SharedArray<E> = ArcArray<E, IxDyn>;
15
16#[derive(Debug, Clone)]
22#[allow(missing_docs)]
23pub enum NdArrayTensor {
24 F64(NdArrayStorage<f64>),
25 F32(NdArrayStorage<f32>),
26 I64(NdArrayStorage<i64>),
27 I32(NdArrayStorage<i32>),
28 I16(NdArrayStorage<i16>),
29 I8(NdArrayStorage<i8>),
30 U64(NdArrayStorage<u64>),
31 U32(NdArrayStorage<u32>),
32 U16(NdArrayStorage<u16>),
33 U8(NdArrayStorage<u8>),
34 Bool(NdArrayStorage<bool>),
35}
36
37impl NdArrayTensor {
38 pub(crate) fn bool(self) -> SharedArray<bool> {
40 match self {
41 NdArrayTensor::Bool(storage) => storage.into_shared(),
42 _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()),
43 }
44 }
45
46 #[inline]
48 pub fn is_borrowed(&self) -> bool {
49 macro_rules! check {
50 ($($variant:ident),*) => {
51 match self {
52 $(NdArrayTensor::$variant(s) => s.is_borrowed(),)*
53 }
54 };
55 }
56 check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
57 }
58}
59
60pub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor
61where
62 NdArrayTensor: From<SharedArray<E1>>,
63{
64 fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {
65 array.mapv(|a| a.elem()).into_shared()
66 }
67
68 if E1::dtype() == dtype {
69 return array.into();
70 }
71
72 match dtype {
73 DType::F64 => cast::<E1, f64>(array).into(),
74 DType::F32 => cast::<E1, f32>(array).into(),
75 DType::Flex32 => cast::<E1, f32>(array).into(),
76 DType::I64 => cast::<E1, i64>(array).into(),
77 DType::I32 => cast::<E1, i32>(array).into(),
78 DType::I16 => cast::<E1, i16>(array).into(),
79 DType::I8 => cast::<E1, i8>(array).into(),
80 DType::U64 => cast::<E1, u64>(array).into(),
81 DType::U32 => cast::<E1, u32>(array).into(),
82 DType::U16 => cast::<E1, u16>(array).into(),
83 DType::U8 => cast::<E1, u8>(array).into(),
84 DType::Bool => cast::<E1, bool>(array).into(),
85 dtype => panic!("Unsupported dtype: {dtype:?}"),
86 }
87}
88
89macro_rules! impl_from {
90 ($($ty: ty => $dtype: ident),*) => {
91 $(impl From<SharedArray<$ty>> for NdArrayTensor {
93 fn from(value: SharedArray<$ty>) -> NdArrayTensor {
94 NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))
95 }
96 })*
97
98 $(impl From<NdArrayStorage<$ty>> for NdArrayTensor {
100 fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {
101 NdArrayTensor::$dtype(value)
102 }
103 })*
104 };
105}
106
107impl_from!(
108 f64 => F64, f32 => F32,
109 i64 => I64, i32 => I32, i16 => I16, i8 => I8,
110 u64 => U64, u32 => U32, u16 => U16, u8 => U8,
111 bool => Bool
112);
113
114#[macro_export]
122macro_rules! execute_with_dtype {
123 (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
124 let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);
125 let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);
126 match ($lhs, $rhs) {
127 $(
128 ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {
129 #[allow(unused)]
130 type $element = $ty;
131 $op(lhs.into_shared(), rhs.into_shared()).into()
133 }
134 )*
135 _ => panic!(
136 "Data type mismatch (lhs: {:?}, rhs: {:?})",
137 lhs_dtype, rhs_dtype
138 ),
139 }
140 }};
141 (($lhs:expr, $rhs:expr), $op:expr) => {{
143 $crate::execute_with_dtype!(($lhs, $rhs), E, $op)
144 }};
145
146 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
148 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
149 F64 => f64, F32 => f32,
150 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
151 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
152 Bool => bool
153 ])
154 }};
155
156 ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{
157 match $tensor {
158 $(
159 $crate::NdArrayTensor::$dtype(storage) => {
160 #[allow(unused)]
161 type $element = $ty;
162 $op(storage.into_shared()).into()
164 }
165 )*
166 #[allow(unreachable_patterns)]
167 other => unimplemented!("unsupported dtype: {:?}", other.dtype())
168 }
169 }};
170 ($tensor:expr, $op:expr) => {{
172 $crate::execute_with_dtype!($tensor, E, $op)
173 }};
174
175 ($tensor:expr, $element:ident, $op:expr) => {{
177 $crate::execute_with_dtype!($tensor, $element, $op, [
178 F64 => f64, F32 => f32,
179 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
180 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
181 Bool => bool
182 ])
183 }};
184}
185
186#[macro_export]
193macro_rules! execute_with_float_dtype {
194 (($lhs:expr, $rhs:expr), $op:expr) => {{
196 $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)
197 }};
198
199 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
201 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
202 F64 => f64, F32 => f32
203 ])
204 }};
205
206 ($tensor:expr, $op:expr) => {{
208 $crate::execute_with_float_dtype!($tensor, E, $op)
209 }};
210
211 ($tensor:expr, $element:ident, $op:expr) => {{
213 $crate::execute_with_dtype!($tensor, $element, $op, [
214 F64 => f64, F32 => f32
215 ])
216 }};
217}
218
219#[macro_export]
226macro_rules! execute_with_int_dtype {
227 (($lhs:expr, $rhs:expr), $op:expr) => {{
229 $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)
230 }};
231
232 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
234 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
235 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
236 U64 => u64, U32 => u32, U16 => u16, U8 => u8
237 ])
238 }};
239
240 ($tensor:expr, $op:expr) => {{
242 $crate::execute_with_int_dtype!($tensor, E, $op)
243 }};
244
245 ($tensor:expr, $element:ident, $op:expr) => {{
247 $crate::execute_with_dtype!($tensor, $element, $op, [
248 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
249 U64 => u64, U32 => u32, U16 => u16, U8 => u8
250 ])
251 }};
252}
253
254#[macro_export]
261macro_rules! execute_with_numeric_dtype {
262 (($lhs:expr, $rhs:expr), $op:expr) => {{
264 $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)
265 }};
266
267 (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{
269 $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [
270 F64 => f64, F32 => f32,
271 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
272 U64 => u64, U32 => u32, U16 => u16, U8 => u8
273 ])
274 }};
275
276 ($tensor:expr, $op:expr) => {{
278 $crate::execute_with_numeric_dtype!($tensor, E, $op)
279 }};
280
281 ($tensor:expr, $element:ident, $op:expr) => {{
283 $crate::execute_with_dtype!($tensor, $element, $op, [
284 F64 => f64, F32 => f32,
285 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
286 U64 => u64, U32 => u32, U16 => u16, U8 => u8
287 ])
288 }};
289}
290
291#[macro_export]
299macro_rules! cat_with_dtype {
300 ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {
301 match &$tensors[0] {
302 $(NdArrayTensor::$dtype(_) => {
303 let tensors = $tensors
304 .iter()
305 .map(|t| {
306 if let NdArrayTensor::$dtype(storage) = t {
307 storage.view()
309 } else {
310 panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype())
311 }
312 })
313 .collect::<Vec<_>>();
314 NdArrayOps::concatenate(&tensors, $dim).into()
315 })*
316 _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype())
317 }
318 };
319}
320
321impl TensorMetadata for NdArrayTensor {
322 fn dtype(&self) -> DType {
323 match self {
324 NdArrayTensor::F64(_) => DType::F64,
325 NdArrayTensor::F32(_) => DType::F32,
326 NdArrayTensor::I64(_) => DType::I64,
327 NdArrayTensor::I32(_) => DType::I32,
328 NdArrayTensor::I16(_) => DType::I16,
329 NdArrayTensor::I8(_) => DType::I8,
330 NdArrayTensor::U64(_) => DType::U64,
331 NdArrayTensor::U32(_) => DType::U32,
332 NdArrayTensor::U16(_) => DType::U16,
333 NdArrayTensor::U8(_) => DType::U8,
334 NdArrayTensor::Bool(_) => DType::Bool,
335 }
336 }
337
338 fn shape(&self) -> Shape {
339 macro_rules! get_shape {
341 ($($variant:ident),*) => {
342 match self {
343 $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*
344 }
345 };
346 }
347 get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
348 }
349
350 fn rank(&self) -> usize {
351 self.shape().num_dims()
352 }
353}
354
355pub(crate) trait ShapeOps {
356 fn num_dims(self) -> usize;
357 fn num_elements(self) -> usize;
358 fn dims<const N: usize>(self) -> [usize; N];
359 fn into_shape(self) -> Shape;
360}
361
362impl ShapeOps for &[usize] {
363 fn num_dims(self) -> usize {
364 self.len()
365 }
366
367 fn num_elements(self) -> usize {
368 self.iter().product()
369 }
370
371 fn dims<const N: usize>(self) -> [usize; N] {
372 self.try_into().unwrap()
373 }
374
375 fn into_shape(self) -> Shape {
376 Shape {
377 dims: self.to_vec(),
378 }
379 }
380}
381
382mod utils {
383 use burn_std::tensor::is_contiguous;
384
385 use super::*;
386
387 impl NdArrayTensor {
388 pub(crate) fn into_data(self) -> TensorData {
389 let shape = self.shape();
390 let contiguous = self.is_contiguous();
391
392 fn inner<E: Element>(
393 shape: Shape,
394 is_contiguous: bool,
395 array: ArcArray<E, IxDyn>,
396 ) -> TensorData {
397 let vec = if is_contiguous {
398 match array.try_into_owned_nocopy() {
399 Ok(owned) => {
400 let (mut vec, offset) = owned.into_raw_vec_and_offset();
401 if let Some(offset) = offset {
402 vec.drain(..offset);
403 }
404 if vec.len() > shape.num_elements() {
405 vec.drain(shape.num_elements()..vec.len());
406 }
407 vec
408 }
409 Err(array) => array.into_iter().collect(),
410 }
411 } else {
412 array.into_iter().collect()
413 };
414
415 TensorData::new(vec, shape)
416 }
417
418 execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))
420 }
421
422 pub(crate) fn is_contiguous(&self) -> bool {
423 macro_rules! check_contiguous {
426 ($($variant:ident),*) => {
427 match self {
428 $(NdArrayTensor::$variant(storage) => {
429 match storage {
430 NdArrayStorage::Borrowed { .. } => {
431 true
434 }
435 NdArrayStorage::Owned(array) => {
436 let shape = array.shape();
437 let mut strides = Vec::with_capacity(array.strides().len());
438 for &stride in array.strides() {
439 if stride <= 0 {
440 return false;
441 }
442 strides.push(stride as usize);
443 }
444 is_contiguous(shape, &strides)
445 }
446 }
447 })*
448 }
449 };
450 }
451 check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
452 }
453 }
454}
455
456#[macro_export(local_inner_macros)]
458macro_rules! to_typed_dims {
459 (
460 $n:expr,
461 $dims:expr,
462 justdim
463 ) => {{
464 let mut dims = [0; $n];
465 for i in 0..$n {
466 dims[i] = $dims[i];
467 }
468 let dim: Dim<[usize; $n]> = Dim(dims);
469 dim
470 }};
471}
472
473#[macro_export(local_inner_macros)]
475macro_rules! reshape {
476 (
477 ty $ty:ty,
478 n $n:expr,
479 shape $shape:expr,
480 array $array:expr
481 ) => {{
482 let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
483 let array = match $array.is_standard_layout() {
484 true => {
485 match $array.to_shape(dim) {
486 Ok(val) => val.into_shared(),
487 Err(err) => {
488 core::panic!("Shape should be compatible shape={dim:?}: {err:?}");
489 }
490 }
491 },
492 false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
493 };
494 array.into_dyn()
495 }};
496 (
497 ty $ty:ty,
498 shape $shape:expr,
499 array $array:expr,
500 d $D:expr
501 ) => {{
502 match $D {
503 1 => reshape!(ty $ty, n 1, shape $shape, array $array),
504 2 => reshape!(ty $ty, n 2, shape $shape, array $array),
505 3 => reshape!(ty $ty, n 3, shape $shape, array $array),
506 4 => reshape!(ty $ty, n 4, shape $shape, array $array),
507 5 => reshape!(ty $ty, n 5, shape $shape, array $array),
508 6 => reshape!(ty $ty, n 6, shape $shape, array $array),
509 _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
510 }
511 }};
512}
513
514#[macro_export]
516macro_rules! slice {
517 ($tensor:expr, $slices:expr) => {
518 slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)
519 };
520 ($tensor:expr, $slices:expr, $($variant:ident),*) => {
521 match $tensor {
522 $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*
523 }
524 };
525}
526
527impl NdArrayTensor {
528 pub fn from_data(data: TensorData) -> NdArrayTensor {
538 match Self::try_from_data_borrowed(data) {
540 Ok(tensor) => tensor,
541 Err(data) => Self::from_data_owned(data),
542 }
543 }
544
545 fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {
552 let TensorData {
553 bytes,
554 shape,
555 dtype,
556 } = data;
557
558 macro_rules! try_borrow {
559 ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {
560 match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {
561 Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),
562 Err((bytes, shape)) => (bytes, shape),
563 }
564 };
565 }
566
567 let (bytes, shape) = match dtype {
569 DType::F64 => try_borrow!(f64, F64, bytes, shape),
570 DType::F32 => try_borrow!(f32, F32, bytes, shape),
571 DType::I64 => try_borrow!(i64, I64, bytes, shape),
572 DType::I32 => try_borrow!(i32, I32, bytes, shape),
573 DType::I16 => try_borrow!(i16, I16, bytes, shape),
574 DType::I8 => try_borrow!(i8, I8, bytes, shape),
575 DType::U64 => try_borrow!(u64, U64, bytes, shape),
576 DType::U32 => try_borrow!(u32, U32, bytes, shape),
577 DType::U16 => try_borrow!(u16, U16, bytes, shape),
578 DType::U8 => try_borrow!(u8, U8, bytes, shape),
579 DType::Bool => try_borrow!(bool, Bool, bytes, shape),
580 _ => (bytes, shape), };
582
583 Err(TensorData {
584 bytes,
585 shape,
586 dtype,
587 })
588 }
589
590 fn from_data_owned(mut data: TensorData) -> NdArrayTensor {
596 let shape = mem::take(&mut data.shape);
597
598 macro_rules! execute {
599 ($data: expr, [$($dtype: ident => $ty: ty),*]) => {
600 match $data.dtype {
601 $(DType::$dtype => {
602 match data.into_vec::<$ty>() {
603 Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),
605 Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
606 }.into()
607 },)*
608 other => unimplemented!("Unsupported dtype {other:?}"),
609 }
610 };
611 }
612
613 execute!(data, [
614 F64 => f64, F32 => f32,
615 I64 => i64, I32 => i32, I16 => i16, I8 => i8,
616 U64 => u64, U32 => u32, U16 => u16, U8 => u8,
617 Bool => bool
618 ])
619 }
620}
621
622#[derive(Clone, Debug)]
624pub struct NdArrayQTensor {
625 pub qtensor: NdArrayTensor,
627 pub scheme: QuantScheme,
629 pub qparams: Vec<QParams<f32>>,
631}
632
633impl NdArrayQTensor {
634 pub fn strategy(&self) -> QuantizationStrategy {
636 match self.scheme {
637 QuantScheme {
638 level: QuantLevel::Tensor,
639 mode: QuantMode::Symmetric,
640 value:
641 QuantValue::Q8F
642 | QuantValue::Q8S
643 | QuantValue::E4M3
644 | QuantValue::E5M2
645 | QuantValue::Q4F
646 | QuantValue::Q4S
647 | QuantValue::E2M1
648 | QuantValue::Q2F
649 | QuantValue::Q2S,
650 ..
651 } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
652 self.qparams[0].scales,
653 self.scheme.value,
654 )),
655 QuantScheme {
656 level: QuantLevel::Block(block_size),
657 mode: QuantMode::Symmetric,
658 value:
659 QuantValue::Q8F
660 | QuantValue::Q8S
661 | QuantValue::E4M3
662 | QuantValue::E5M2
663 | QuantValue::Q4F
664 | QuantValue::Q4S
665 | QuantValue::E2M1
666 | QuantValue::Q2F
667 | QuantValue::Q2S,
668 ..
669 } => QuantizationStrategy::PerBlockSymmetric(
670 self.qparams
671 .iter()
672 .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))
673 .collect(),
674 block_size,
675 ),
676 }
677 }
678}
679
680impl QTensorPrimitive for NdArrayQTensor {
681 fn scheme(&self) -> &QuantScheme {
682 &self.scheme
683 }
684
685 fn default_scheme() -> QuantScheme {
686 QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)
687 }
688}
689
690impl TensorMetadata for NdArrayQTensor {
691 fn dtype(&self) -> DType {
692 DType::QFloat(self.scheme)
693 }
694
695 fn shape(&self) -> Shape {
696 self.qtensor.shape()
697 }
698
699 fn rank(&self) -> usize {
700 self.shape().num_dims()
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use crate::NdArray;
707 use alloc::vec;
708
709 use super::*;
710 use burn_backend::{
711 Distribution,
712 ops::{FloatTensorOps, QTensorOps},
713 quantization::{QuantStore, QuantizationParametersPrimitive},
714 };
715 use burn_std::rand::get_seeded_rng;
716
717 #[test]
718 fn should_support_into_and_from_data_1d() {
719 let data_expected = TensorData::random::<f32, _, _>(
720 Shape::new([3]),
721 Distribution::Default,
722 &mut get_seeded_rng(),
723 );
724 let tensor = NdArrayTensor::from_data(data_expected.clone());
725
726 let data_actual = tensor.into_data();
727
728 assert_eq!(data_expected, data_actual);
729 }
730
731 #[test]
732 fn should_support_into_and_from_data_2d() {
733 let data_expected = TensorData::random::<f32, _, _>(
734 Shape::new([2, 3]),
735 Distribution::Default,
736 &mut get_seeded_rng(),
737 );
738 let tensor = NdArrayTensor::from_data(data_expected.clone());
739
740 let data_actual = tensor.into_data();
741
742 assert_eq!(data_expected, data_actual);
743 }
744
745 #[test]
746 fn should_support_into_and_from_data_3d() {
747 let data_expected = TensorData::random::<f32, _, _>(
748 Shape::new([2, 3, 4]),
749 Distribution::Default,
750 &mut get_seeded_rng(),
751 );
752 let tensor = NdArrayTensor::from_data(data_expected.clone());
753
754 let data_actual = tensor.into_data();
755
756 assert_eq!(data_expected, data_actual);
757 }
758
759 #[test]
760 fn should_support_into_and_from_data_4d() {
761 let data_expected = TensorData::random::<f32, _, _>(
762 Shape::new([2, 3, 4, 2]),
763 Distribution::Default,
764 &mut get_seeded_rng(),
765 );
766 let tensor = NdArrayTensor::from_data(data_expected.clone());
767
768 let data_actual = tensor.into_data();
769
770 assert_eq!(data_expected, data_actual);
771 }
772
773 #[test]
774 fn should_support_qtensor_strategy() {
775 type B = NdArray<f32, i64, i8>;
776 let scale: f32 = 0.009_019_608;
777 let device = Default::default();
778
779 let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);
780 let scheme = QuantScheme::default()
781 .with_value(QuantValue::Q8S)
782 .with_store(QuantStore::Native);
783 let qparams = QuantizationParametersPrimitive {
784 scales: B::float_from_data(TensorData::from([scale]), &device),
785 };
786 let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);
787
788 assert_eq!(qtensor.scheme(), &scheme);
789 assert_eq!(
790 qtensor.strategy(),
791 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
792 scale,
793 QuantValue::Q8S
794 ))
795 );
796 }
797
798 #[test]
804 fn zero_copy_creates_borrowed_storage() {
805 use burn_std::Bytes;
810
811 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
812 let bytes = Bytes::from_elems(data);
813 let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
814
815 let tensor = NdArrayTensor::from_data(tensor_data);
816
817 match &tensor {
818 NdArrayTensor::F32(storage) => {
819 assert!(
820 storage.is_borrowed(),
821 "ZERO-COPY REGRESSION: from_data should create borrowed storage \
822 for properly aligned TensorData with Bytes"
823 );
824 assert!(
825 !storage.is_unique(),
826 "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false"
827 );
828 }
829 _ => panic!("Expected F32 tensor"),
830 }
831 }
832
833 #[test]
834 fn zero_copy_data_integrity() {
835 use burn_std::Bytes;
837
838 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
839 let bytes = Bytes::from_elems(data);
840 let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);
841
842 let tensor = NdArrayTensor::from_data(tensor_data);
843
844 match &tensor {
845 NdArrayTensor::F32(storage) => {
846 let view = storage.view();
847 assert_eq!(view[[0, 0]], 1.0);
848 assert_eq!(view[[0, 1]], 2.0);
849 assert_eq!(view[[1, 0]], 3.0);
850 assert_eq!(view[[1, 1]], 4.0);
851 }
852 _ => panic!("Expected F32 tensor"),
853 }
854 }
855
856 #[test]
857 fn zero_copy_fallback_when_bytes_owned() {
858 let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
861 let tensor = NdArrayTensor::from_data(data.clone());
862 let result = tensor.into_data();
863
864 assert_eq!(data, result, "Data should round-trip correctly");
865 }
866}