burn_backend/data/
tensor.rs

1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::RngCore;
9
10use crate::distribution::Distribution;
11use crate::element::{Element, ElementConversion};
12use burn_std::tensor::DType;
13use burn_std::{Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, bf16, f16};
14
15/// Data structure for tensors.
16#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
17pub struct TensorData {
18    /// The values of the tensor (as bytes).
19    pub bytes: Bytes,
20
21    /// The shape of the tensor.
22    pub shape: Vec<usize>,
23
24    /// The data type of the tensor.
25    pub dtype: DType,
26}
27
28impl TensorData {
29    /// Creates a new tensor data structure.
30    pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
31        // Ensure shape is valid
32        let shape = shape.into();
33        Self::check_data_len(&value, &shape);
34
35        Self {
36            bytes: Bytes::from_elems(value),
37            shape,
38            dtype: E::dtype(),
39        }
40    }
41
42    /// Creates a new quantized tensor data structure.
43    pub fn quantized<E: Element, S: Into<Vec<usize>>>(
44        value: Vec<E>,
45        shape: S,
46        scheme: QuantScheme,
47        qparams: &[f32],
48    ) -> Self {
49        let shape = shape.into();
50        Self::check_data_len(&value, &shape);
51
52        let q_bytes = QuantizedBytes::new(value, scheme, qparams);
53
54        Self {
55            bytes: q_bytes.bytes,
56            shape,
57            dtype: DType::QFloat(q_bytes.scheme),
58        }
59    }
60
61    /// Creates a new tensor data structure from raw bytes.
62    pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
63        Self {
64            bytes,
65            shape: shape.into(),
66            dtype,
67        }
68    }
69
70    /// Creates a new tensor data structure from raw bytes stored in a vector.
71    ///
72    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
73    /// certain that the bytes representation is valid.
74    pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
75        Self {
76            bytes: Bytes::from_bytes_vec(bytes),
77            shape: shape.into(),
78            dtype,
79        }
80    }
81
82    // Check that the input vector contains a correct number of elements
83    fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
84        let expected_data_len = Self::numel(shape);
85        let num_data = data.len();
86        assert_eq!(
87            expected_data_len, num_data,
88            "Shape {shape:?} is invalid for input of size {num_data:?}",
89        );
90    }
91
92    /// Returns the immutable slice view of the tensor data.
93    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
94        if E::dtype() == self.dtype {
95            match E::dtype() {
96                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
97                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
98                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
99                DType::Bool => {
100                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
101                        .map_err(DataError::CastError)?;
102                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
103                }
104                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
105            }
106        } else {
107            Err(DataError::TypeMismatch(format!(
108                "Invalid target element type (expected {:?}, got {:?})",
109                self.dtype,
110                E::dtype()
111            )))
112        }
113    }
114
115    /// Returns the mutable slice view of the tensor data.
116    ///
117    /// # Panics
118    /// If the target element type is different from the stored element type.
119    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
120        if E::dtype() == self.dtype {
121            match E::dtype() {
122                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
123                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
124                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
125                DType::Bool => {
126                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
127                        .map_err(DataError::CastError)?;
128                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
129                }
130                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
131                    .map_err(DataError::CastError),
132            }
133        } else {
134            Err(DataError::TypeMismatch(format!(
135                "Invalid target element type (expected {:?}, got {:?})",
136                self.dtype,
137                E::dtype()
138            )))
139        }
140    }
141
142    /// Returns the tensor data as a vector of scalar values.
143    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
144        Ok(self.as_slice()?.to_vec())
145    }
146
147    /// Returns the tensor data as a vector of scalar values.
148    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
149        // This means we cannot call `into_vec` for QFloat
150        if E::dtype() != self.dtype {
151            return Err(DataError::TypeMismatch(format!(
152                "Invalid target element type (expected {:?}, got {:?})",
153                self.dtype,
154                E::dtype()
155            )));
156        }
157
158        match E::dtype() {
159            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
160            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
161            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
162            DType::Bool => {
163                let vec = self.into_vec_unchecked::<u8>()?;
164                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
165            }
166            _ => self.into_vec_unchecked(),
167        }
168    }
169
170    /// Returns the tensor data as a vector of scalar values. Does not check dtype.
171    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
172        let mut me = self;
173        me.bytes = match me.bytes.try_into_vec::<E>() {
174            Ok(elems) => return Ok(elems),
175            Err(bytes) => bytes,
176        };
177
178        // The bytes might have been deserialized and allocated with a different align.
179        // In that case, we have to memcopy the data into a new vector, more suitably allocated
180        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
181            .map_err(DataError::CastError)?
182            .to_vec())
183    }
184
185    /// Returns an iterator over the values of the tensor data.
186    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
187        if E::dtype() == self.dtype {
188            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
189        } else {
190            match self.dtype {
191                DType::I8 => Box::new(
192                    bytemuck::checked::cast_slice(&self.bytes)
193                        .iter()
194                        .map(|e: &i8| e.elem::<E>()),
195                ),
196                DType::I16 => Box::new(
197                    bytemuck::checked::cast_slice(&self.bytes)
198                        .iter()
199                        .map(|e: &i16| e.elem::<E>()),
200                ),
201                DType::I32 => Box::new(
202                    bytemuck::checked::cast_slice(&self.bytes)
203                        .iter()
204                        .map(|e: &i32| e.elem::<E>()),
205                ),
206                DType::I64 => Box::new(
207                    bytemuck::checked::cast_slice(&self.bytes)
208                        .iter()
209                        .map(|e: &i64| e.elem::<E>()),
210                ),
211                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
212                DType::U16 => Box::new(
213                    bytemuck::checked::cast_slice(&self.bytes)
214                        .iter()
215                        .map(|e: &u16| e.elem::<E>()),
216                ),
217                DType::U32 => Box::new(
218                    bytemuck::checked::cast_slice(&self.bytes)
219                        .iter()
220                        .map(|e: &u32| e.elem::<E>()),
221                ),
222                DType::U64 => Box::new(
223                    bytemuck::checked::cast_slice(&self.bytes)
224                        .iter()
225                        .map(|e: &u64| e.elem::<E>()),
226                ),
227                DType::BF16 => Box::new(
228                    bytemuck::checked::cast_slice(&self.bytes)
229                        .iter()
230                        .map(|e: &bf16| e.elem::<E>()),
231                ),
232                DType::F16 => Box::new(
233                    bytemuck::checked::cast_slice(&self.bytes)
234                        .iter()
235                        .map(|e: &f16| e.elem::<E>()),
236                ),
237                DType::F32 | DType::Flex32 => Box::new(
238                    bytemuck::checked::cast_slice(&self.bytes)
239                        .iter()
240                        .map(|e: &f32| e.elem::<E>()),
241                ),
242                DType::F64 => Box::new(
243                    bytemuck::checked::cast_slice(&self.bytes)
244                        .iter()
245                        .map(|e: &f64| e.elem::<E>()),
246                ),
247                // bool is a byte value equal to either 0 or 1
248                DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
249                DType::QFloat(scheme) => match scheme {
250                    QuantScheme {
251                        level: QuantLevel::Tensor | QuantLevel::Block(_),
252                        mode: QuantMode::Symmetric,
253                        value:
254                            QuantValue::Q8F
255                            | QuantValue::Q8S
256                            // Represent sub-byte values as i8
257                            | QuantValue::Q4F
258                            | QuantValue::Q4S
259                            | QuantValue::Q2F
260                            | QuantValue::Q2S,
261                        ..
262                    } => {
263                        // Quantized int8 values
264                        let q_bytes = QuantizedBytes {
265                            bytes: self.bytes.clone(),
266                            scheme,
267                            num_elements: self.num_elements(),
268                        };
269                        let (values, _) = q_bytes.into_vec_i8();
270
271                        Box::new(
272                            values
273                                .iter()
274                                .map(|e: &i8| e.elem::<E>())
275                                .collect::<Vec<_>>()
276                                .into_iter(),
277                        )
278                    }
279                    QuantScheme {
280                        level: QuantLevel::Tensor | QuantLevel::Block(_),
281                        mode: QuantMode::Symmetric,
282                        value:
283                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
284                        ..
285                    } => {
286                        unimplemented!("Not yet implemented for iteration");
287                    }
288                },
289            }
290        }
291    }
292
293    /// Returns the rank (the number of dimensions).
294    pub fn rank(&self) -> usize {
295        self.shape.len()
296    }
297
298    /// Returns the total number of elements of the tensor data.
299    pub fn num_elements(&self) -> usize {
300        Self::numel(&self.shape)
301    }
302
303    fn numel(shape: &[usize]) -> usize {
304        shape.iter().product()
305    }
306
307    /// Populates the data with random values.
308    pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
309        shape: S,
310        distribution: Distribution,
311        rng: &mut R,
312    ) -> Self {
313        let shape = shape.into();
314        let num_elements = Self::numel(&shape);
315        let mut data = Vec::with_capacity(num_elements);
316
317        for _ in 0..num_elements {
318            data.push(E::random(distribution, rng));
319        }
320
321        TensorData::new(data, shape)
322    }
323
324    /// Populates the data with zeros.
325    pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
326        let shape = shape.into();
327        let num_elements = Self::numel(&shape);
328        let mut data = Vec::<E>::with_capacity(num_elements);
329
330        for _ in 0..num_elements {
331            data.push(0.elem());
332        }
333
334        TensorData::new(data, shape)
335    }
336
337    /// Populates the data with ones.
338    pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
339        let shape = shape.into();
340        let num_elements = Self::numel(&shape);
341        let mut data = Vec::<E>::with_capacity(num_elements);
342
343        for _ in 0..num_elements {
344            data.push(1.elem());
345        }
346
347        TensorData::new(data, shape)
348    }
349
350    /// Populates the data with the given value
351    pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
352        let shape = shape.into();
353        let num_elements = Self::numel(&shape);
354        let mut data = Vec::<E>::with_capacity(num_elements);
355        for _ in 0..num_elements {
356            data.push(fill_value)
357        }
358
359        TensorData::new(data, shape)
360    }
361
362    #[allow(dead_code)]
363    /// Populates the data with the given value
364    pub fn full_dtype<E: Element, S: Into<Vec<usize>>>(
365        shape: S,
366        fill_value: E,
367        dtype: DType,
368    ) -> TensorData {
369        match dtype {
370            DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
371            DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
372            DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
373            DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
374            DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
375            DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
376            DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
377            DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
378            DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
379            DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
380            DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
381            DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
382            DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
383            DType::QFloat(_) => unreachable!(),
384        }
385    }
386
387    /// Converts the data to a different element type.
388    pub fn convert<E: Element>(self) -> Self {
389        self.convert_dtype(E::dtype())
390    }
391
392    /// Converts the data to a different element type.
393    pub fn convert_dtype(self, dtype: DType) -> Self {
394        if dtype == self.dtype {
395            self
396        } else if dtype.size() == self.dtype.size()
397            && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
398            && !matches!(dtype, DType::Bool | DType::QFloat(_))
399        {
400            match self.dtype {
401                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
402                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
403                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
404                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
405                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
406                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
407                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
408                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
409                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
410                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
411                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
412                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
413                DType::Bool | DType::QFloat(_) => unreachable!(),
414            }
415        } else {
416            match self.dtype {
417                DType::F64 => self.convert_clone_dtype::<f64>(dtype),
418                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
419                DType::F16 => self.convert_clone_dtype::<f16>(dtype),
420                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
421                DType::I64 => self.convert_clone_dtype::<i64>(dtype),
422                DType::I32 => self.convert_clone_dtype::<i32>(dtype),
423                DType::I16 => self.convert_clone_dtype::<i16>(dtype),
424                DType::I8 => self.convert_clone_dtype::<i8>(dtype),
425                DType::U64 => self.convert_clone_dtype::<u64>(dtype),
426                DType::U32 => self.convert_clone_dtype::<u32>(dtype),
427                DType::U16 => self.convert_clone_dtype::<u16>(dtype),
428                DType::U8 => self.convert_clone_dtype::<u8>(dtype),
429                DType::Bool => self.convert_clone_dtype::<bool>(dtype),
430                DType::QFloat(_) => unreachable!(),
431            }
432        }
433    }
434
435    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
436        match dtype {
437            DType::F64 => self.convert_inplace::<Current, f64>(),
438            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
439            DType::F16 => self.convert_inplace::<Current, f16>(),
440            DType::BF16 => self.convert_inplace::<Current, bf16>(),
441            DType::I64 => self.convert_inplace::<Current, i64>(),
442            DType::I32 => self.convert_inplace::<Current, i32>(),
443            DType::I16 => self.convert_inplace::<Current, i16>(),
444            DType::I8 => self.convert_inplace::<Current, i8>(),
445            DType::U64 => self.convert_inplace::<Current, u64>(),
446            DType::U32 => self.convert_inplace::<Current, u32>(),
447            DType::U16 => self.convert_inplace::<Current, u16>(),
448            DType::U8 => self.convert_inplace::<Current, u8>(),
449            DType::Bool | DType::QFloat(_) => unreachable!(),
450        }
451    }
452
453    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
454        mut self,
455    ) -> Self {
456        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
457            let t: Target = x.elem();
458            let x = cast_mut::<_, Target>(x);
459            *x = t;
460        }
461
462        self.dtype = Target::dtype();
463
464        self
465    }
466
467    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
468        match dtype {
469            DType::F64 => self.convert_clone::<Current, f64>(),
470            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
471            DType::F16 => self.convert_clone::<Current, f16>(),
472            DType::BF16 => self.convert_clone::<Current, bf16>(),
473            DType::I64 => self.convert_clone::<Current, i64>(),
474            DType::I32 => self.convert_clone::<Current, i32>(),
475            DType::I16 => self.convert_clone::<Current, i16>(),
476            DType::I8 => self.convert_clone::<Current, i8>(),
477            DType::U64 => self.convert_clone::<Current, u64>(),
478            DType::U32 => self.convert_clone::<Current, u32>(),
479            DType::U16 => self.convert_clone::<Current, u16>(),
480            DType::U8 => self.convert_clone::<Current, u8>(),
481            DType::Bool => self.convert_clone::<Current, bool>(),
482            DType::QFloat(_) => unreachable!(),
483        }
484    }
485
486    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
487        self,
488    ) -> Self {
489        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
490        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
491
492        for (x, out) in this.iter().zip(&mut out) {
493            *out = x.elem();
494        }
495
496        Self::new(out, self.shape)
497    }
498
499    /// Returns the data as a slice of bytes.
500    pub fn as_bytes(&self) -> &[u8] {
501        &self.bytes
502    }
503
504    /// Returns the bytes representation of the data.
505    pub fn into_bytes(self) -> Bytes {
506        self.bytes
507    }
508}
509
510impl<E: Element, const A: usize> From<[E; A]> for TensorData {
511    fn from(elems: [E; A]) -> Self {
512        TensorData::new(elems.to_vec(), [A])
513    }
514}
515
516impl<const A: usize> From<[usize; A]> for TensorData {
517    fn from(elems: [usize; A]) -> Self {
518        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
519    }
520}
521
522impl From<&[usize]> for TensorData {
523    fn from(elems: &[usize]) -> Self {
524        let mut data = Vec::with_capacity(elems.len());
525        for elem in elems.iter() {
526            data.push(*elem as i64);
527        }
528
529        TensorData::new(data, [elems.len()])
530    }
531}
532
533impl<E: Element> From<&[E]> for TensorData {
534    fn from(elems: &[E]) -> Self {
535        let mut data = Vec::with_capacity(elems.len());
536        for elem in elems.iter() {
537            data.push(*elem);
538        }
539
540        TensorData::new(data, [elems.len()])
541    }
542}
543
544impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
545    fn from(elems: [[E; B]; A]) -> Self {
546        let mut data = Vec::with_capacity(A * B);
547        for elem in elems.into_iter().take(A) {
548            for elem in elem.into_iter().take(B) {
549                data.push(elem);
550            }
551        }
552
553        TensorData::new(data, [A, B])
554    }
555}
556
557impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
558    for TensorData
559{
560    fn from(elems: [[[E; C]; B]; A]) -> Self {
561        let mut data = Vec::with_capacity(A * B * C);
562
563        for elem in elems.into_iter().take(A) {
564            for elem in elem.into_iter().take(B) {
565                for elem in elem.into_iter().take(C) {
566                    data.push(elem);
567                }
568            }
569        }
570
571        TensorData::new(data, [A, B, C])
572    }
573}
574
575impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
576    From<[[[[E; D]; C]; B]; A]> for TensorData
577{
578    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
579        let mut data = Vec::with_capacity(A * B * C * D);
580
581        for elem in elems.into_iter().take(A) {
582            for elem in elem.into_iter().take(B) {
583                for elem in elem.into_iter().take(C) {
584                    for elem in elem.into_iter().take(D) {
585                        data.push(elem);
586                    }
587                }
588            }
589        }
590
591        TensorData::new(data, [A, B, C, D])
592    }
593}
594
595impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
596    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
597{
598    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
599        let mut data = Vec::with_capacity(A * B * C * D * E);
600
601        for elem in elems.into_iter().take(A) {
602            for elem in elem.into_iter().take(B) {
603                for elem in elem.into_iter().take(C) {
604                    for elem in elem.into_iter().take(D) {
605                        for elem in elem.into_iter().take(E) {
606                            data.push(elem);
607                        }
608                    }
609                }
610            }
611        }
612
613        TensorData::new(data, [A, B, C, D, E])
614    }
615}
616impl core::fmt::Display for TensorData {
617    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
618        let fmt = match self.dtype {
619            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
620            DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
621            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
622            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
623            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
624            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
625            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
626            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
627            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
628            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
629            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
630            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
631            DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
632            DType::QFloat(scheme) => match scheme {
633                QuantScheme {
634                    level: QuantLevel::Tensor | QuantLevel::Block(_),
635                    mode: QuantMode::Symmetric,
636                    value:
637                        QuantValue::Q8F
638                        | QuantValue::Q8S
639                        // Display sub-byte values as i8
640                        | QuantValue::Q4F
641                        | QuantValue::Q4S
642                        | QuantValue::Q2F
643                        | QuantValue::Q2S,
644                    ..
645                } => {
646                    format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
647                },
648                QuantScheme {
649                        level: QuantLevel::Tensor | QuantLevel::Block(_),
650                        mode: QuantMode::Symmetric,
651                        value:
652                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
653                        ..
654                    } => {
655                        unimplemented!("Can't format yet");
656                    }
657            },
658        };
659        f.write_str(fmt.as_str())
660    }
661}
662
663/// The things that can go wrong when manipulating tensor data.
664#[derive(Debug)]
665pub enum DataError {
666    /// Failed to cast the values to a specified element type.
667    CastError(CheckedCastError),
668    /// Invalid target element type.
669    TypeMismatch(String),
670}
671
672impl core::error::Error for DataError {}
673
674impl core::fmt::Display for DataError {
675    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
676        f.write_str(format!("{self:?}").as_str())
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683    use alloc::vec;
684    use rand::{SeedableRng, rngs::StdRng};
685
686    #[test]
687    fn should_have_rank() {
688        let shape = [3, 5, 6];
689        let data = TensorData::random::<f32, _, _>(
690            shape,
691            Distribution::Default,
692            &mut StdRng::from_os_rng(),
693        );
694
695        assert_eq!(data.rank(), 3);
696    }
697
698    #[test]
699    fn into_vec_should_yield_same_value_as_iter() {
700        let shape = [3, 5, 6];
701        let data = TensorData::random::<f32, _, _>(
702            shape,
703            Distribution::Default,
704            &mut StdRng::from_os_rng(),
705        );
706
707        let expected = data.iter::<f32>().collect::<Vec<f32>>();
708        let actual = data.into_vec::<f32>().unwrap();
709
710        assert_eq!(expected, actual);
711    }
712
713    #[test]
714    #[should_panic]
715    fn into_vec_should_assert_wrong_dtype() {
716        let shape = [3, 5, 6];
717        let data = TensorData::random::<f32, _, _>(
718            shape,
719            Distribution::Default,
720            &mut StdRng::from_os_rng(),
721        );
722
723        data.into_vec::<i32>().unwrap();
724    }
725
726    #[test]
727    fn should_have_right_num_elements() {
728        let shape = [3, 5, 6];
729        let num_elements: usize = shape.iter().product();
730        let data = TensorData::random::<f32, _, _>(
731            shape,
732            Distribution::Default,
733            &mut StdRng::from_os_rng(),
734        );
735
736        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
737        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
738    }
739
740    #[test]
741    fn should_have_right_shape() {
742        let data = TensorData::from([[3.0, 5.0, 6.0]]);
743        assert_eq!(data.shape, vec![1, 3]);
744
745        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
746        assert_eq!(data.shape, vec![2, 3]);
747
748        let data = TensorData::from([3.0, 5.0, 6.0]);
749        assert_eq!(data.shape, vec![3]);
750    }
751
752    #[test]
753    fn should_convert_bytes_correctly() {
754        let mut vector: Vec<f32> = Vec::with_capacity(5);
755        vector.push(2.0);
756        vector.push(3.0);
757        let data1 = TensorData::new(vector, vec![2]);
758
759        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
760        assert_eq!(data1.bytes.len(), 2 * factor);
761        assert_eq!(data1.bytes.capacity(), 5 * factor);
762    }
763
764    #[test]
765    fn should_convert_bytes_correctly_inplace() {
766        fn test_precision<E: Element>() {
767            let data = TensorData::new((0..32).collect(), [32]);
768            for (i, val) in data
769                .clone()
770                .convert::<E>()
771                .into_vec::<E>()
772                .unwrap()
773                .into_iter()
774                .enumerate()
775            {
776                assert_eq!(i as u32, val.elem::<u32>())
777            }
778        }
779        test_precision::<f32>();
780        test_precision::<f16>();
781        test_precision::<i64>();
782        test_precision::<i32>();
783    }
784
785    macro_rules! test_dtypes {
786    ($test_name:ident, $($dtype:ty),*) => {
787        $(
788            paste::paste! {
789                #[test]
790                fn [<$test_name _ $dtype:snake>]() {
791                    let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
792                    let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
793                    assert_eq!(full_dtype, full);
794                }
795            }
796        )*
797    };
798}
799
800    test_dtypes!(
801        should_create_with_dtype,
802        bool,
803        i8,
804        i16,
805        i32,
806        i64,
807        u8,
808        u16,
809        u32,
810        u64,
811        f16,
812        bf16,
813        f32,
814        f64
815    );
816}