burn_backend/data/
tensor.rs

1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::RngCore;
9
10use crate::distribution::Distribution;
11use crate::element::{Element, ElementConversion};
12use burn_std::tensor::DType;
13use burn_std::tensor::quantization::{
14    QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes,
15};
16use burn_std::{Bytes, bf16, f16};
17
18/// Data structure for tensors.
19#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
20pub struct TensorData {
21    /// The values of the tensor (as bytes).
22    pub bytes: Bytes,
23
24    /// The shape of the tensor.
25    pub shape: Vec<usize>,
26
27    /// The data type of the tensor.
28    pub dtype: DType,
29}
30
31impl TensorData {
32    /// Creates a new tensor data structure.
33    pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
34        // Ensure shape is valid
35        let shape = shape.into();
36        Self::check_data_len(&value, &shape);
37
38        Self {
39            bytes: Bytes::from_elems(value),
40            shape,
41            dtype: E::dtype(),
42        }
43    }
44
45    /// Creates a new quantized tensor data structure.
46    pub fn quantized<E: Element, S: Into<Vec<usize>>>(
47        value: Vec<E>,
48        shape: S,
49        scheme: QuantScheme,
50        qparams: &[f32],
51    ) -> Self {
52        let shape = shape.into();
53        Self::check_data_len(&value, &shape);
54
55        let q_bytes = QuantizedBytes::new(value, scheme, qparams);
56
57        Self {
58            bytes: q_bytes.bytes,
59            shape,
60            dtype: DType::QFloat(q_bytes.scheme),
61        }
62    }
63
64    /// Creates a new tensor data structure from raw bytes.
65    pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
66        Self {
67            bytes,
68            shape: shape.into(),
69            dtype,
70        }
71    }
72
73    /// Creates a new tensor data structure from raw bytes stored in a vector.
74    ///
75    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
76    /// certain that the bytes representation is valid.
77    pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
78        Self {
79            bytes: Bytes::from_bytes_vec(bytes),
80            shape: shape.into(),
81            dtype,
82        }
83    }
84
85    // Check that the input vector contains a correct number of elements
86    fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
87        let expected_data_len = Self::numel(shape);
88        let num_data = data.len();
89        assert_eq!(
90            expected_data_len, num_data,
91            "Shape {shape:?} is invalid for input of size {num_data:?}",
92        );
93    }
94
95    /// Returns the immutable slice view of the tensor data.
96    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
97        if E::dtype() == self.dtype {
98            match E::dtype() {
99                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
100                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
101                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
102                DType::Bool => {
103                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
104                        .map_err(DataError::CastError)?;
105                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
106                }
107                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
108            }
109        } else {
110            Err(DataError::TypeMismatch(format!(
111                "Invalid target element type (expected {:?}, got {:?})",
112                self.dtype,
113                E::dtype()
114            )))
115        }
116    }
117
118    /// Returns the mutable slice view of the tensor data.
119    ///
120    /// # Panics
121    /// If the target element type is different from the stored element type.
122    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
123        if E::dtype() == self.dtype {
124            match E::dtype() {
125                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
126                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
127                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
128                DType::Bool => {
129                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
130                        .map_err(DataError::CastError)?;
131                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
132                }
133                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
134                    .map_err(DataError::CastError),
135            }
136        } else {
137            Err(DataError::TypeMismatch(format!(
138                "Invalid target element type (expected {:?}, got {:?})",
139                self.dtype,
140                E::dtype()
141            )))
142        }
143    }
144
145    /// Returns the tensor data as a vector of scalar values.
146    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
147        Ok(self.as_slice()?.to_vec())
148    }
149
150    /// Returns the tensor data as a vector of scalar values.
151    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
152        // This means we cannot call `into_vec` for QFloat
153        if E::dtype() != self.dtype {
154            return Err(DataError::TypeMismatch(format!(
155                "Invalid target element type (expected {:?}, got {:?})",
156                self.dtype,
157                E::dtype()
158            )));
159        }
160
161        match E::dtype() {
162            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
163            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
164            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
165            DType::Bool => {
166                let vec = self.into_vec_unchecked::<u8>()?;
167                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
168            }
169            _ => self.into_vec_unchecked(),
170        }
171    }
172
173    /// Returns the tensor data as a vector of scalar values. Does not check dtype.
174    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
175        let mut me = self;
176        me.bytes = match me.bytes.try_into_vec::<E>() {
177            Ok(elems) => return Ok(elems),
178            Err(bytes) => bytes,
179        };
180        // The bytes might have been deserialized and allocated with a different align.
181        // In that case, we have to memcopy the data into a new vector, more suitably allocated
182        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
183            .map_err(DataError::CastError)?
184            .to_vec())
185    }
186
187    /// Returns an iterator over the values of the tensor data.
188    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
189        if E::dtype() == self.dtype {
190            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
191        } else {
192            match self.dtype {
193                DType::I8 => Box::new(
194                    bytemuck::checked::cast_slice(&self.bytes)
195                        .iter()
196                        .map(|e: &i8| e.elem::<E>()),
197                ),
198                DType::I16 => Box::new(
199                    bytemuck::checked::cast_slice(&self.bytes)
200                        .iter()
201                        .map(|e: &i16| e.elem::<E>()),
202                ),
203                DType::I32 => Box::new(
204                    bytemuck::checked::cast_slice(&self.bytes)
205                        .iter()
206                        .map(|e: &i32| e.elem::<E>()),
207                ),
208                DType::I64 => Box::new(
209                    bytemuck::checked::cast_slice(&self.bytes)
210                        .iter()
211                        .map(|e: &i64| e.elem::<E>()),
212                ),
213                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
214                DType::U16 => Box::new(
215                    bytemuck::checked::cast_slice(&self.bytes)
216                        .iter()
217                        .map(|e: &u16| e.elem::<E>()),
218                ),
219                DType::U32 => Box::new(
220                    bytemuck::checked::cast_slice(&self.bytes)
221                        .iter()
222                        .map(|e: &u32| e.elem::<E>()),
223                ),
224                DType::U64 => Box::new(
225                    bytemuck::checked::cast_slice(&self.bytes)
226                        .iter()
227                        .map(|e: &u64| e.elem::<E>()),
228                ),
229                DType::BF16 => Box::new(
230                    bytemuck::checked::cast_slice(&self.bytes)
231                        .iter()
232                        .map(|e: &bf16| e.elem::<E>()),
233                ),
234                DType::F16 => Box::new(
235                    bytemuck::checked::cast_slice(&self.bytes)
236                        .iter()
237                        .map(|e: &f16| e.elem::<E>()),
238                ),
239                DType::F32 | DType::Flex32 => Box::new(
240                    bytemuck::checked::cast_slice(&self.bytes)
241                        .iter()
242                        .map(|e: &f32| e.elem::<E>()),
243                ),
244                DType::F64 => Box::new(
245                    bytemuck::checked::cast_slice(&self.bytes)
246                        .iter()
247                        .map(|e: &f64| e.elem::<E>()),
248                ),
249                // bool is a byte value equal to either 0 or 1
250                DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
251                DType::QFloat(scheme) => match scheme {
252                    QuantScheme {
253                        level: QuantLevel::Tensor | QuantLevel::Block(_),
254                        mode: QuantMode::Symmetric,
255                        value:
256                            QuantValue::Q8F
257                            | QuantValue::Q8S
258                            // Represent sub-byte values as i8
259                            | QuantValue::Q4F
260                            | QuantValue::Q4S
261                            | QuantValue::Q2F
262                            | QuantValue::Q2S,
263                        ..
264                    } => {
265                        // Quantized int8 values
266                        let q_bytes = QuantizedBytes {
267                            bytes: self.bytes.clone(),
268                            scheme,
269                            num_elements: self.num_elements(),
270                        };
271                        let (values, _) = q_bytes.into_vec_i8();
272
273                        Box::new(
274                            values
275                                .iter()
276                                .map(|e: &i8| e.elem::<E>())
277                                .collect::<Vec<_>>()
278                                .into_iter(),
279                        )
280                    }
281                    QuantScheme {
282                        level: QuantLevel::Tensor | QuantLevel::Block(_),
283                        mode: QuantMode::Symmetric,
284                        value:
285                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
286                        ..
287                    } => {
288                        unimplemented!("Not yet implemented for iteration");
289                    }
290                },
291            }
292        }
293    }
294
295    /// Returns the rank (the number of dimensions).
296    pub fn rank(&self) -> usize {
297        self.shape.len()
298    }
299
300    /// Returns the total number of elements of the tensor data.
301    pub fn num_elements(&self) -> usize {
302        Self::numel(&self.shape)
303    }
304
305    fn numel(shape: &[usize]) -> usize {
306        shape.iter().product()
307    }
308
309    /// Populates the data with random values.
310    pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
311        shape: S,
312        distribution: Distribution,
313        rng: &mut R,
314    ) -> Self {
315        let shape = shape.into();
316        let num_elements = Self::numel(&shape);
317        let mut data = Vec::with_capacity(num_elements);
318
319        for _ in 0..num_elements {
320            data.push(E::random(distribution, rng));
321        }
322
323        TensorData::new(data, shape)
324    }
325
326    /// Populates the data with zeros.
327    pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
328        let shape = shape.into();
329        let num_elements = Self::numel(&shape);
330        let mut data = Vec::<E>::with_capacity(num_elements);
331
332        for _ in 0..num_elements {
333            data.push(0.elem());
334        }
335
336        TensorData::new(data, shape)
337    }
338
339    /// Populates the data with ones.
340    pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
341        let shape = shape.into();
342        let num_elements = Self::numel(&shape);
343        let mut data = Vec::<E>::with_capacity(num_elements);
344
345        for _ in 0..num_elements {
346            data.push(1.elem());
347        }
348
349        TensorData::new(data, shape)
350    }
351
352    /// Populates the data with the given value
353    pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
354        let shape = shape.into();
355        let num_elements = Self::numel(&shape);
356        let mut data = Vec::<E>::with_capacity(num_elements);
357        for _ in 0..num_elements {
358            data.push(fill_value)
359        }
360
361        TensorData::new(data, shape)
362    }
363
364    #[allow(dead_code)]
365    /// Populates the data with the given value
366    pub fn full_dtype<E: Element, S: Into<Vec<usize>>>(
367        shape: S,
368        fill_value: E,
369        dtype: DType,
370    ) -> TensorData {
371        match dtype {
372            DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
373            DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
374            DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
375            DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
376            DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
377            DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
378            DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
379            DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
380            DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
381            DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
382            DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
383            DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
384            DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
385            DType::QFloat(_) => unreachable!(),
386        }
387    }
388
389    /// Converts the data to a different element type.
390    pub fn convert<E: Element>(self) -> Self {
391        self.convert_dtype(E::dtype())
392    }
393
394    /// Converts the data to a different element type.
395    pub fn convert_dtype(self, dtype: DType) -> Self {
396        if dtype == self.dtype {
397            self
398        } else if dtype.size() == self.dtype.size()
399            && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
400            && !matches!(dtype, DType::Bool | DType::QFloat(_))
401        {
402            match self.dtype {
403                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
404                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
405                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
406                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
407                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
408                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
409                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
410                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
411                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
412                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
413                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
414                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
415                DType::Bool | DType::QFloat(_) => unreachable!(),
416            }
417        } else {
418            match self.dtype {
419                DType::F64 => self.convert_clone_dtype::<f64>(dtype),
420                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
421                DType::F16 => self.convert_clone_dtype::<f16>(dtype),
422                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
423                DType::I64 => self.convert_clone_dtype::<i64>(dtype),
424                DType::I32 => self.convert_clone_dtype::<i32>(dtype),
425                DType::I16 => self.convert_clone_dtype::<i16>(dtype),
426                DType::I8 => self.convert_clone_dtype::<i8>(dtype),
427                DType::U64 => self.convert_clone_dtype::<u64>(dtype),
428                DType::U32 => self.convert_clone_dtype::<u32>(dtype),
429                DType::U16 => self.convert_clone_dtype::<u16>(dtype),
430                DType::U8 => self.convert_clone_dtype::<u8>(dtype),
431                DType::Bool => self.convert_clone_dtype::<bool>(dtype),
432                DType::QFloat(_) => unreachable!(),
433            }
434        }
435    }
436
437    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
438        match dtype {
439            DType::F64 => self.convert_inplace::<Current, f64>(),
440            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
441            DType::F16 => self.convert_inplace::<Current, f16>(),
442            DType::BF16 => self.convert_inplace::<Current, bf16>(),
443            DType::I64 => self.convert_inplace::<Current, i64>(),
444            DType::I32 => self.convert_inplace::<Current, i32>(),
445            DType::I16 => self.convert_inplace::<Current, i16>(),
446            DType::I8 => self.convert_inplace::<Current, i8>(),
447            DType::U64 => self.convert_inplace::<Current, u64>(),
448            DType::U32 => self.convert_inplace::<Current, u32>(),
449            DType::U16 => self.convert_inplace::<Current, u16>(),
450            DType::U8 => self.convert_inplace::<Current, u8>(),
451            DType::Bool | DType::QFloat(_) => unreachable!(),
452        }
453    }
454
455    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
456        mut self,
457    ) -> Self {
458        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
459            let t: Target = x.elem();
460            let x = cast_mut::<_, Target>(x);
461            *x = t;
462        }
463
464        self.dtype = Target::dtype();
465
466        self
467    }
468
469    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
470        match dtype {
471            DType::F64 => self.convert_clone::<Current, f64>(),
472            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
473            DType::F16 => self.convert_clone::<Current, f16>(),
474            DType::BF16 => self.convert_clone::<Current, bf16>(),
475            DType::I64 => self.convert_clone::<Current, i64>(),
476            DType::I32 => self.convert_clone::<Current, i32>(),
477            DType::I16 => self.convert_clone::<Current, i16>(),
478            DType::I8 => self.convert_clone::<Current, i8>(),
479            DType::U64 => self.convert_clone::<Current, u64>(),
480            DType::U32 => self.convert_clone::<Current, u32>(),
481            DType::U16 => self.convert_clone::<Current, u16>(),
482            DType::U8 => self.convert_clone::<Current, u8>(),
483            DType::Bool => self.convert_clone::<Current, bool>(),
484            DType::QFloat(_) => unreachable!(),
485        }
486    }
487
488    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
489        self,
490    ) -> Self {
491        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
492        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
493
494        for (x, out) in this.iter().zip(&mut out) {
495            *out = x.elem();
496        }
497
498        Self::new(out, self.shape)
499    }
500
501    /// Returns the data as a slice of bytes.
502    pub fn as_bytes(&self) -> &[u8] {
503        &self.bytes
504    }
505
506    /// Returns the bytes representation of the data.
507    pub fn into_bytes(self) -> Bytes {
508        self.bytes
509    }
510}
511
512impl<E: Element, const A: usize> From<[E; A]> for TensorData {
513    fn from(elems: [E; A]) -> Self {
514        TensorData::new(elems.to_vec(), [A])
515    }
516}
517
518impl<const A: usize> From<[usize; A]> for TensorData {
519    fn from(elems: [usize; A]) -> Self {
520        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
521    }
522}
523
524impl From<&[usize]> for TensorData {
525    fn from(elems: &[usize]) -> Self {
526        let mut data = Vec::with_capacity(elems.len());
527        for elem in elems.iter() {
528            data.push(*elem as i64);
529        }
530
531        TensorData::new(data, [elems.len()])
532    }
533}
534
535impl<E: Element> From<&[E]> for TensorData {
536    fn from(elems: &[E]) -> Self {
537        let mut data = Vec::with_capacity(elems.len());
538        for elem in elems.iter() {
539            data.push(*elem);
540        }
541
542        TensorData::new(data, [elems.len()])
543    }
544}
545
546impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
547    fn from(elems: [[E; B]; A]) -> Self {
548        let mut data = Vec::with_capacity(A * B);
549        for elem in elems.into_iter().take(A) {
550            for elem in elem.into_iter().take(B) {
551                data.push(elem);
552            }
553        }
554
555        TensorData::new(data, [A, B])
556    }
557}
558
559impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
560    for TensorData
561{
562    fn from(elems: [[[E; C]; B]; A]) -> Self {
563        let mut data = Vec::with_capacity(A * B * C);
564
565        for elem in elems.into_iter().take(A) {
566            for elem in elem.into_iter().take(B) {
567                for elem in elem.into_iter().take(C) {
568                    data.push(elem);
569                }
570            }
571        }
572
573        TensorData::new(data, [A, B, C])
574    }
575}
576
577impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
578    From<[[[[E; D]; C]; B]; A]> for TensorData
579{
580    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
581        let mut data = Vec::with_capacity(A * B * C * D);
582
583        for elem in elems.into_iter().take(A) {
584            for elem in elem.into_iter().take(B) {
585                for elem in elem.into_iter().take(C) {
586                    for elem in elem.into_iter().take(D) {
587                        data.push(elem);
588                    }
589                }
590            }
591        }
592
593        TensorData::new(data, [A, B, C, D])
594    }
595}
596
597impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
598    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
599{
600    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
601        let mut data = Vec::with_capacity(A * B * C * D * E);
602
603        for elem in elems.into_iter().take(A) {
604            for elem in elem.into_iter().take(B) {
605                for elem in elem.into_iter().take(C) {
606                    for elem in elem.into_iter().take(D) {
607                        for elem in elem.into_iter().take(E) {
608                            data.push(elem);
609                        }
610                    }
611                }
612            }
613        }
614
615        TensorData::new(data, [A, B, C, D, E])
616    }
617}
618impl core::fmt::Display for TensorData {
619    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
620        let fmt = match self.dtype {
621            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
622            DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
623            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
624            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
625            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
626            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
627            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
628            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
629            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
630            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
631            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
632            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
633            DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
634            DType::QFloat(scheme) => match scheme {
635                QuantScheme {
636                    level: QuantLevel::Tensor | QuantLevel::Block(_),
637                    mode: QuantMode::Symmetric,
638                    value:
639                        QuantValue::Q8F
640                        | QuantValue::Q8S
641                        // Display sub-byte values as i8
642                        | QuantValue::Q4F
643                        | QuantValue::Q4S
644                        | QuantValue::Q2F
645                        | QuantValue::Q2S,
646                    ..
647                } => {
648                    format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
649                },
650                QuantScheme {
651                        level: QuantLevel::Tensor | QuantLevel::Block(_),
652                        mode: QuantMode::Symmetric,
653                        value:
654                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
655                        ..
656                    } => {
657                        unimplemented!("Can't format yet");
658                    }
659            },
660        };
661        f.write_str(fmt.as_str())
662    }
663}
664
665/// The things that can go wrong when manipulating tensor data.
666#[derive(Debug)]
667pub enum DataError {
668    /// Failed to cast the values to a specified element type.
669    CastError(CheckedCastError),
670    /// Invalid target element type.
671    TypeMismatch(String),
672}
673
674impl core::error::Error for DataError {}
675
676impl core::fmt::Display for DataError {
677    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
678        f.write_str(format!("{self:?}").as_str())
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use alloc::vec;
686    use rand::{SeedableRng, rngs::StdRng};
687
688    #[test]
689    fn should_have_rank() {
690        let shape = [3, 5, 6];
691        let data = TensorData::random::<f32, _, _>(
692            shape,
693            Distribution::Default,
694            &mut StdRng::from_os_rng(),
695        );
696
697        assert_eq!(data.rank(), 3);
698    }
699
700    #[test]
701    fn into_vec_should_yield_same_value_as_iter() {
702        let shape = [3, 5, 6];
703        let data = TensorData::random::<f32, _, _>(
704            shape,
705            Distribution::Default,
706            &mut StdRng::from_os_rng(),
707        );
708
709        let expected = data.iter::<f32>().collect::<Vec<f32>>();
710        let actual = data.into_vec::<f32>().unwrap();
711
712        assert_eq!(expected, actual);
713    }
714
715    #[test]
716    #[should_panic]
717    fn into_vec_should_assert_wrong_dtype() {
718        let shape = [3, 5, 6];
719        let data = TensorData::random::<f32, _, _>(
720            shape,
721            Distribution::Default,
722            &mut StdRng::from_os_rng(),
723        );
724
725        data.into_vec::<i32>().unwrap();
726    }
727
728    #[test]
729    fn should_have_right_num_elements() {
730        let shape = [3, 5, 6];
731        let num_elements: usize = shape.iter().product();
732        let data = TensorData::random::<f32, _, _>(
733            shape,
734            Distribution::Default,
735            &mut StdRng::from_os_rng(),
736        );
737
738        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
739        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
740    }
741
742    #[test]
743    fn should_have_right_shape() {
744        let data = TensorData::from([[3.0, 5.0, 6.0]]);
745        assert_eq!(data.shape, vec![1, 3]);
746
747        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
748        assert_eq!(data.shape, vec![2, 3]);
749
750        let data = TensorData::from([3.0, 5.0, 6.0]);
751        assert_eq!(data.shape, vec![3]);
752    }
753
754    #[test]
755    fn should_convert_bytes_correctly() {
756        let mut vector: Vec<f32> = Vec::with_capacity(5);
757        vector.push(2.0);
758        vector.push(3.0);
759        let data1 = TensorData::new(vector, vec![2]);
760
761        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
762        assert_eq!(data1.bytes.len(), 2 * factor);
763        assert_eq!(data1.bytes.capacity(), 5 * factor);
764    }
765
766    #[test]
767    fn should_convert_bytes_correctly_inplace() {
768        fn test_precision<E: Element>() {
769            let data = TensorData::new((0..32).collect(), [32]);
770            for (i, val) in data
771                .clone()
772                .convert::<E>()
773                .into_vec::<E>()
774                .unwrap()
775                .into_iter()
776                .enumerate()
777            {
778                assert_eq!(i as u32, val.elem::<u32>())
779            }
780        }
781        test_precision::<f32>();
782        test_precision::<f16>();
783        test_precision::<i64>();
784        test_precision::<i32>();
785    }
786
787    macro_rules! test_dtypes {
788    ($test_name:ident, $($dtype:ty),*) => {
789        $(
790            paste::paste! {
791                #[test]
792                fn [<$test_name _ $dtype:snake>]() {
793                    let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
794                    let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
795                    assert_eq!(full_dtype, full);
796                }
797            }
798        )*
799    };
800}
801
802    test_dtypes!(
803        should_create_with_dtype,
804        bool,
805        i8,
806        i16,
807        i32,
808        i64,
809        u8,
810        u16,
811        u32,
812        u64,
813        f16,
814        bf16,
815        f32,
816        f64
817    );
818}