Skip to main content

burn_backend/data/
tensor.rs

1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::RngCore;
9use thiserror::Error;
10
11use crate::Scalar;
12use crate::distribution::Distribution;
13use crate::element::{Element, ElementConversion};
14use burn_std::tensor::DType;
15use burn_std::{Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, bf16, f16};
16
17/// Data structure for tensors.
18#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19pub struct TensorData {
20    /// The values of the tensor (as bytes).
21    pub bytes: Bytes,
22
23    /// The shape of the tensor.
24    pub shape: Vec<usize>,
25
26    /// The data type of the tensor.
27    pub dtype: DType,
28}
29
30impl TensorData {
31    /// Creates a new tensor data structure.
32    pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
33        // Ensure shape is valid
34        let shape = shape.into();
35        Self::check_data_len(&value, &shape);
36
37        Self {
38            bytes: Bytes::from_elems(value),
39            shape,
40            dtype: E::dtype(),
41        }
42    }
43
44    /// Creates a new quantized tensor data structure.
45    pub fn quantized<E: Element, S: Into<Vec<usize>>>(
46        value: Vec<E>,
47        shape: S,
48        scheme: QuantScheme,
49        qparams: &[f32],
50    ) -> Self {
51        let shape = shape.into();
52        Self::check_data_len(&value, &shape);
53
54        let q_bytes = QuantizedBytes::new(value, scheme, qparams);
55
56        Self {
57            bytes: q_bytes.bytes,
58            shape,
59            dtype: DType::QFloat(q_bytes.scheme),
60        }
61    }
62
63    /// Creates a new tensor data structure from raw bytes.
64    pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
65        Self {
66            bytes,
67            shape: shape.into(),
68            dtype,
69        }
70    }
71
72    /// Creates a new tensor data structure from raw bytes stored in a vector.
73    ///
74    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
75    /// certain that the bytes representation is valid.
76    pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
77        Self {
78            bytes: Bytes::from_bytes_vec(bytes),
79            shape: shape.into(),
80            dtype,
81        }
82    }
83
84    // Check that the input vector contains a correct number of elements
85    fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
86        let expected_data_len = Self::numel(shape);
87        let num_data = data.len();
88        assert_eq!(
89            expected_data_len, num_data,
90            "Shape {shape:?} is invalid for input of size {num_data:?}",
91        );
92    }
93
94    /// Returns the immutable slice view of the tensor data.
95    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
96        if E::dtype() == self.dtype {
97            match E::dtype() {
98                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
99                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
100                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
101                DType::Bool => {
102                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
103                        .map_err(DataError::CastError)?;
104                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
105                }
106                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
107            }
108        } else {
109            Err(DataError::TypeMismatch(format!(
110                "Invalid target element type (expected {:?}, got {:?})",
111                self.dtype,
112                E::dtype()
113            )))
114        }
115    }
116
117    /// Returns the mutable slice view of the tensor data.
118    ///
119    /// # Panics
120    /// If the target element type is different from the stored element type.
121    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
122        if E::dtype() == self.dtype {
123            match E::dtype() {
124                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
125                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
126                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
127                DType::Bool => {
128                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
129                        .map_err(DataError::CastError)?;
130                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
131                }
132                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
133                    .map_err(DataError::CastError),
134            }
135        } else {
136            Err(DataError::TypeMismatch(format!(
137                "Invalid target element type (expected {:?}, got {:?})",
138                self.dtype,
139                E::dtype()
140            )))
141        }
142    }
143
144    /// Returns the tensor data as a vector of scalar values.
145    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
146        Ok(self.as_slice()?.to_vec())
147    }
148
149    /// Returns the tensor data as a vector of scalar values.
150    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
151        // This means we cannot call `into_vec` for QFloat
152        if E::dtype() != self.dtype {
153            return Err(DataError::TypeMismatch(format!(
154                "Invalid target element type (expected {:?}, got {:?})",
155                self.dtype,
156                E::dtype()
157            )));
158        }
159
160        match E::dtype() {
161            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
162            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
163            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
164            DType::Bool => {
165                let vec = self.into_vec_unchecked::<u8>()?;
166                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
167            }
168            _ => self.into_vec_unchecked(),
169        }
170    }
171
172    /// Returns the tensor data as a vector of scalar values. Does not check dtype.
173    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
174        let mut me = self;
175        me.bytes = match me.bytes.try_into_vec::<E>() {
176            Ok(elems) => return Ok(elems),
177            Err(bytes) => bytes,
178        };
179
180        // The bytes might have been deserialized and allocated with a different align.
181        // In that case, we have to memcopy the data into a new vector, more suitably allocated
182        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
183            .map_err(DataError::CastError)?
184            .to_vec())
185    }
186
187    /// Returns an iterator over the values of the tensor data.
188    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
189        if E::dtype() == self.dtype {
190            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
191        } else {
192            match self.dtype {
193                DType::I8 => Box::new(
194                    bytemuck::checked::cast_slice(&self.bytes)
195                        .iter()
196                        .map(|e: &i8| e.elem::<E>()),
197                ),
198                DType::I16 => Box::new(
199                    bytemuck::checked::cast_slice(&self.bytes)
200                        .iter()
201                        .map(|e: &i16| e.elem::<E>()),
202                ),
203                DType::I32 => Box::new(
204                    bytemuck::checked::cast_slice(&self.bytes)
205                        .iter()
206                        .map(|e: &i32| e.elem::<E>()),
207                ),
208                DType::I64 => Box::new(
209                    bytemuck::checked::cast_slice(&self.bytes)
210                        .iter()
211                        .map(|e: &i64| e.elem::<E>()),
212                ),
213                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
214                DType::U16 => Box::new(
215                    bytemuck::checked::cast_slice(&self.bytes)
216                        .iter()
217                        .map(|e: &u16| e.elem::<E>()),
218                ),
219                DType::U32 => Box::new(
220                    bytemuck::checked::cast_slice(&self.bytes)
221                        .iter()
222                        .map(|e: &u32| e.elem::<E>()),
223                ),
224                DType::U64 => Box::new(
225                    bytemuck::checked::cast_slice(&self.bytes)
226                        .iter()
227                        .map(|e: &u64| e.elem::<E>()),
228                ),
229                DType::BF16 => Box::new(
230                    bytemuck::checked::cast_slice(&self.bytes)
231                        .iter()
232                        .map(|e: &bf16| e.elem::<E>()),
233                ),
234                DType::F16 => Box::new(
235                    bytemuck::checked::cast_slice(&self.bytes)
236                        .iter()
237                        .map(|e: &f16| e.elem::<E>()),
238                ),
239                DType::F32 | DType::Flex32 => Box::new(
240                    bytemuck::checked::cast_slice(&self.bytes)
241                        .iter()
242                        .map(|e: &f32| e.elem::<E>()),
243                ),
244                DType::F64 => Box::new(
245                    bytemuck::checked::cast_slice(&self.bytes)
246                        .iter()
247                        .map(|e: &f64| e.elem::<E>()),
248                ),
249                // bool is a byte value equal to either 0 or 1
250                DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
251                DType::QFloat(scheme) => match scheme {
252                    QuantScheme {
253                        level: QuantLevel::Tensor | QuantLevel::Block(_),
254                        mode: QuantMode::Symmetric,
255                        value:
256                            QuantValue::Q8F
257                            | QuantValue::Q8S
258                            // Represent sub-byte values as i8
259                            | QuantValue::Q4F
260                            | QuantValue::Q4S
261                            | QuantValue::Q2F
262                            | QuantValue::Q2S,
263                        ..
264                    } => {
265                        // Quantized int8 values
266                        let q_bytes = QuantizedBytes {
267                            bytes: self.bytes.clone(),
268                            scheme,
269                            num_elements: self.num_elements(),
270                        };
271                        let (values, _) = q_bytes.into_vec_i8();
272
273                        Box::new(
274                            values
275                                .iter()
276                                .map(|e: &i8| e.elem::<E>())
277                                .collect::<Vec<_>>()
278                                .into_iter(),
279                        )
280                    }
281                    QuantScheme {
282                        level: QuantLevel::Tensor | QuantLevel::Block(_),
283                        mode: QuantMode::Symmetric,
284                        value:
285                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
286                        ..
287                    } => {
288                        unimplemented!("Not yet implemented for iteration");
289                    }
290                },
291            }
292        }
293    }
294
295    /// Returns the rank (the number of dimensions).
296    pub fn rank(&self) -> usize {
297        self.shape.len()
298    }
299
300    /// Returns the total number of elements of the tensor data.
301    pub fn num_elements(&self) -> usize {
302        Self::numel(&self.shape)
303    }
304
305    fn numel(shape: &[usize]) -> usize {
306        shape.iter().product()
307    }
308
309    /// Populates the data with random values.
310    pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
311        shape: S,
312        distribution: Distribution,
313        rng: &mut R,
314    ) -> Self {
315        let shape = shape.into();
316        let num_elements = Self::numel(&shape);
317        let mut data = Vec::with_capacity(num_elements);
318
319        for _ in 0..num_elements {
320            data.push(E::random(distribution, rng));
321        }
322
323        TensorData::new(data, shape)
324    }
325
326    /// Populates the data with zeros.
327    pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
328        let shape = shape.into();
329        let num_elements = Self::numel(&shape);
330        let mut data = Vec::<E>::with_capacity(num_elements);
331
332        for _ in 0..num_elements {
333            data.push(0.elem());
334        }
335
336        TensorData::new(data, shape)
337    }
338
339    /// Populates the data with ones.
340    pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
341        let shape = shape.into();
342        let num_elements = Self::numel(&shape);
343        let mut data = Vec::<E>::with_capacity(num_elements);
344
345        for _ in 0..num_elements {
346            data.push(1.elem());
347        }
348
349        TensorData::new(data, shape)
350    }
351
352    /// Populates the data with the given value
353    pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
354        let shape = shape.into();
355        let num_elements = Self::numel(&shape);
356        let mut data = Vec::<E>::with_capacity(num_elements);
357        for _ in 0..num_elements {
358            data.push(fill_value)
359        }
360
361        TensorData::new(data, shape)
362    }
363
364    /// Populates the data with the given value
365    pub fn full_dtype<E: Into<Scalar>, S: Into<Vec<usize>>>(
366        shape: S,
367        fill_value: E,
368        dtype: DType,
369    ) -> TensorData {
370        let fill_value = fill_value.into();
371        match dtype {
372            DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
373            DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
374            DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
375            DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
376            DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
377            DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
378            DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
379            DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
380            DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
381            DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
382            DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
383            DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
384            DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
385            DType::QFloat(_) => unreachable!(),
386        }
387    }
388
389    /// Converts the data to a different element type.
390    pub fn convert<E: Element>(self) -> Self {
391        self.convert_dtype(E::dtype())
392    }
393
394    /// Converts the data to a different element type.
395    pub fn convert_dtype(self, dtype: DType) -> Self {
396        if dtype == self.dtype {
397            self
398        } else if dtype.size() == self.dtype.size()
399            && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
400            && !matches!(dtype, DType::Bool | DType::QFloat(_))
401        {
402            match self.dtype {
403                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
404                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
405                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
406                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
407                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
408                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
409                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
410                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
411                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
412                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
413                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
414                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
415                DType::Bool | DType::QFloat(_) => unreachable!(),
416            }
417        } else {
418            match self.dtype {
419                DType::F64 => self.convert_clone_dtype::<f64>(dtype),
420                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
421                DType::F16 => self.convert_clone_dtype::<f16>(dtype),
422                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
423                DType::I64 => self.convert_clone_dtype::<i64>(dtype),
424                DType::I32 => self.convert_clone_dtype::<i32>(dtype),
425                DType::I16 => self.convert_clone_dtype::<i16>(dtype),
426                DType::I8 => self.convert_clone_dtype::<i8>(dtype),
427                DType::U64 => self.convert_clone_dtype::<u64>(dtype),
428                DType::U32 => self.convert_clone_dtype::<u32>(dtype),
429                DType::U16 => self.convert_clone_dtype::<u16>(dtype),
430                DType::U8 => self.convert_clone_dtype::<u8>(dtype),
431                DType::Bool => self.convert_clone_dtype::<bool>(dtype),
432                DType::QFloat(_) => unreachable!(),
433            }
434        }
435    }
436
437    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
438        match dtype {
439            DType::F64 => self.convert_inplace::<Current, f64>(),
440            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
441            DType::F16 => self.convert_inplace::<Current, f16>(),
442            DType::BF16 => self.convert_inplace::<Current, bf16>(),
443            DType::I64 => self.convert_inplace::<Current, i64>(),
444            DType::I32 => self.convert_inplace::<Current, i32>(),
445            DType::I16 => self.convert_inplace::<Current, i16>(),
446            DType::I8 => self.convert_inplace::<Current, i8>(),
447            DType::U64 => self.convert_inplace::<Current, u64>(),
448            DType::U32 => self.convert_inplace::<Current, u32>(),
449            DType::U16 => self.convert_inplace::<Current, u16>(),
450            DType::U8 => self.convert_inplace::<Current, u8>(),
451            DType::Bool | DType::QFloat(_) => unreachable!(),
452        }
453    }
454
455    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
456        mut self,
457    ) -> Self {
458        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
459            let t: Target = x.elem();
460            let x = cast_mut::<_, Target>(x);
461            *x = t;
462        }
463
464        self.dtype = Target::dtype();
465
466        self
467    }
468
469    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
470        match dtype {
471            DType::F64 => self.convert_clone::<Current, f64>(),
472            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
473            DType::F16 => self.convert_clone::<Current, f16>(),
474            DType::BF16 => self.convert_clone::<Current, bf16>(),
475            DType::I64 => self.convert_clone::<Current, i64>(),
476            DType::I32 => self.convert_clone::<Current, i32>(),
477            DType::I16 => self.convert_clone::<Current, i16>(),
478            DType::I8 => self.convert_clone::<Current, i8>(),
479            DType::U64 => self.convert_clone::<Current, u64>(),
480            DType::U32 => self.convert_clone::<Current, u32>(),
481            DType::U16 => self.convert_clone::<Current, u16>(),
482            DType::U8 => self.convert_clone::<Current, u8>(),
483            DType::Bool => self.convert_clone::<Current, bool>(),
484            DType::QFloat(_) => unreachable!(),
485        }
486    }
487
488    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
489        self,
490    ) -> Self {
491        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
492        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
493
494        for (x, out) in this.iter().zip(&mut out) {
495            *out = x.elem();
496        }
497
498        Self::new(out, self.shape)
499    }
500
501    /// Returns the data as a slice of bytes.
502    pub fn as_bytes(&self) -> &[u8] {
503        &self.bytes
504    }
505
506    /// Returns the bytes representation of the data.
507    pub fn into_bytes(self) -> Bytes {
508        self.bytes
509    }
510}
511
512impl<E: Element, const A: usize> From<[E; A]> for TensorData {
513    fn from(elems: [E; A]) -> Self {
514        TensorData::new(elems.to_vec(), [A])
515    }
516}
517
518impl<const A: usize> From<[usize; A]> for TensorData {
519    fn from(elems: [usize; A]) -> Self {
520        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
521    }
522}
523
524impl From<&[usize]> for TensorData {
525    fn from(elems: &[usize]) -> Self {
526        let mut data = Vec::with_capacity(elems.len());
527        for elem in elems.iter() {
528            data.push(*elem as i64);
529        }
530
531        TensorData::new(data, [elems.len()])
532    }
533}
534
535impl<E: Element> From<&[E]> for TensorData {
536    fn from(elems: &[E]) -> Self {
537        let mut data = Vec::with_capacity(elems.len());
538        for elem in elems.iter() {
539            data.push(*elem);
540        }
541
542        TensorData::new(data, [elems.len()])
543    }
544}
545
546impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
547    fn from(elems: [[E; B]; A]) -> Self {
548        let mut data = Vec::with_capacity(A * B);
549        for elem in elems.into_iter().take(A) {
550            for elem in elem.into_iter().take(B) {
551                data.push(elem);
552            }
553        }
554
555        TensorData::new(data, [A, B])
556    }
557}
558
559impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
560    for TensorData
561{
562    fn from(elems: [[[E; C]; B]; A]) -> Self {
563        let mut data = Vec::with_capacity(A * B * C);
564
565        for elem in elems.into_iter().take(A) {
566            for elem in elem.into_iter().take(B) {
567                for elem in elem.into_iter().take(C) {
568                    data.push(elem);
569                }
570            }
571        }
572
573        TensorData::new(data, [A, B, C])
574    }
575}
576
577impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
578    From<[[[[E; D]; C]; B]; A]> for TensorData
579{
580    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
581        let mut data = Vec::with_capacity(A * B * C * D);
582
583        for elem in elems.into_iter().take(A) {
584            for elem in elem.into_iter().take(B) {
585                for elem in elem.into_iter().take(C) {
586                    for elem in elem.into_iter().take(D) {
587                        data.push(elem);
588                    }
589                }
590            }
591        }
592
593        TensorData::new(data, [A, B, C, D])
594    }
595}
596
597impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
598    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
599{
600    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
601        let mut data = Vec::with_capacity(A * B * C * D * E);
602
603        for elem in elems.into_iter().take(A) {
604            for elem in elem.into_iter().take(B) {
605                for elem in elem.into_iter().take(C) {
606                    for elem in elem.into_iter().take(D) {
607                        for elem in elem.into_iter().take(E) {
608                            data.push(elem);
609                        }
610                    }
611                }
612            }
613        }
614
615        TensorData::new(data, [A, B, C, D, E])
616    }
617}
618impl core::fmt::Display for TensorData {
619    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
620        let fmt = match self.dtype {
621            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
622            DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
623            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
624            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
625            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
626            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
627            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
628            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
629            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
630            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
631            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
632            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
633            DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
634            DType::QFloat(scheme) => match scheme {
635                QuantScheme {
636                    level: QuantLevel::Tensor | QuantLevel::Block(_),
637                    mode: QuantMode::Symmetric,
638                    value:
639                        QuantValue::Q8F
640                        | QuantValue::Q8S
641                        // Display sub-byte values as i8
642                        | QuantValue::Q4F
643                        | QuantValue::Q4S
644                        | QuantValue::Q2F
645                        | QuantValue::Q2S,
646                    ..
647                } => {
648                    format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
649                },
650                QuantScheme {
651                        level: QuantLevel::Tensor | QuantLevel::Block(_),
652                        mode: QuantMode::Symmetric,
653                        value:
654                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
655                        ..
656                    } => {
657                        unimplemented!("Can't format yet");
658                    }
659            },
660        };
661        f.write_str(fmt.as_str())
662    }
663}
664
665/// The things that can go wrong when manipulating tensor data.
666#[derive(Debug, Error)]
667pub enum DataError {
668    /// Failed to cast the values to a specified element type.
669    #[error("Failed to cast values to the specified element type.\nError:\n  {0}")]
670    CastError(CheckedCastError),
671    /// Invalid target element type.
672    #[error("{0}")]
673    TypeMismatch(String),
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use alloc::vec;
680    use rand::{SeedableRng, rngs::StdRng};
681
682    #[test]
683    fn should_have_rank() {
684        let shape = [3, 5, 6];
685        let data = TensorData::random::<f32, _, _>(
686            shape,
687            Distribution::Default,
688            &mut StdRng::from_os_rng(),
689        );
690
691        assert_eq!(data.rank(), 3);
692    }
693
694    #[test]
695    fn into_vec_should_yield_same_value_as_iter() {
696        let shape = [3, 5, 6];
697        let data = TensorData::random::<f32, _, _>(
698            shape,
699            Distribution::Default,
700            &mut StdRng::from_os_rng(),
701        );
702
703        let expected = data.iter::<f32>().collect::<Vec<f32>>();
704        let actual = data.into_vec::<f32>().unwrap();
705
706        assert_eq!(expected, actual);
707    }
708
709    #[test]
710    #[should_panic]
711    fn into_vec_should_assert_wrong_dtype() {
712        let shape = [3, 5, 6];
713        let data = TensorData::random::<f32, _, _>(
714            shape,
715            Distribution::Default,
716            &mut StdRng::from_os_rng(),
717        );
718
719        data.into_vec::<i32>().unwrap();
720    }
721
722    #[test]
723    fn should_have_right_num_elements() {
724        let shape = [3, 5, 6];
725        let num_elements: usize = shape.iter().product();
726        let data = TensorData::random::<f32, _, _>(
727            shape,
728            Distribution::Default,
729            &mut StdRng::from_os_rng(),
730        );
731
732        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
733        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
734    }
735
736    #[test]
737    fn should_have_right_shape() {
738        let data = TensorData::from([[3.0, 5.0, 6.0]]);
739        assert_eq!(data.shape, vec![1, 3]);
740
741        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
742        assert_eq!(data.shape, vec![2, 3]);
743
744        let data = TensorData::from([3.0, 5.0, 6.0]);
745        assert_eq!(data.shape, vec![3]);
746    }
747
748    #[test]
749    fn should_convert_bytes_correctly() {
750        let mut vector: Vec<f32> = Vec::with_capacity(5);
751        vector.push(2.0);
752        vector.push(3.0);
753        let data1 = TensorData::new(vector, vec![2]);
754
755        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
756        assert_eq!(data1.bytes.len(), 2 * factor);
757        assert_eq!(data1.bytes.capacity(), 5 * factor);
758    }
759
760    #[test]
761    fn should_convert_bytes_correctly_inplace() {
762        fn test_precision<E: Element>() {
763            let data = TensorData::new((0..32).collect(), [32]);
764            for (i, val) in data
765                .clone()
766                .convert::<E>()
767                .into_vec::<E>()
768                .unwrap()
769                .into_iter()
770                .enumerate()
771            {
772                assert_eq!(i as u32, val.elem::<u32>())
773            }
774        }
775        test_precision::<f32>();
776        test_precision::<f16>();
777        test_precision::<i64>();
778        test_precision::<i32>();
779    }
780
781    macro_rules! test_dtypes {
782    ($test_name:ident, $($dtype:ty),*) => {
783        $(
784            paste::paste! {
785                #[test]
786                fn [<$test_name _ $dtype:snake>]() {
787                    let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
788                    let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
789                    assert_eq!(full_dtype, full);
790                }
791            }
792        )*
793    };
794}
795
796    test_dtypes!(
797        should_create_with_dtype,
798        bool,
799        i8,
800        i16,
801        i32,
802        i64,
803        u8,
804        u16,
805        u32,
806        u64,
807        f16,
808        bf16,
809        f32,
810        f64
811    );
812}