1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::Rng;
9use thiserror::Error;
10
11use crate::Scalar;
12use crate::distribution::Distribution;
13use crate::element::{Element, ElementConversion};
14use burn_std::tensor::DType;
15use burn_std::{
16 BoolStore, Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, Shape, bf16,
17 f16,
18};
19
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24pub struct TensorData {
25 pub bytes: Bytes,
27
28 #[serde(with = "shape_inner")]
30 pub shape: Shape,
31
32 pub dtype: DType,
34}
35
36mod shape_inner {
38 use burn_std::SmallVec;
39
40 use super::*;
41
42 pub fn serialize<S: serde::Serializer>(
43 shape: &Shape,
44 serializer: S,
45 ) -> Result<S::Ok, S::Error> {
46 shape.as_slice().serialize(serializer)
47 }
48
49 pub fn deserialize<'de, D: serde::Deserializer<'de>>(
50 deserializer: D,
51 ) -> Result<Shape, D::Error> {
52 let dims = SmallVec::<[usize; _]>::deserialize(deserializer)?;
53 Ok(Shape::new_raw(dims))
54 }
55}
56
57impl TensorData {
58 pub fn new<E: Element, S: Into<Shape>>(value: Vec<E>, shape: S) -> Self {
60 let shape = shape.into();
62 Self::check_data_len(&value, &shape);
63
64 Self {
65 bytes: Bytes::from_elems(value),
66 shape,
67 dtype: E::dtype(),
68 }
69 }
70
71 pub fn quantized<E: Element, S: Into<Shape>>(
73 value: Vec<E>,
74 shape: S,
75 scheme: QuantScheme,
76 qparams: &[f32],
77 ) -> Self {
78 let shape = shape.into();
79 Self::check_data_len(&value, &shape);
80
81 let q_bytes = QuantizedBytes::new(value, scheme, qparams);
82
83 Self {
84 bytes: q_bytes.bytes,
85 shape,
86 dtype: DType::QFloat(q_bytes.scheme),
87 }
88 }
89
90 pub fn from_bytes<S: Into<Shape>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
92 Self {
93 bytes,
94 shape: shape.into(),
95 dtype,
96 }
97 }
98
99 pub fn from_bytes_vec<S: Into<Shape>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
104 Self {
105 bytes: Bytes::from_bytes_vec(bytes),
106 shape: shape.into(),
107 dtype,
108 }
109 }
110
111 fn check_data_len<E: Element>(data: &[E], shape: &Shape) {
113 let expected_data_len = Self::numel(shape);
114 let num_data = data.len();
115 assert_eq!(
116 expected_data_len, num_data,
117 "Shape {shape:?} is invalid for input of size {num_data:?}",
118 );
119 }
120
121 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
123 if self.matches_target_dtype::<E>() {
124 match E::dtype() {
125 DType::Bool(BoolStore::Native) => {
129 let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
130 .map_err(DataError::CastError)?;
131 Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
132 }
133 _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
134 }
135 } else {
136 Err(DataError::TypeMismatch(format!(
137 "Invalid target element type (expected {:?}, got {:?})",
138 self.dtype,
139 E::dtype()
140 )))
141 }
142 }
143
144 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
149 if self.matches_target_dtype::<E>() {
150 match E::dtype() {
151 DType::Bool(BoolStore::Native) => {
155 let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
156 .map_err(DataError::CastError)?;
157 Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
158 }
159 _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
160 .map_err(DataError::CastError),
161 }
162 } else {
163 Err(DataError::TypeMismatch(format!(
164 "Invalid target element type (expected {:?}, got {:?})",
165 self.dtype,
166 E::dtype()
167 )))
168 }
169 }
170
171 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
173 Ok(self.as_slice()?.to_vec())
174 }
175
176 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
178 if !self.matches_target_dtype::<E>() {
180 return Err(DataError::TypeMismatch(format!(
181 "Invalid target element type (expected {:?}, got {:?})",
182 self.dtype,
183 E::dtype()
184 )));
185 }
186
187 match E::dtype() {
188 DType::Bool(BoolStore::Native) => {
192 let vec = self.into_vec_unchecked::<u8>()?;
193 Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
194 }
195 _ => self.into_vec_unchecked(),
196 }
197 }
198
199 fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
201 let mut me = self;
202 me.bytes = match me.bytes.try_into_vec::<E>() {
203 Ok(elems) => return Ok(elems),
204 Err(bytes) => bytes,
205 };
206
207 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
210 .map_err(DataError::CastError)?
211 .to_vec())
212 }
213
214 fn matches_target_dtype<E: Element>(&self) -> bool {
215 let target_dtype = E::dtype();
216 match self.dtype {
217 DType::Bool(BoolStore::U8) => {
218 matches!(target_dtype, DType::U8 | DType::Bool(BoolStore::U8))
219 }
220 DType::Bool(BoolStore::U32) => {
221 matches!(target_dtype, DType::U32 | DType::Bool(BoolStore::U32))
222 }
223 dtype => dtype == target_dtype,
224 }
225 }
226
227 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
229 if E::dtype() == self.dtype {
230 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
231 } else {
232 match self.dtype {
233 DType::I8 => Box::new(
234 bytemuck::checked::cast_slice(&self.bytes)
235 .iter()
236 .map(|e: &i8| e.elem::<E>()),
237 ),
238 DType::I16 => Box::new(
239 bytemuck::checked::cast_slice(&self.bytes)
240 .iter()
241 .map(|e: &i16| e.elem::<E>()),
242 ),
243 DType::I32 => Box::new(
244 bytemuck::checked::cast_slice(&self.bytes)
245 .iter()
246 .map(|e: &i32| e.elem::<E>()),
247 ),
248 DType::I64 => Box::new(
249 bytemuck::checked::cast_slice(&self.bytes)
250 .iter()
251 .map(|e: &i64| e.elem::<E>()),
252 ),
253 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
254 DType::U16 => Box::new(
255 bytemuck::checked::cast_slice(&self.bytes)
256 .iter()
257 .map(|e: &u16| e.elem::<E>()),
258 ),
259 DType::U32 => Box::new(
260 bytemuck::checked::cast_slice(&self.bytes)
261 .iter()
262 .map(|e: &u32| e.elem::<E>()),
263 ),
264 DType::U64 => Box::new(
265 bytemuck::checked::cast_slice(&self.bytes)
266 .iter()
267 .map(|e: &u64| e.elem::<E>()),
268 ),
269 DType::BF16 => Box::new(
270 bytemuck::checked::cast_slice(&self.bytes)
271 .iter()
272 .map(|e: &bf16| e.elem::<E>()),
273 ),
274 DType::F16 => Box::new(
275 bytemuck::checked::cast_slice(&self.bytes)
276 .iter()
277 .map(|e: &f16| e.elem::<E>()),
278 ),
279 DType::F32 | DType::Flex32 => Box::new(
280 bytemuck::checked::cast_slice(&self.bytes)
281 .iter()
282 .map(|e: &f32| e.elem::<E>()),
283 ),
284 DType::F64 => Box::new(
285 bytemuck::checked::cast_slice(&self.bytes)
286 .iter()
287 .map(|e: &f64| e.elem::<E>()),
288 ),
289 DType::Bool(BoolStore::Native) | DType::Bool(BoolStore::U8) => {
291 Box::new(self.bytes.iter().map(|e| e.elem::<E>()))
292 }
293 DType::Bool(BoolStore::U32) => Box::new(
294 bytemuck::checked::cast_slice(&self.bytes)
295 .iter()
296 .map(|e: &u32| e.elem::<E>()),
297 ),
298 DType::QFloat(scheme) => match scheme {
299 QuantScheme {
300 level: QuantLevel::Tensor | QuantLevel::Block(_),
301 mode: QuantMode::Symmetric,
302 value:
303 QuantValue::Q8F
304 | QuantValue::Q8S
305 | QuantValue::Q4F
307 | QuantValue::Q4S
308 | QuantValue::Q2F
309 | QuantValue::Q2S,
310 ..
311 } => {
312 let q_bytes = QuantizedBytes {
314 bytes: self.bytes.clone(),
315 scheme,
316 num_elements: self.num_elements(),
317 };
318 let (values, _) = q_bytes.into_vec_i8();
319
320 Box::new(
321 values
322 .iter()
323 .map(|e: &i8| e.elem::<E>())
324 .collect::<Vec<_>>()
325 .into_iter(),
326 )
327 }
328 QuantScheme {
329 level: QuantLevel::Tensor | QuantLevel::Block(_),
330 mode: QuantMode::Symmetric,
331 value:
332 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
333 ..
334 } => {
335 unimplemented!("Not yet implemented for iteration");
336 }
337 },
338 }
339 }
340 }
341
342 pub fn rank(&self) -> usize {
344 self.shape.len()
345 }
346
347 pub fn num_elements(&self) -> usize {
349 Self::numel(&self.shape)
350 }
351
352 fn numel(shape: &[usize]) -> usize {
353 shape.iter().product()
354 }
355
356 pub fn random<E: Element, R: Rng, S: Into<Shape>>(
358 shape: S,
359 distribution: Distribution,
360 rng: &mut R,
361 ) -> Self {
362 let shape = shape.into();
363 let num_elements = Self::numel(&shape);
364 let mut data = Vec::with_capacity(num_elements);
365
366 for _ in 0..num_elements {
367 data.push(E::random(distribution, rng));
368 }
369
370 TensorData::new(data, shape)
371 }
372
373 pub fn zeros<E: Element, S: Into<Shape>>(shape: S) -> TensorData {
375 let shape = shape.into();
376 let num_elements = Self::numel(&shape);
377 let mut data = Vec::<E>::with_capacity(num_elements);
378
379 for _ in 0..num_elements {
380 data.push(0.elem());
381 }
382
383 TensorData::new(data, shape)
384 }
385
386 pub fn ones<E: Element, S: Into<Shape>>(shape: S) -> TensorData {
388 let shape = shape.into();
389 let num_elements = Self::numel(&shape);
390 let mut data = Vec::<E>::with_capacity(num_elements);
391
392 for _ in 0..num_elements {
393 data.push(1.elem());
394 }
395
396 TensorData::new(data, shape)
397 }
398
399 pub fn full<E: Element, S: Into<Shape>>(shape: S, fill_value: E) -> TensorData {
401 let shape = shape.into();
402 let num_elements = Self::numel(&shape);
403 let mut data = Vec::<E>::with_capacity(num_elements);
404 for _ in 0..num_elements {
405 data.push(fill_value)
406 }
407
408 TensorData::new(data, shape)
409 }
410
411 pub fn full_dtype<E: Into<Scalar>, S: Into<Shape>>(
413 shape: S,
414 fill_value: E,
415 dtype: DType,
416 ) -> TensorData {
417 let fill_value = fill_value.into();
418 match dtype {
419 DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
420 DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
421 DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
422 DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
423 DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
424 DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
425 DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
426 DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
427 DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
428 DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
429 DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
430 DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
431 DType::Bool(BoolStore::Native) => Self::full::<bool, _>(shape, fill_value.elem()),
432 DType::Bool(BoolStore::U8) => {
433 Self::full::<u8, _>(shape, fill_value.elem()).into_bool_u8()
434 }
435 DType::Bool(BoolStore::U32) => {
436 Self::full::<u32, _>(shape, fill_value.elem()).into_bool_u32()
437 }
438 DType::QFloat(_) => unreachable!(),
439 }
440 }
441
442 fn into_bool_u8(mut self) -> Self {
444 self.dtype = DType::Bool(BoolStore::U8);
445 self
446 }
447
448 fn into_bool_u32(mut self) -> Self {
450 self.dtype = DType::Bool(BoolStore::U32);
451 self
452 }
453
454 pub fn convert<E: Element>(self) -> Self {
456 self.convert_dtype(E::dtype())
457 }
458
459 pub fn convert_dtype(self, dtype: DType) -> Self {
461 if dtype == self.dtype {
462 self
463 } else if dtype.size() == self.dtype.size()
464 && !matches!(
465 self.dtype,
466 DType::Bool(BoolStore::Native) | DType::QFloat(_)
467 )
468 && !matches!(dtype, DType::Bool(BoolStore::Native) | DType::QFloat(_))
469 {
470 match self.dtype {
471 DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
472 DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
473 DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
474 DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
475 DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
476 DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
477 DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
478 DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
479 DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
480 DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
481 DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
482 DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
483 DType::Bool(BoolStore::U8) => self.convert_inplace_dtype::<u8>(dtype),
484 DType::Bool(BoolStore::U32) => self.convert_inplace_dtype::<u32>(dtype),
485 DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(),
486 }
487 } else {
488 match self.dtype {
489 DType::F64 => self.convert_clone_dtype::<f64>(dtype),
490 DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
491 DType::F16 => self.convert_clone_dtype::<f16>(dtype),
492 DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
493 DType::I64 => self.convert_clone_dtype::<i64>(dtype),
494 DType::I32 => self.convert_clone_dtype::<i32>(dtype),
495 DType::I16 => self.convert_clone_dtype::<i16>(dtype),
496 DType::I8 => self.convert_clone_dtype::<i8>(dtype),
497 DType::U64 => self.convert_clone_dtype::<u64>(dtype),
498 DType::U32 => self.convert_clone_dtype::<u32>(dtype),
499 DType::U16 => self.convert_clone_dtype::<u16>(dtype),
500 DType::U8 => self.convert_clone_dtype::<u8>(dtype),
501 DType::Bool(BoolStore::Native) => self.convert_clone_dtype::<bool>(dtype),
502 DType::Bool(BoolStore::U8) => self.convert_clone_dtype::<u8>(dtype),
503 DType::Bool(BoolStore::U32) => self.convert_clone_dtype::<u32>(dtype),
504 DType::QFloat(_) => unreachable!(),
505 }
506 }
507 }
508
509 fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
510 match dtype {
511 DType::F64 => self.convert_inplace::<Current, f64>(),
512 DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
513 DType::F16 => self.convert_inplace::<Current, f16>(),
514 DType::BF16 => self.convert_inplace::<Current, bf16>(),
515 DType::I64 => self.convert_inplace::<Current, i64>(),
516 DType::I32 => self.convert_inplace::<Current, i32>(),
517 DType::I16 => self.convert_inplace::<Current, i16>(),
518 DType::I8 => self.convert_inplace::<Current, i8>(),
519 DType::U64 => self.convert_inplace::<Current, u64>(),
520 DType::U32 => self.convert_inplace::<Current, u32>(),
521 DType::U16 => self.convert_inplace::<Current, u16>(),
522 DType::U8 => self.convert_inplace::<Current, u8>(),
523 DType::Bool(BoolStore::U8) => self.convert_inplace::<Current, u8>().into_bool_u8(),
524 DType::Bool(BoolStore::U32) => self.convert_inplace::<Current, u32>().into_bool_u32(),
525 DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(),
526 }
527 }
528
529 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
530 mut self,
531 ) -> Self {
532 for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
533 let t: Target = x.elem();
534 let x = cast_mut::<_, Target>(x);
535 *x = t;
536 }
537
538 self.dtype = Target::dtype();
539
540 self
541 }
542
543 fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
544 match dtype {
545 DType::F64 => self.convert_clone::<Current, f64>(),
546 DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
547 DType::F16 => self.convert_clone::<Current, f16>(),
548 DType::BF16 => self.convert_clone::<Current, bf16>(),
549 DType::I64 => self.convert_clone::<Current, i64>(),
550 DType::I32 => self.convert_clone::<Current, i32>(),
551 DType::I16 => self.convert_clone::<Current, i16>(),
552 DType::I8 => self.convert_clone::<Current, i8>(),
553 DType::U64 => self.convert_clone::<Current, u64>(),
554 DType::U32 => self.convert_clone::<Current, u32>(),
555 DType::U16 => self.convert_clone::<Current, u16>(),
556 DType::U8 => self.convert_clone::<Current, u8>(),
557 DType::Bool(BoolStore::Native) => self.convert_clone::<Current, bool>(),
558 DType::Bool(BoolStore::U8) => self.convert_clone::<Current, u8>().into_bool_u8(),
559 DType::Bool(BoolStore::U32) => self.convert_clone::<Current, u32>().into_bool_u32(),
560 DType::QFloat(_) => unreachable!(),
561 }
562 }
563
564 fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
565 self,
566 ) -> Self {
567 let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
568 let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
569
570 for (x, out) in this.iter().zip(&mut out) {
571 *out = x.elem();
572 }
573
574 Self::new(out, self.shape)
575 }
576
577 pub fn as_bytes(&self) -> &[u8] {
579 &self.bytes
580 }
581
582 pub fn into_bytes(self) -> Bytes {
584 self.bytes
585 }
586}
587
588impl<E: Element, const A: usize> From<[E; A]> for TensorData {
589 fn from(elems: [E; A]) -> Self {
590 TensorData::new(elems.to_vec(), [A])
591 }
592}
593
594impl<const A: usize> From<[usize; A]> for TensorData {
595 fn from(elems: [usize; A]) -> Self {
596 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
597 }
598}
599
600impl From<&[usize]> for TensorData {
601 fn from(elems: &[usize]) -> Self {
602 let mut data = Vec::with_capacity(elems.len());
603 for elem in elems.iter() {
604 data.push(*elem as i64);
605 }
606
607 TensorData::new(data, [elems.len()])
608 }
609}
610
611impl<E: Element> From<&[E]> for TensorData {
612 fn from(elems: &[E]) -> Self {
613 let mut data = Vec::with_capacity(elems.len());
614 for elem in elems.iter() {
615 data.push(*elem);
616 }
617
618 TensorData::new(data, [elems.len()])
619 }
620}
621
622impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
623 fn from(elems: [[E; B]; A]) -> Self {
624 let mut data = Vec::with_capacity(A * B);
625 for elem in elems.into_iter().take(A) {
626 for elem in elem.into_iter().take(B) {
627 data.push(elem);
628 }
629 }
630
631 TensorData::new(data, [A, B])
632 }
633}
634
635impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
636 for TensorData
637{
638 fn from(elems: [[[E; C]; B]; A]) -> Self {
639 let mut data = Vec::with_capacity(A * B * C);
640
641 for elem in elems.into_iter().take(A) {
642 for elem in elem.into_iter().take(B) {
643 for elem in elem.into_iter().take(C) {
644 data.push(elem);
645 }
646 }
647 }
648
649 TensorData::new(data, [A, B, C])
650 }
651}
652
653impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
654 From<[[[[E; D]; C]; B]; A]> for TensorData
655{
656 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
657 let mut data = Vec::with_capacity(A * B * C * D);
658
659 for elem in elems.into_iter().take(A) {
660 for elem in elem.into_iter().take(B) {
661 for elem in elem.into_iter().take(C) {
662 for elem in elem.into_iter().take(D) {
663 data.push(elem);
664 }
665 }
666 }
667 }
668
669 TensorData::new(data, [A, B, C, D])
670 }
671}
672
673impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
674 From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
675{
676 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
677 let mut data = Vec::with_capacity(A * B * C * D * E);
678
679 for elem in elems.into_iter().take(A) {
680 for elem in elem.into_iter().take(B) {
681 for elem in elem.into_iter().take(C) {
682 for elem in elem.into_iter().take(D) {
683 for elem in elem.into_iter().take(E) {
684 data.push(elem);
685 }
686 }
687 }
688 }
689 }
690
691 TensorData::new(data, [A, B, C, D, E])
692 }
693}
694impl core::fmt::Display for TensorData {
695 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
696 let fmt = match self.dtype {
697 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
698 DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
699 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
700 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
701 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
702 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
703 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
704 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
705 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
706 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
707 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
708 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
709 DType::Bool(BoolStore::Native) => format!("{:?}", self.as_slice::<bool>().unwrap()),
710 DType::Bool(BoolStore::U8) => format!("{:?}", self.as_slice::<u8>().unwrap()),
711 DType::Bool(BoolStore::U32) => format!("{:?}", self.as_slice::<u32>().unwrap()),
712 DType::QFloat(scheme) => match scheme {
713 QuantScheme {
714 level: QuantLevel::Tensor | QuantLevel::Block(_),
715 mode: QuantMode::Symmetric,
716 value:
717 QuantValue::Q8F
718 | QuantValue::Q8S
719 | QuantValue::Q4F
721 | QuantValue::Q4S
722 | QuantValue::Q2F
723 | QuantValue::Q2S,
724 ..
725 } => {
726 format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
727 },
728 QuantScheme {
729 level: QuantLevel::Tensor | QuantLevel::Block(_),
730 mode: QuantMode::Symmetric,
731 value:
732 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
733 ..
734 } => {
735 unimplemented!("Can't format yet");
736 }
737 },
738 };
739 f.write_str(fmt.as_str())
740 }
741}
742
743#[derive(Debug, Error)]
745pub enum DataError {
746 #[error("Failed to cast values to the specified element type.\nError:\n {0}")]
748 CastError(CheckedCastError),
749 #[error("{0}")]
751 TypeMismatch(String),
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757 use alloc::vec;
758 use burn_std::shape;
759 use rand::{
760 SeedableRng,
761 rngs::{StdRng, SysRng},
762 };
763
764 #[test]
765 fn should_have_rank() {
766 let shape = [3, 5, 6];
767 let data = TensorData::random::<f32, _, _>(
768 shape,
769 Distribution::Default,
770 &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
771 );
772
773 assert_eq!(data.rank(), 3);
774 }
775
776 #[test]
777 fn into_vec_should_yield_same_value_as_iter() {
778 let shape = [3, 5, 6];
779 let data = TensorData::random::<f32, _, _>(
780 shape,
781 Distribution::Default,
782 &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
783 );
784
785 let expected = data.iter::<f32>().collect::<Vec<f32>>();
786 let actual = data.into_vec::<f32>().unwrap();
787
788 assert_eq!(expected, actual);
789 }
790
791 #[test]
792 #[should_panic]
793 fn into_vec_should_assert_wrong_dtype() {
794 let shape = [3, 5, 6];
795 let data = TensorData::random::<f32, _, _>(
796 shape,
797 Distribution::Default,
798 &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
799 );
800
801 data.into_vec::<i32>().unwrap();
802 }
803
804 #[test]
805 fn should_have_right_num_elements() {
806 let shape = [3, 5, 6];
807 let num_elements: usize = shape.iter().product();
808 let data = TensorData::random::<f32, _, _>(
809 shape,
810 Distribution::Default,
811 &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
812 );
813
814 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
816 }
817
818 #[test]
819 fn should_have_right_shape() {
820 let data = TensorData::from([[3.0, 5.0, 6.0]]);
821 assert_eq!(data.shape, shape![1, 3]);
822
823 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
824 assert_eq!(data.shape, shape![2, 3]);
825
826 let data = TensorData::from([3.0, 5.0, 6.0]);
827 assert_eq!(data.shape, shape![3]);
828 }
829
830 #[test]
831 fn should_convert_bytes_correctly() {
832 let mut vector: Vec<f32> = Vec::with_capacity(5);
833 vector.push(2.0);
834 vector.push(3.0);
835 let data1 = TensorData::new(vector, vec![2]);
836
837 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
838 assert_eq!(data1.bytes.len(), 2 * factor);
839 assert_eq!(data1.bytes.capacity(), 5 * factor);
840 }
841
842 #[test]
843 fn should_convert_bytes_correctly_inplace() {
844 fn test_precision<E: Element>() {
845 let data = TensorData::new((0..32).collect(), [32]);
846 for (i, val) in data
847 .clone()
848 .convert::<E>()
849 .into_vec::<E>()
850 .unwrap()
851 .into_iter()
852 .enumerate()
853 {
854 assert_eq!(i as u32, val.elem::<u32>())
855 }
856 }
857 test_precision::<f32>();
858 test_precision::<f16>();
859 test_precision::<i64>();
860 test_precision::<i32>();
861 }
862
863 macro_rules! test_dtypes {
864 ($test_name:ident, $($dtype:ty),*) => {
865 $(
866 paste::paste! {
867 #[test]
868 fn [<$test_name _ $dtype:snake>]() {
869 let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
870 let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
871 assert_eq!(full_dtype, full);
872 }
873 }
874 )*
875 };
876}
877
878 test_dtypes!(
879 should_create_with_dtype,
880 bool,
881 i8,
882 i16,
883 i32,
884 i64,
885 u8,
886 u16,
887 u32,
888 u64,
889 f16,
890 bf16,
891 f32,
892 f64
893 );
894
895 #[test]
896 fn should_serialize_deserialize_tensor_data() {
897 let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);
898 assert_eq!(
899 data.as_bytes(),
900 [
901 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192,
902 64
903 ]
904 );
905 let serialized = serde_json::to_string(&data).unwrap();
906 let deserialized: TensorData = serde_json::from_str(&serialized).unwrap();
907 assert_eq!(data, deserialized);
908 }
909
910 #[test]
911 fn should_deserialize_tensor_data_with_shape_inner() {
912 let serialized = r#"{
914 "bytes": [0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64],
915 "shape": [2, 3],
916 "dtype": "F32"
917 }"#;
918
919 let data: TensorData = serde_json::from_str(serialized).unwrap();
920 assert_eq!(data.shape, shape![2, 3]);
921 assert_eq!(
922 data.as_slice::<f32>().unwrap(),
923 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
924 );
925 }
926
927 #[test]
928 fn should_serialize_shape_as_flat_array() {
929 let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);
932 let serialized = serde_json::to_string(&data).unwrap();
933 let json: serde_json::Value = serde_json::from_str(&serialized).unwrap();
934 assert_eq!(json["shape"], serde_json::json!([2, 3]));
935 }
936}