Skip to main content

burn_backend/data/
tensor.rs

1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::Rng;
9use thiserror::Error;
10
11use crate::Scalar;
12use crate::distribution::Distribution;
13use crate::element::{Element, ElementConversion};
14use burn_std::tensor::DType;
15use burn_std::{
16    BoolStore, Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, Shape, bf16,
17    f16,
18};
19
20use serde::{Deserialize, Serialize};
21
22/// Data structure for tensors.
23#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24pub struct TensorData {
25    /// The values of the tensor (as bytes).
26    pub bytes: Bytes,
27
28    /// The shape of the tensor.
29    #[serde(with = "shape_inner")]
30    pub shape: Shape,
31
32    /// The data type of the tensor.
33    pub dtype: DType,
34}
35
36// For backward compatibility with shape `Vec<usize>`
37mod shape_inner {
38    use burn_std::SmallVec;
39
40    use super::*;
41
42    pub fn serialize<S: serde::Serializer>(
43        shape: &Shape,
44        serializer: S,
45    ) -> Result<S::Ok, S::Error> {
46        shape.as_slice().serialize(serializer)
47    }
48
49    pub fn deserialize<'de, D: serde::Deserializer<'de>>(
50        deserializer: D,
51    ) -> Result<Shape, D::Error> {
52        let dims = SmallVec::<[usize; _]>::deserialize(deserializer)?;
53        Ok(Shape::new_raw(dims))
54    }
55}
56
57impl TensorData {
58    /// Creates a new tensor data structure.
59    pub fn new<E: Element, S: Into<Shape>>(value: Vec<E>, shape: S) -> Self {
60        // Ensure shape is valid
61        let shape = shape.into();
62        Self::check_data_len(&value, &shape);
63
64        Self {
65            bytes: Bytes::from_elems(value),
66            shape,
67            dtype: E::dtype(),
68        }
69    }
70
71    /// Creates a new quantized tensor data structure.
72    pub fn quantized<E: Element, S: Into<Shape>>(
73        value: Vec<E>,
74        shape: S,
75        scheme: QuantScheme,
76        qparams: &[f32],
77    ) -> Self {
78        let shape = shape.into();
79        Self::check_data_len(&value, &shape);
80
81        let q_bytes = QuantizedBytes::new(value, scheme, qparams);
82
83        Self {
84            bytes: q_bytes.bytes,
85            shape,
86            dtype: DType::QFloat(q_bytes.scheme),
87        }
88    }
89
90    /// Creates a new tensor data structure from raw bytes.
91    pub fn from_bytes<S: Into<Shape>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
92        Self {
93            bytes,
94            shape: shape.into(),
95            dtype,
96        }
97    }
98
99    /// Creates a new tensor data structure from raw bytes stored in a vector.
100    ///
101    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
102    /// certain that the bytes representation is valid.
103    pub fn from_bytes_vec<S: Into<Shape>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
104        Self {
105            bytes: Bytes::from_bytes_vec(bytes),
106            shape: shape.into(),
107            dtype,
108        }
109    }
110
111    // Check that the input vector contains a correct number of elements
112    fn check_data_len<E: Element>(data: &[E], shape: &Shape) {
113        let expected_data_len = Self::numel(shape);
114        let num_data = data.len();
115        assert_eq!(
116            expected_data_len, num_data,
117            "Shape {shape:?} is invalid for input of size {num_data:?}",
118        );
119    }
120
121    /// Returns the immutable slice view of the tensor data.
122    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
123        if self.matches_target_dtype::<E>() {
124            match E::dtype() {
125                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
126                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
127                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
128                DType::Bool(BoolStore::Native) => {
129                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
130                        .map_err(DataError::CastError)?;
131                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
132                }
133                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
134            }
135        } else {
136            Err(DataError::TypeMismatch(format!(
137                "Invalid target element type (expected {:?}, got {:?})",
138                self.dtype,
139                E::dtype()
140            )))
141        }
142    }
143
144    /// Returns the mutable slice view of the tensor data.
145    ///
146    /// # Panics
147    /// If the target element type is different from the stored element type.
148    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
149        if self.matches_target_dtype::<E>() {
150            match E::dtype() {
151                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
152                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
153                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
154                DType::Bool(BoolStore::Native) => {
155                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
156                        .map_err(DataError::CastError)?;
157                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
158                }
159                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
160                    .map_err(DataError::CastError),
161            }
162        } else {
163            Err(DataError::TypeMismatch(format!(
164                "Invalid target element type (expected {:?}, got {:?})",
165                self.dtype,
166                E::dtype()
167            )))
168        }
169    }
170
171    /// Returns the tensor data as a vector of scalar values.
172    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
173        Ok(self.as_slice()?.to_vec())
174    }
175
176    /// Returns the tensor data as a vector of scalar values.
177    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
178        // This means we cannot call `into_vec` for QFloat
179        if !self.matches_target_dtype::<E>() {
180            return Err(DataError::TypeMismatch(format!(
181                "Invalid target element type (expected {:?}, got {:?})",
182                self.dtype,
183                E::dtype()
184            )));
185        }
186
187        match E::dtype() {
188            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
189            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
190            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
191            DType::Bool(BoolStore::Native) => {
192                let vec = self.into_vec_unchecked::<u8>()?;
193                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
194            }
195            _ => self.into_vec_unchecked(),
196        }
197    }
198
199    /// Returns the tensor data as a vector of scalar values. Does not check dtype.
200    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
201        let mut me = self;
202        me.bytes = match me.bytes.try_into_vec::<E>() {
203            Ok(elems) => return Ok(elems),
204            Err(bytes) => bytes,
205        };
206
207        // The bytes might have been deserialized and allocated with a different align.
208        // In that case, we have to memcopy the data into a new vector, more suitably allocated
209        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
210            .map_err(DataError::CastError)?
211            .to_vec())
212    }
213
214    fn matches_target_dtype<E: Element>(&self) -> bool {
215        let target_dtype = E::dtype();
216        match self.dtype {
217            DType::Bool(BoolStore::U8) => {
218                matches!(target_dtype, DType::U8 | DType::Bool(BoolStore::U8))
219            }
220            DType::Bool(BoolStore::U32) => {
221                matches!(target_dtype, DType::U32 | DType::Bool(BoolStore::U32))
222            }
223            dtype => dtype == target_dtype,
224        }
225    }
226
227    /// Returns an iterator over the values of the tensor data.
228    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
229        if E::dtype() == self.dtype {
230            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
231        } else {
232            match self.dtype {
233                DType::I8 => Box::new(
234                    bytemuck::checked::cast_slice(&self.bytes)
235                        .iter()
236                        .map(|e: &i8| e.elem::<E>()),
237                ),
238                DType::I16 => Box::new(
239                    bytemuck::checked::cast_slice(&self.bytes)
240                        .iter()
241                        .map(|e: &i16| e.elem::<E>()),
242                ),
243                DType::I32 => Box::new(
244                    bytemuck::checked::cast_slice(&self.bytes)
245                        .iter()
246                        .map(|e: &i32| e.elem::<E>()),
247                ),
248                DType::I64 => Box::new(
249                    bytemuck::checked::cast_slice(&self.bytes)
250                        .iter()
251                        .map(|e: &i64| e.elem::<E>()),
252                ),
253                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
254                DType::U16 => Box::new(
255                    bytemuck::checked::cast_slice(&self.bytes)
256                        .iter()
257                        .map(|e: &u16| e.elem::<E>()),
258                ),
259                DType::U32 => Box::new(
260                    bytemuck::checked::cast_slice(&self.bytes)
261                        .iter()
262                        .map(|e: &u32| e.elem::<E>()),
263                ),
264                DType::U64 => Box::new(
265                    bytemuck::checked::cast_slice(&self.bytes)
266                        .iter()
267                        .map(|e: &u64| e.elem::<E>()),
268                ),
269                DType::BF16 => Box::new(
270                    bytemuck::checked::cast_slice(&self.bytes)
271                        .iter()
272                        .map(|e: &bf16| e.elem::<E>()),
273                ),
274                DType::F16 => Box::new(
275                    bytemuck::checked::cast_slice(&self.bytes)
276                        .iter()
277                        .map(|e: &f16| e.elem::<E>()),
278                ),
279                DType::F32 | DType::Flex32 => Box::new(
280                    bytemuck::checked::cast_slice(&self.bytes)
281                        .iter()
282                        .map(|e: &f32| e.elem::<E>()),
283                ),
284                DType::F64 => Box::new(
285                    bytemuck::checked::cast_slice(&self.bytes)
286                        .iter()
287                        .map(|e: &f64| e.elem::<E>()),
288                ),
289                // bool is a byte value equal to either 0 or 1
290                DType::Bool(BoolStore::Native) | DType::Bool(BoolStore::U8) => {
291                    Box::new(self.bytes.iter().map(|e| e.elem::<E>()))
292                }
293                DType::Bool(BoolStore::U32) => Box::new(
294                    bytemuck::checked::cast_slice(&self.bytes)
295                        .iter()
296                        .map(|e: &u32| e.elem::<E>()),
297                ),
298                DType::QFloat(scheme) => match scheme {
299                    QuantScheme {
300                        level: QuantLevel::Tensor | QuantLevel::Block(_),
301                        mode: QuantMode::Symmetric,
302                        value:
303                            QuantValue::Q8F
304                            | QuantValue::Q8S
305                            // Represent sub-byte values as i8
306                            | QuantValue::Q4F
307                            | QuantValue::Q4S
308                            | QuantValue::Q2F
309                            | QuantValue::Q2S,
310                        ..
311                    } => {
312                        // Quantized int8 values
313                        let q_bytes = QuantizedBytes {
314                            bytes: self.bytes.clone(),
315                            scheme,
316                            num_elements: self.num_elements(),
317                        };
318                        let (values, _) = q_bytes.into_vec_i8();
319
320                        Box::new(
321                            values
322                                .iter()
323                                .map(|e: &i8| e.elem::<E>())
324                                .collect::<Vec<_>>()
325                                .into_iter(),
326                        )
327                    }
328                    QuantScheme {
329                        level: QuantLevel::Tensor | QuantLevel::Block(_),
330                        mode: QuantMode::Symmetric,
331                        value:
332                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
333                        ..
334                    } => {
335                        unimplemented!("Not yet implemented for iteration");
336                    }
337                },
338            }
339        }
340    }
341
342    /// Returns the rank (the number of dimensions).
343    pub fn rank(&self) -> usize {
344        self.shape.len()
345    }
346
347    /// Returns the total number of elements of the tensor data.
348    pub fn num_elements(&self) -> usize {
349        Self::numel(&self.shape)
350    }
351
352    fn numel(shape: &[usize]) -> usize {
353        shape.iter().product()
354    }
355
356    /// Populates the data with random values.
357    pub fn random<E: Element, R: Rng, S: Into<Shape>>(
358        shape: S,
359        distribution: Distribution,
360        rng: &mut R,
361    ) -> Self {
362        let shape = shape.into();
363        let num_elements = Self::numel(&shape);
364        let mut data = Vec::with_capacity(num_elements);
365
366        for _ in 0..num_elements {
367            data.push(E::random(distribution, rng));
368        }
369
370        TensorData::new(data, shape)
371    }
372
373    /// Populates the data with zeros.
374    pub fn zeros<E: Element, S: Into<Shape>>(shape: S) -> TensorData {
375        let shape = shape.into();
376        let num_elements = Self::numel(&shape);
377        let mut data = Vec::<E>::with_capacity(num_elements);
378
379        for _ in 0..num_elements {
380            data.push(0.elem());
381        }
382
383        TensorData::new(data, shape)
384    }
385
386    /// Populates the data with ones.
387    pub fn ones<E: Element, S: Into<Shape>>(shape: S) -> TensorData {
388        let shape = shape.into();
389        let num_elements = Self::numel(&shape);
390        let mut data = Vec::<E>::with_capacity(num_elements);
391
392        for _ in 0..num_elements {
393            data.push(1.elem());
394        }
395
396        TensorData::new(data, shape)
397    }
398
399    /// Populates the data with the given value
400    pub fn full<E: Element, S: Into<Shape>>(shape: S, fill_value: E) -> TensorData {
401        let shape = shape.into();
402        let num_elements = Self::numel(&shape);
403        let mut data = Vec::<E>::with_capacity(num_elements);
404        for _ in 0..num_elements {
405            data.push(fill_value)
406        }
407
408        TensorData::new(data, shape)
409    }
410
411    /// Populates the data with the given value
412    pub fn full_dtype<E: Into<Scalar>, S: Into<Shape>>(
413        shape: S,
414        fill_value: E,
415        dtype: DType,
416    ) -> TensorData {
417        let fill_value = fill_value.into();
418        match dtype {
419            DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
420            DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
421            DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
422            DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
423            DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
424            DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
425            DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
426            DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
427            DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
428            DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
429            DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
430            DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
431            DType::Bool(BoolStore::Native) => Self::full::<bool, _>(shape, fill_value.elem()),
432            DType::Bool(BoolStore::U8) => {
433                Self::full::<u8, _>(shape, fill_value.elem()).into_bool_u8()
434            }
435            DType::Bool(BoolStore::U32) => {
436                Self::full::<u32, _>(shape, fill_value.elem()).into_bool_u32()
437            }
438            DType::QFloat(_) => unreachable!(),
439        }
440    }
441
442    // Unchecked, used to overwrite the dtype
443    fn into_bool_u8(mut self) -> Self {
444        self.dtype = DType::Bool(BoolStore::U8);
445        self
446    }
447
448    // Unchecked, used to overwrite the dtype
449    fn into_bool_u32(mut self) -> Self {
450        self.dtype = DType::Bool(BoolStore::U32);
451        self
452    }
453
454    /// Converts the data to a different element type.
455    pub fn convert<E: Element>(self) -> Self {
456        self.convert_dtype(E::dtype())
457    }
458
459    /// Converts the data to a different element type.
460    pub fn convert_dtype(self, dtype: DType) -> Self {
461        if dtype == self.dtype {
462            self
463        } else if dtype.size() == self.dtype.size()
464            && !matches!(
465                self.dtype,
466                DType::Bool(BoolStore::Native) | DType::QFloat(_)
467            )
468            && !matches!(dtype, DType::Bool(BoolStore::Native) | DType::QFloat(_))
469        {
470            match self.dtype {
471                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
472                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
473                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
474                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
475                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
476                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
477                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
478                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
479                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
480                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
481                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
482                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
483                DType::Bool(BoolStore::U8) => self.convert_inplace_dtype::<u8>(dtype),
484                DType::Bool(BoolStore::U32) => self.convert_inplace_dtype::<u32>(dtype),
485                DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(),
486            }
487        } else {
488            match self.dtype {
489                DType::F64 => self.convert_clone_dtype::<f64>(dtype),
490                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
491                DType::F16 => self.convert_clone_dtype::<f16>(dtype),
492                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
493                DType::I64 => self.convert_clone_dtype::<i64>(dtype),
494                DType::I32 => self.convert_clone_dtype::<i32>(dtype),
495                DType::I16 => self.convert_clone_dtype::<i16>(dtype),
496                DType::I8 => self.convert_clone_dtype::<i8>(dtype),
497                DType::U64 => self.convert_clone_dtype::<u64>(dtype),
498                DType::U32 => self.convert_clone_dtype::<u32>(dtype),
499                DType::U16 => self.convert_clone_dtype::<u16>(dtype),
500                DType::U8 => self.convert_clone_dtype::<u8>(dtype),
501                DType::Bool(BoolStore::Native) => self.convert_clone_dtype::<bool>(dtype),
502                DType::Bool(BoolStore::U8) => self.convert_clone_dtype::<u8>(dtype),
503                DType::Bool(BoolStore::U32) => self.convert_clone_dtype::<u32>(dtype),
504                DType::QFloat(_) => unreachable!(),
505            }
506        }
507    }
508
509    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
510        match dtype {
511            DType::F64 => self.convert_inplace::<Current, f64>(),
512            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
513            DType::F16 => self.convert_inplace::<Current, f16>(),
514            DType::BF16 => self.convert_inplace::<Current, bf16>(),
515            DType::I64 => self.convert_inplace::<Current, i64>(),
516            DType::I32 => self.convert_inplace::<Current, i32>(),
517            DType::I16 => self.convert_inplace::<Current, i16>(),
518            DType::I8 => self.convert_inplace::<Current, i8>(),
519            DType::U64 => self.convert_inplace::<Current, u64>(),
520            DType::U32 => self.convert_inplace::<Current, u32>(),
521            DType::U16 => self.convert_inplace::<Current, u16>(),
522            DType::U8 => self.convert_inplace::<Current, u8>(),
523            DType::Bool(BoolStore::U8) => self.convert_inplace::<Current, u8>().into_bool_u8(),
524            DType::Bool(BoolStore::U32) => self.convert_inplace::<Current, u32>().into_bool_u32(),
525            DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(),
526        }
527    }
528
529    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
530        mut self,
531    ) -> Self {
532        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
533            let t: Target = x.elem();
534            let x = cast_mut::<_, Target>(x);
535            *x = t;
536        }
537
538        self.dtype = Target::dtype();
539
540        self
541    }
542
543    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
544        match dtype {
545            DType::F64 => self.convert_clone::<Current, f64>(),
546            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
547            DType::F16 => self.convert_clone::<Current, f16>(),
548            DType::BF16 => self.convert_clone::<Current, bf16>(),
549            DType::I64 => self.convert_clone::<Current, i64>(),
550            DType::I32 => self.convert_clone::<Current, i32>(),
551            DType::I16 => self.convert_clone::<Current, i16>(),
552            DType::I8 => self.convert_clone::<Current, i8>(),
553            DType::U64 => self.convert_clone::<Current, u64>(),
554            DType::U32 => self.convert_clone::<Current, u32>(),
555            DType::U16 => self.convert_clone::<Current, u16>(),
556            DType::U8 => self.convert_clone::<Current, u8>(),
557            DType::Bool(BoolStore::Native) => self.convert_clone::<Current, bool>(),
558            DType::Bool(BoolStore::U8) => self.convert_clone::<Current, u8>().into_bool_u8(),
559            DType::Bool(BoolStore::U32) => self.convert_clone::<Current, u32>().into_bool_u32(),
560            DType::QFloat(_) => unreachable!(),
561        }
562    }
563
564    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
565        self,
566    ) -> Self {
567        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
568        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
569
570        for (x, out) in this.iter().zip(&mut out) {
571            *out = x.elem();
572        }
573
574        Self::new(out, self.shape)
575    }
576
577    /// Returns the data as a slice of bytes.
578    pub fn as_bytes(&self) -> &[u8] {
579        &self.bytes
580    }
581
582    /// Returns the bytes representation of the data.
583    pub fn into_bytes(self) -> Bytes {
584        self.bytes
585    }
586}
587
588impl<E: Element, const A: usize> From<[E; A]> for TensorData {
589    fn from(elems: [E; A]) -> Self {
590        TensorData::new(elems.to_vec(), [A])
591    }
592}
593
594impl<const A: usize> From<[usize; A]> for TensorData {
595    fn from(elems: [usize; A]) -> Self {
596        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
597    }
598}
599
600impl From<&[usize]> for TensorData {
601    fn from(elems: &[usize]) -> Self {
602        let mut data = Vec::with_capacity(elems.len());
603        for elem in elems.iter() {
604            data.push(*elem as i64);
605        }
606
607        TensorData::new(data, [elems.len()])
608    }
609}
610
611impl<E: Element> From<&[E]> for TensorData {
612    fn from(elems: &[E]) -> Self {
613        let mut data = Vec::with_capacity(elems.len());
614        for elem in elems.iter() {
615            data.push(*elem);
616        }
617
618        TensorData::new(data, [elems.len()])
619    }
620}
621
622impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
623    fn from(elems: [[E; B]; A]) -> Self {
624        let mut data = Vec::with_capacity(A * B);
625        for elem in elems.into_iter().take(A) {
626            for elem in elem.into_iter().take(B) {
627                data.push(elem);
628            }
629        }
630
631        TensorData::new(data, [A, B])
632    }
633}
634
635impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
636    for TensorData
637{
638    fn from(elems: [[[E; C]; B]; A]) -> Self {
639        let mut data = Vec::with_capacity(A * B * C);
640
641        for elem in elems.into_iter().take(A) {
642            for elem in elem.into_iter().take(B) {
643                for elem in elem.into_iter().take(C) {
644                    data.push(elem);
645                }
646            }
647        }
648
649        TensorData::new(data, [A, B, C])
650    }
651}
652
653impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
654    From<[[[[E; D]; C]; B]; A]> for TensorData
655{
656    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
657        let mut data = Vec::with_capacity(A * B * C * D);
658
659        for elem in elems.into_iter().take(A) {
660            for elem in elem.into_iter().take(B) {
661                for elem in elem.into_iter().take(C) {
662                    for elem in elem.into_iter().take(D) {
663                        data.push(elem);
664                    }
665                }
666            }
667        }
668
669        TensorData::new(data, [A, B, C, D])
670    }
671}
672
673impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
674    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
675{
676    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
677        let mut data = Vec::with_capacity(A * B * C * D * E);
678
679        for elem in elems.into_iter().take(A) {
680            for elem in elem.into_iter().take(B) {
681                for elem in elem.into_iter().take(C) {
682                    for elem in elem.into_iter().take(D) {
683                        for elem in elem.into_iter().take(E) {
684                            data.push(elem);
685                        }
686                    }
687                }
688            }
689        }
690
691        TensorData::new(data, [A, B, C, D, E])
692    }
693}
694impl core::fmt::Display for TensorData {
695    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
696        let fmt = match self.dtype {
697            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
698            DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
699            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
700            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
701            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
702            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
703            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
704            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
705            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
706            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
707            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
708            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
709            DType::Bool(BoolStore::Native) => format!("{:?}", self.as_slice::<bool>().unwrap()),
710            DType::Bool(BoolStore::U8) => format!("{:?}", self.as_slice::<u8>().unwrap()),
711            DType::Bool(BoolStore::U32) => format!("{:?}", self.as_slice::<u32>().unwrap()),
712            DType::QFloat(scheme) => match scheme {
713                QuantScheme {
714                    level: QuantLevel::Tensor | QuantLevel::Block(_),
715                    mode: QuantMode::Symmetric,
716                    value:
717                        QuantValue::Q8F
718                        | QuantValue::Q8S
719                        // Display sub-byte values as i8
720                        | QuantValue::Q4F
721                        | QuantValue::Q4S
722                        | QuantValue::Q2F
723                        | QuantValue::Q2S,
724                    ..
725                } => {
726                    format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
727                },
728                QuantScheme {
729                        level: QuantLevel::Tensor | QuantLevel::Block(_),
730                        mode: QuantMode::Symmetric,
731                        value:
732                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
733                        ..
734                    } => {
735                        unimplemented!("Can't format yet");
736                    }
737            },
738        };
739        f.write_str(fmt.as_str())
740    }
741}
742
743/// The things that can go wrong when manipulating tensor data.
744#[derive(Debug, Error)]
745pub enum DataError {
746    /// Failed to cast the values to a specified element type.
747    #[error("Failed to cast values to the specified element type.\nError:\n  {0}")]
748    CastError(CheckedCastError),
749    /// Invalid target element type.
750    #[error("{0}")]
751    TypeMismatch(String),
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757    use alloc::vec;
758    use burn_std::shape;
759    use rand::{
760        SeedableRng,
761        rngs::{StdRng, SysRng},
762    };
763
764    #[test]
765    fn should_have_rank() {
766        let shape = [3, 5, 6];
767        let data = TensorData::random::<f32, _, _>(
768            shape,
769            Distribution::Default,
770            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
771        );
772
773        assert_eq!(data.rank(), 3);
774    }
775
776    #[test]
777    fn into_vec_should_yield_same_value_as_iter() {
778        let shape = [3, 5, 6];
779        let data = TensorData::random::<f32, _, _>(
780            shape,
781            Distribution::Default,
782            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
783        );
784
785        let expected = data.iter::<f32>().collect::<Vec<f32>>();
786        let actual = data.into_vec::<f32>().unwrap();
787
788        assert_eq!(expected, actual);
789    }
790
791    #[test]
792    #[should_panic]
793    fn into_vec_should_assert_wrong_dtype() {
794        let shape = [3, 5, 6];
795        let data = TensorData::random::<f32, _, _>(
796            shape,
797            Distribution::Default,
798            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
799        );
800
801        data.into_vec::<i32>().unwrap();
802    }
803
804    #[test]
805    fn should_have_right_num_elements() {
806        let shape = [3, 5, 6];
807        let num_elements: usize = shape.iter().product();
808        let data = TensorData::random::<f32, _, _>(
809            shape,
810            Distribution::Default,
811            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),
812        );
813
814        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
815        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
816    }
817
818    #[test]
819    fn should_have_right_shape() {
820        let data = TensorData::from([[3.0, 5.0, 6.0]]);
821        assert_eq!(data.shape, shape![1, 3]);
822
823        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
824        assert_eq!(data.shape, shape![2, 3]);
825
826        let data = TensorData::from([3.0, 5.0, 6.0]);
827        assert_eq!(data.shape, shape![3]);
828    }
829
830    #[test]
831    fn should_convert_bytes_correctly() {
832        let mut vector: Vec<f32> = Vec::with_capacity(5);
833        vector.push(2.0);
834        vector.push(3.0);
835        let data1 = TensorData::new(vector, vec![2]);
836
837        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
838        assert_eq!(data1.bytes.len(), 2 * factor);
839        assert_eq!(data1.bytes.capacity(), 5 * factor);
840    }
841
842    #[test]
843    fn should_convert_bytes_correctly_inplace() {
844        fn test_precision<E: Element>() {
845            let data = TensorData::new((0..32).collect(), [32]);
846            for (i, val) in data
847                .clone()
848                .convert::<E>()
849                .into_vec::<E>()
850                .unwrap()
851                .into_iter()
852                .enumerate()
853            {
854                assert_eq!(i as u32, val.elem::<u32>())
855            }
856        }
857        test_precision::<f32>();
858        test_precision::<f16>();
859        test_precision::<i64>();
860        test_precision::<i32>();
861    }
862
863    macro_rules! test_dtypes {
864    ($test_name:ident, $($dtype:ty),*) => {
865        $(
866            paste::paste! {
867                #[test]
868                fn [<$test_name _ $dtype:snake>]() {
869                    let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
870                    let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
871                    assert_eq!(full_dtype, full);
872                }
873            }
874        )*
875    };
876}
877
878    test_dtypes!(
879        should_create_with_dtype,
880        bool,
881        i8,
882        i16,
883        i32,
884        i64,
885        u8,
886        u16,
887        u32,
888        u64,
889        f16,
890        bf16,
891        f32,
892        f64
893    );
894
895    #[test]
896    fn should_serialize_deserialize_tensor_data() {
897        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);
898        assert_eq!(
899            data.as_bytes(),
900            [
901                0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192,
902                64
903            ]
904        );
905        let serialized = serde_json::to_string(&data).unwrap();
906        let deserialized: TensorData = serde_json::from_str(&serialized).unwrap();
907        assert_eq!(data, deserialized);
908    }
909
910    #[test]
911    fn should_deserialize_tensor_data_with_shape_inner() {
912        // TensorData `shape` was previously a Vec<usize>.
913        let serialized = r#"{
914        "bytes": [0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64],
915        "shape": [2, 3],
916        "dtype": "F32"
917    }"#;
918
919        let data: TensorData = serde_json::from_str(serialized).unwrap();
920        assert_eq!(data.shape, shape![2, 3]);
921        assert_eq!(
922            data.as_slice::<f32>().unwrap(),
923            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
924        );
925    }
926
927    #[test]
928    fn should_serialize_shape_as_flat_array() {
929        // Ensure the new Shape serializes identically to how Vec<usize> used to,
930        // i.e. as a flat JSON array, not as an object like `{"dims": [2, 3]}`.
931        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);
932        let serialized = serde_json::to_string(&data).unwrap();
933        let json: serde_json::Value = serde_json::from_str(&serialized).unwrap();
934        assert_eq!(json["shape"], serde_json::json!([2, 3]));
935    }
936}