burn_tensor/tensor/
data.rs

1use core::{
2    any::{Any, TypeId},
3    f32,
4};
5
6use alloc::boxed::Box;
7use alloc::format;
8use alloc::string::String;
9use alloc::vec::Vec;
10use bytemuck::{checked::CheckedCastError, AnyBitPattern};
11use half::{bf16, f16};
12
13use crate::{
14    quantization::{
15        Quantization, QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes,
16    },
17    tensor::{bytes::Bytes, Shape},
18    DType, Distribution, Element, ElementConversion,
19};
20
21use num_traits::pow::Pow;
22
23#[cfg(not(feature = "std"))]
24#[allow(unused_imports)]
25use num_traits::Float;
26
27use rand::RngCore;
28
29/// The things that can go wrong when manipulating tensor data.
30#[derive(Debug)]
31pub enum DataError {
32    /// Failed to cast the values to a specified element type.
33    CastError(CheckedCastError),
34    /// Invalid target element type.
35    TypeMismatch(String),
36}
37
38/// Data structure for tensors.
39#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
40pub struct TensorData {
41    /// The values of the tensor (as bytes).
42    bytes: Bytes,
43
44    /// The shape of the tensor.
45    pub shape: Vec<usize>,
46
47    /// The data type of the tensor.
48    pub dtype: DType,
49}
50
51impl TensorData {
52    /// Creates a new tensor data structure.
53    pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
54        // Ensure shape is valid
55        let shape = shape.into();
56        Self::check_data_len(&value, &shape);
57
58        Self {
59            bytes: Bytes::from_elems(value),
60            shape,
61            dtype: E::dtype(),
62        }
63    }
64
65    /// Creates a new quantized tensor data structure.
66    pub fn quantized<E: Element, S: Into<Vec<usize>>>(
67        value: Vec<E>,
68        shape: S,
69        strategy: QuantizationStrategy,
70    ) -> Self {
71        let shape = shape.into();
72        Self::check_data_len(&value, &shape);
73
74        let q_bytes = QuantizedBytes::new(value, strategy);
75
76        Self {
77            bytes: q_bytes.bytes,
78            shape,
79            dtype: DType::QFloat(q_bytes.scheme),
80        }
81    }
82
83    /// Creates a new tensor data structure from raw bytes.
84    ///
85    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
86    /// certain that the bytes representation is valid.
87    pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
88        Self {
89            bytes: Bytes::from_bytes_vec(bytes),
90            shape: shape.into(),
91            dtype,
92        }
93    }
94
95    // Check that the input vector contains a correct number of elements
96    fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
97        let expected_data_len = Self::numel(shape);
98        let num_data = data.len();
99        assert_eq!(
100            expected_data_len, num_data,
101            "Shape {:?} is invalid for input of size {:?}",
102            shape, num_data,
103        );
104    }
105
106    fn try_as_slice<E: Element>(&self) -> Result<&[E], DataError> {
107        bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
108    }
109
110    /// Returns the immutable slice view of the tensor data.
111    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
112        if E::dtype() == self.dtype {
113            self.try_as_slice()
114        } else {
115            Err(DataError::TypeMismatch(format!(
116                "Invalid target element type (expected {:?}, got {:?})",
117                self.dtype,
118                E::dtype()
119            )))
120        }
121    }
122
123    /// Returns the mutable slice view of the tensor data.
124    ///
125    /// # Panics
126    /// If the target element type is different from the stored element type.
127    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
128        if E::dtype() == self.dtype {
129            bytemuck::checked::try_cast_slice_mut(&mut self.bytes).map_err(DataError::CastError)
130        } else {
131            Err(DataError::TypeMismatch(format!(
132                "Invalid target element type (expected {:?}, got {:?})",
133                self.dtype,
134                E::dtype()
135            )))
136        }
137    }
138
139    /// Returns the tensor data as a vector of scalar values.
140    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
141        Ok(self.as_slice()?.to_vec())
142    }
143
144    /// Returns the tensor data as a vector of scalar values.
145    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
146        // This means we cannot call `into_vec` for QFloat
147        if E::dtype() != self.dtype {
148            return Err(DataError::TypeMismatch(format!(
149                "Invalid target element type (expected {:?}, got {:?})",
150                self.dtype,
151                E::dtype()
152            )));
153        }
154
155        let mut me = self;
156        me.bytes = match me.bytes.try_into_vec::<E>() {
157            Ok(elems) => return Ok(elems),
158            Err(bytes) => bytes,
159        };
160        // The bytes might have been deserialized and allocated with a different align.
161        // In that case, we have to memcopy the data into a new vector, more suitably allocated
162        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
163            .map_err(DataError::CastError)?
164            .to_vec())
165    }
166
167    /// Returns an iterator over the values of the tensor data.
168    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
169        if E::dtype() == self.dtype {
170            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
171        } else {
172            match self.dtype {
173                DType::I8 => Box::new(
174                    bytemuck::checked::cast_slice(&self.bytes)
175                        .iter()
176                        .map(|e: &i8| e.elem::<E>()),
177                ),
178                DType::I16 => Box::new(
179                    bytemuck::checked::cast_slice(&self.bytes)
180                        .iter()
181                        .map(|e: &i16| e.elem::<E>()),
182                ),
183                DType::I32 => Box::new(
184                    bytemuck::checked::cast_slice(&self.bytes)
185                        .iter()
186                        .map(|e: &i32| e.elem::<E>()),
187                ),
188                DType::I64 => Box::new(
189                    bytemuck::checked::cast_slice(&self.bytes)
190                        .iter()
191                        .map(|e: &i64| e.elem::<E>()),
192                ),
193                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
194                DType::U16 => Box::new(
195                    bytemuck::checked::cast_slice(&self.bytes)
196                        .iter()
197                        .map(|e: &u16| e.elem::<E>()),
198                ),
199                DType::U32 => Box::new(
200                    bytemuck::checked::cast_slice(&self.bytes)
201                        .iter()
202                        .map(|e: &u32| e.elem::<E>()),
203                ),
204                DType::U64 => Box::new(
205                    bytemuck::checked::cast_slice(&self.bytes)
206                        .iter()
207                        .map(|e: &u64| e.elem::<E>()),
208                ),
209                DType::BF16 => Box::new(
210                    bytemuck::checked::cast_slice(&self.bytes)
211                        .iter()
212                        .map(|e: &bf16| e.elem::<E>()),
213                ),
214                DType::F16 => Box::new(
215                    bytemuck::checked::cast_slice(&self.bytes)
216                        .iter()
217                        .map(|e: &f16| e.elem::<E>()),
218                ),
219                DType::F32 => Box::new(
220                    bytemuck::checked::cast_slice(&self.bytes)
221                        .iter()
222                        .map(|e: &f32| e.elem::<E>()),
223                ),
224                DType::F64 => Box::new(
225                    bytemuck::checked::cast_slice(&self.bytes)
226                        .iter()
227                        .map(|e: &f64| e.elem::<E>()),
228                ),
229                // bool is a byte value equal to either 0 or 1
230                DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
231                DType::QFloat(scheme) => match scheme {
232                    QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
233                    | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
234                        // Quantized int8 values
235                        let q_bytes = QuantizedBytes {
236                            bytes: self.bytes.clone(),
237                            scheme,
238                            num_elements: self.num_elements(),
239                        };
240                        let (values, _) = q_bytes.into_vec_i8();
241
242                        Box::new(
243                            values
244                                .iter()
245                                .map(|e: &i8| e.elem::<E>())
246                                .collect::<Vec<_>>()
247                                .into_iter(),
248                        )
249                    }
250                },
251            }
252        }
253    }
254
255    /// Returns the total number of elements of the tensor data.
256    pub fn num_elements(&self) -> usize {
257        Self::numel(&self.shape)
258    }
259
260    fn numel(shape: &[usize]) -> usize {
261        shape.iter().product()
262    }
263
264    /// Populates the data with random values.
265    pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
266        shape: S,
267        distribution: Distribution,
268        rng: &mut R,
269    ) -> Self {
270        let shape = shape.into();
271        let num_elements = Self::numel(&shape);
272        let mut data = Vec::with_capacity(num_elements);
273
274        for _ in 0..num_elements {
275            data.push(E::random(distribution, rng));
276        }
277
278        TensorData::new(data, shape)
279    }
280
281    /// Populates the data with zeros.
282    pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
283        let shape = shape.into();
284        let num_elements = Self::numel(&shape);
285        let mut data = Vec::<E>::with_capacity(num_elements);
286
287        for _ in 0..num_elements {
288            data.push(0.elem());
289        }
290
291        TensorData::new(data, shape)
292    }
293
294    /// Populates the data with ones.
295    pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
296        let shape = shape.into();
297        let num_elements = Self::numel(&shape);
298        let mut data = Vec::<E>::with_capacity(num_elements);
299
300        for _ in 0..num_elements {
301            data.push(1.elem());
302        }
303
304        TensorData::new(data, shape)
305    }
306
307    /// Populates the data with the given value
308    pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
309        let shape = shape.into();
310        let num_elements = Self::numel(&shape);
311        let mut data = Vec::<E>::with_capacity(num_elements);
312        for _ in 0..num_elements {
313            data.push(fill_value)
314        }
315
316        TensorData::new(data, shape)
317    }
318
319    /// Converts the data to a different element type.
320    pub fn convert<E: Element>(self) -> Self {
321        if E::dtype() == self.dtype {
322            self
323        } else if core::mem::size_of::<E>() == self.dtype.size()
324            && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
325        {
326            match self.dtype {
327                DType::F64 => self.convert_inplace::<f64, E>(),
328                DType::F32 => self.convert_inplace::<f32, E>(),
329                DType::F16 => self.convert_inplace::<f16, E>(),
330                DType::BF16 => self.convert_inplace::<bf16, E>(),
331                DType::I64 => self.convert_inplace::<i64, E>(),
332                DType::I32 => self.convert_inplace::<i32, E>(),
333                DType::I16 => self.convert_inplace::<i16, E>(),
334                DType::I8 => self.convert_inplace::<i8, E>(),
335                DType::U64 => self.convert_inplace::<u64, E>(),
336                DType::U32 => self.convert_inplace::<u32, E>(),
337                DType::U16 => self.convert_inplace::<u16, E>(),
338                DType::U8 => self.convert_inplace::<u8, E>(),
339                DType::Bool | DType::QFloat(_) => unreachable!(),
340            }
341        } else {
342            TensorData::new(self.iter::<E>().collect(), self.shape)
343        }
344    }
345
346    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element>(mut self) -> Self {
347        let step = core::mem::size_of::<Current>();
348
349        for offset in 0..(self.bytes.len() / step) {
350            let start = offset * step;
351            let end = start + step;
352
353            let slice_old = &mut self.bytes[start..end];
354            let val: Current = *bytemuck::from_bytes(slice_old);
355            let val = &val.elem::<Target>();
356            let slice_new = bytemuck::bytes_of(val);
357
358            slice_old.clone_from_slice(slice_new);
359        }
360        self.dtype = Target::dtype();
361
362        self
363    }
364
365    /// Returns the data as a slice of bytes.
366    pub fn as_bytes(&self) -> &[u8] {
367        &self.bytes
368    }
369
370    /// Returns the bytes representation of the data.
371    pub fn into_bytes(self) -> Bytes {
372        self.bytes
373    }
374
375    /// Applies the data quantization strategy.
376    ///
377    /// # Panics
378    ///
379    /// Panics if the data type is not supported for quantization.
380    pub fn with_quantization(self, quantization: QuantizationStrategy) -> Self {
381        assert_eq!(
382            self.dtype,
383            DType::F32,
384            "Only f32 data type can be quantized"
385        );
386        match &quantization {
387            QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized(
388                strategy.quantize(self.as_slice().unwrap()),
389                self.shape,
390                quantization,
391            ),
392            QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized(
393                strategy.quantize(self.as_slice().unwrap()),
394                self.shape,
395                quantization,
396            ),
397        }
398    }
399
400    /// Dequantizes the data according to its quantization scheme.
401    pub fn dequantize(self) -> Result<Self, DataError> {
402        if let DType::QFloat(scheme) = self.dtype {
403            let num_elements = self.num_elements();
404            let q_bytes = QuantizedBytes {
405                bytes: self.bytes,
406                scheme,
407                num_elements,
408            };
409
410            let values = q_bytes.dequantize().0;
411            Ok(Self::new(values, self.shape))
412        } else {
413            Err(DataError::TypeMismatch(format!(
414                "Expected quantized data, got {:?}",
415                self.dtype
416            )))
417        }
418    }
419
420    /// Asserts the data is approximately equal to another data.
421    ///
422    /// # Arguments
423    ///
424    /// * `other` - The other data.
425    /// * `precision` - The precision of the comparison.
426    ///
427    /// # Panics
428    ///
429    /// Panics if the data is not approximately equal.
430    #[track_caller]
431    pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
432        let tolerance = 0.1.pow(precision as f64);
433
434        self.assert_approx_eq_diff(other, tolerance)
435    }
436
437    /// Asserts the data is equal to another data.
438    ///
439    /// # Arguments
440    ///
441    /// * `other` - The other data.
442    /// * `strict` - If true, the data types must the be same.
443    ///              Otherwise, the comparison is done in the current data type.
444    ///
445    /// # Panics
446    ///
447    /// Panics if the data is not equal.
448    #[track_caller]
449    pub fn assert_eq(&self, other: &Self, strict: bool) {
450        if strict {
451            assert_eq!(
452                self.dtype, other.dtype,
453                "Data types differ ({:?} != {:?})",
454                self.dtype, other.dtype
455            );
456        }
457
458        match self.dtype {
459            DType::F64 => self.assert_eq_elem::<f64>(other),
460            DType::F32 => self.assert_eq_elem::<f32>(other),
461            DType::F16 => self.assert_eq_elem::<f16>(other),
462            DType::BF16 => self.assert_eq_elem::<bf16>(other),
463            DType::I64 => self.assert_eq_elem::<i64>(other),
464            DType::I32 => self.assert_eq_elem::<i32>(other),
465            DType::I16 => self.assert_eq_elem::<i16>(other),
466            DType::I8 => self.assert_eq_elem::<i8>(other),
467            DType::U64 => self.assert_eq_elem::<u64>(other),
468            DType::U32 => self.assert_eq_elem::<u32>(other),
469            DType::U16 => self.assert_eq_elem::<u16>(other),
470            DType::U8 => self.assert_eq_elem::<u8>(other),
471            DType::Bool => self.assert_eq_elem::<bool>(other),
472            DType::QFloat(q) => {
473                // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
474                let q_other = if let DType::QFloat(q_other) = other.dtype {
475                    q_other
476                } else {
477                    panic!("Quantized data differs from other not quantized data")
478                };
479                match (q, q_other) {
480                    (
481                        QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
482                        QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
483                    )
484                    | (
485                        QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
486                        QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
487                    ) => self.assert_eq_elem::<i8>(other),
488                    _ => panic!("Quantization schemes differ ({:?} != {:?})", q, q_other),
489                }
490            }
491        }
492    }
493
494    #[track_caller]
495    fn assert_eq_elem<E: Element>(&self, other: &Self) {
496        let mut message = String::new();
497        if self.shape != other.shape {
498            message += format!(
499                "\n  => Shape is different: {:?} != {:?}",
500                self.shape, other.shape
501            )
502            .as_str();
503        }
504
505        let mut num_diff = 0;
506        let max_num_diff = 5;
507        for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
508            if a.cmp(&b).is_ne() {
509                // Only print the first 5 different values.
510                if num_diff < max_num_diff {
511                    message += format!("\n  => Position {i}: {a} != {b}").as_str();
512                }
513                num_diff += 1;
514            }
515        }
516
517        if num_diff >= max_num_diff {
518            message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
519        }
520
521        if !message.is_empty() {
522            panic!("Tensors are not eq:{}", message);
523        }
524    }
525
526    /// Asserts the data is approximately equal to another data.
527    ///
528    /// # Arguments
529    ///
530    /// * `other` - The other data.
531    /// * `tolerance` - The tolerance of the comparison.
532    ///
533    /// # Panics
534    ///
535    /// Panics if the data is not approximately equal.
536    #[track_caller]
537    pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) {
538        let mut message = String::new();
539        if self.shape != other.shape {
540            message += format!(
541                "\n  => Shape is different: {:?} != {:?}",
542                self.shape, other.shape
543            )
544            .as_str();
545        }
546
547        let iter = self.iter::<f64>().zip(other.iter::<f64>());
548
549        let mut num_diff = 0;
550        let max_num_diff = 5;
551
552        for (i, (a, b)) in iter.enumerate() {
553            //if they are both nan, then they are equally nan
554            let both_nan = a.is_nan() && b.is_nan();
555            //this works for both infinities
556            let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.));
557
558            if both_nan || both_inf {
559                continue;
560            }
561
562            let err = (a - b).abs();
563
564            if self.dtype.is_float() {
565                if let Some((err, tolerance)) = compare_floats(a, b, self.dtype, tolerance) {
566                    // Only print the first 5 different values.
567                    if num_diff < max_num_diff {
568                        message += format!(
569                            "\n  => Position {i}: {a} != {b} | difference {err} > tolerance \
570                         {tolerance}"
571                        )
572                        .as_str();
573                    }
574                    num_diff += 1;
575                }
576            } else if err > tolerance || err.is_nan() {
577                // Only print the first 5 different values.
578                if num_diff < max_num_diff {
579                    message += format!(
580                        "\n  => Position {i}: {a} != {b} | difference {err} > tolerance \
581                         {tolerance}"
582                    )
583                    .as_str();
584                }
585                num_diff += 1;
586            }
587        }
588
589        if num_diff >= max_num_diff {
590            message += format!("\n{} more errors...", num_diff - 5).as_str();
591        }
592
593        if !message.is_empty() {
594            panic!("Tensors are not approx eq:{}", message);
595        }
596    }
597
598    /// Asserts each value is within a given range.
599    ///
600    /// # Arguments
601    ///
602    /// * `range` - The range.
603    ///
604    /// # Panics
605    ///
606    /// If any value is not within the half-open range bounded inclusively below
607    /// and exclusively above (`start..end`).
608    pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
609        let start = range.start.elem::<f32>();
610        let end = range.end.elem::<f32>();
611
612        for elem in self.iter::<f32>() {
613            if elem < start || elem >= end {
614                panic!("Element ({elem:?}) is not within range {range:?}");
615            }
616        }
617    }
618
619    /// Asserts each value is within a given inclusive range.
620    ///
621    /// # Arguments
622    ///
623    /// * `range` - The range.
624    ///
625    /// # Panics
626    ///
627    /// If any value is not within the half-open range bounded inclusively (`start..=end`).
628    pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
629        let start = range.start().elem::<f32>();
630        let end = range.end().elem::<f32>();
631
632        for elem in self.iter::<f32>() {
633            if elem < start || elem > end {
634                panic!("Element ({elem:?}) is not within range {range:?}");
635            }
636        }
637    }
638}
639
640impl<E: Element, const A: usize> From<[E; A]> for TensorData {
641    fn from(elems: [E; A]) -> Self {
642        TensorData::new(elems.to_vec(), [A])
643    }
644}
645
646impl<const A: usize> From<[usize; A]> for TensorData {
647    fn from(elems: [usize; A]) -> Self {
648        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
649    }
650}
651
652impl From<&[usize]> for TensorData {
653    fn from(elems: &[usize]) -> Self {
654        let mut data = Vec::with_capacity(elems.len());
655        for elem in elems.iter() {
656            data.push(*elem as i64);
657        }
658
659        TensorData::new(data, [elems.len()])
660    }
661}
662
663impl<E: Element> From<&[E]> for TensorData {
664    fn from(elems: &[E]) -> Self {
665        let mut data = Vec::with_capacity(elems.len());
666        for elem in elems.iter() {
667            data.push(*elem);
668        }
669
670        TensorData::new(data, [elems.len()])
671    }
672}
673
674impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
675    fn from(elems: [[E; B]; A]) -> Self {
676        let mut data = Vec::with_capacity(A * B);
677        for elem in elems.into_iter().take(A) {
678            for elem in elem.into_iter().take(B) {
679                data.push(elem);
680            }
681        }
682
683        TensorData::new(data, [A, B])
684    }
685}
686
687impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
688    for TensorData
689{
690    fn from(elems: [[[E; C]; B]; A]) -> Self {
691        let mut data = Vec::with_capacity(A * B * C);
692
693        for elem in elems.into_iter().take(A) {
694            for elem in elem.into_iter().take(B) {
695                for elem in elem.into_iter().take(C) {
696                    data.push(elem);
697                }
698            }
699        }
700
701        TensorData::new(data, [A, B, C])
702    }
703}
704
705impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
706    From<[[[[E; D]; C]; B]; A]> for TensorData
707{
708    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
709        let mut data = Vec::with_capacity(A * B * C * D);
710
711        for elem in elems.into_iter().take(A) {
712            for elem in elem.into_iter().take(B) {
713                for elem in elem.into_iter().take(C) {
714                    for elem in elem.into_iter().take(D) {
715                        data.push(elem);
716                    }
717                }
718            }
719        }
720
721        TensorData::new(data, [A, B, C, D])
722    }
723}
724
725impl<
726        Elem: Element,
727        const A: usize,
728        const B: usize,
729        const C: usize,
730        const D: usize,
731        const E: usize,
732    > From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
733{
734    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
735        let mut data = Vec::with_capacity(A * B * C * D * E);
736
737        for elem in elems.into_iter().take(A) {
738            for elem in elem.into_iter().take(B) {
739                for elem in elem.into_iter().take(C) {
740                    for elem in elem.into_iter().take(D) {
741                        for elem in elem.into_iter().take(E) {
742                            data.push(elem);
743                        }
744                    }
745                }
746            }
747        }
748
749        TensorData::new(data, [A, B, C, D, E])
750    }
751}
752
753impl core::fmt::Display for TensorData {
754    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
755        let fmt = match self.dtype {
756            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
757            DType::F32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
758            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
759            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
760            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
761            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
762            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
763            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
764            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
765            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
766            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
767            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
768            DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
769            DType::QFloat(scheme) => match scheme {
770                QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
771                | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
772                    format!("{:?} {scheme:?}", self.try_as_slice::<i8>().unwrap())
773                }
774            },
775        };
776        f.write_str(fmt.as_str())
777    }
778}
779
780/// Data structure for serializing and deserializing tensor data.
781#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)]
782#[deprecated(
783    since = "0.14.0",
784    note = "the internal data format has changed, please use `TensorData` instead"
785)]
786pub struct DataSerialize<E> {
787    /// The values of the tensor.
788    pub value: Vec<E>,
789    /// The shape of the tensor.
790    pub shape: Vec<usize>,
791}
792
793/// Data structure for tensors.
794#[derive(new, Debug, Clone, PartialEq, Eq)]
795#[deprecated(
796    since = "0.14.0",
797    note = "the internal data format has changed, please use `TensorData` instead"
798)]
799pub struct Data<E, const D: usize> {
800    /// The values of the tensor.
801    pub value: Vec<E>,
802
803    /// The shape of the tensor.
804    pub shape: Shape,
805}
806
807#[allow(deprecated)]
808impl<const D: usize, E: Element> Data<E, D> {
809    /// Converts the data to a different element type.
810    pub fn convert<EOther: Element>(self) -> Data<EOther, D> {
811        let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
812
813        Data {
814            value,
815            shape: self.shape,
816        }
817    }
818
819    /// Asserts each value is within a given range.
820    ///
821    /// # Arguments
822    ///
823    /// * `range` - The range.
824    ///
825    /// # Panics
826    ///
827    /// If any value is not within the half-open range bounded inclusively below
828    /// and exclusively above (`start..end`).
829    pub fn assert_within_range<EOther: Element>(&self, range: core::ops::Range<EOther>) {
830        let start = range.start.elem::<f32>();
831        let end = range.end.elem::<f32>();
832
833        for elem in self.value.iter() {
834            let elem = elem.elem::<f32>();
835            if elem < start || elem >= end {
836                panic!("Element ({elem:?}) is not within range {range:?}");
837            }
838        }
839    }
840}
841
842#[allow(deprecated)]
843impl<E: Element> DataSerialize<E> {
844    /// Converts the data to a different element type.
845    pub fn convert<EOther: Element>(self) -> DataSerialize<EOther> {
846        if TypeId::of::<E>() == TypeId::of::<EOther>() {
847            let cast: Box<dyn Any> = Box::new(self);
848            let cast: Box<DataSerialize<EOther>> = cast.downcast().unwrap();
849            return *cast;
850        }
851
852        let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
853
854        DataSerialize {
855            value,
856            shape: self.shape,
857        }
858    }
859
860    /// Converts the data to the new [TensorData] format.
861    pub fn into_tensor_data(self) -> TensorData {
862        TensorData::new(self.value, self.shape)
863    }
864}
865
866#[allow(deprecated)]
867impl<E: Element, const D: usize> Data<E, D> {
868    /// Populates the data with random values.
869    pub fn random<R: RngCore>(shape: Shape, distribution: Distribution, rng: &mut R) -> Self {
870        let num_elements = shape.num_elements();
871        let mut data = Vec::with_capacity(num_elements);
872
873        for _ in 0..num_elements {
874            data.push(E::random(distribution, rng));
875        }
876
877        Data::new(data, shape)
878    }
879}
880
881#[allow(deprecated)]
882impl<E: core::fmt::Debug, const D: usize> Data<E, D>
883where
884    E: Element,
885{
886    /// Populates the data with zeros.
887    pub fn zeros<S: Into<Shape>>(shape: S) -> Data<E, D> {
888        let shape = shape.into();
889        let num_elements = shape.num_elements();
890        let mut data = Vec::with_capacity(num_elements);
891
892        for _ in 0..num_elements {
893            data.push(0.elem());
894        }
895
896        Data::new(data, shape)
897    }
898}
899
900#[allow(deprecated)]
901impl<E: core::fmt::Debug, const D: usize> Data<E, D>
902where
903    E: Element,
904{
905    /// Populates the data with ones.
906    pub fn ones(shape: Shape) -> Data<E, D> {
907        let num_elements = shape.num_elements();
908        let mut data = Vec::with_capacity(num_elements);
909
910        for _ in 0..num_elements {
911            data.push(1.elem());
912        }
913
914        Data::new(data, shape)
915    }
916}
917
918#[allow(deprecated)]
919impl<E: core::fmt::Debug, const D: usize> Data<E, D>
920where
921    E: Element,
922{
923    /// Populates the data with the given value
924    pub fn full(shape: Shape, fill_value: E) -> Data<E, D> {
925        let num_elements = shape.num_elements();
926        let mut data = Vec::with_capacity(num_elements);
927        for _ in 0..num_elements {
928            data.push(fill_value)
929        }
930
931        Data::new(data, shape)
932    }
933}
934
935#[allow(deprecated)]
936impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
937    /// Serializes the data.
938    ///
939    /// # Returns
940    ///
941    /// The serialized data.
942    pub fn serialize(&self) -> DataSerialize<E> {
943        DataSerialize {
944            value: self.value.clone(),
945            shape: self.shape.dims.to_vec(),
946        }
947    }
948}
949
950#[allow(deprecated)]
951impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data<E, D> {
952    /// Asserts the data is approximately equal to another data.
953    ///
954    /// # Arguments
955    ///
956    /// * `other` - The other data.
957    /// * `precision` - The precision of the comparison.
958    ///
959    /// # Panics
960    ///
961    /// Panics if the data is not approximately equal.
962    #[track_caller]
963    pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
964        let tolerance = 0.1.pow(precision as f64);
965
966        self.assert_approx_eq_diff(other, tolerance)
967    }
968
969    /// Asserts the data is approximately equal to another data.
970    ///
971    /// # Arguments
972    ///
973    /// * `other` - The other data.
974    /// * `tolerance` - The tolerance of the comparison.
975    ///
976    /// # Panics
977    ///
978    /// Panics if the data is not approximately equal.
979    #[track_caller]
980    pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) {
981        let mut message = String::new();
982        if self.shape != other.shape {
983            message += format!(
984                "\n  => Shape is different: {:?} != {:?}",
985                self.shape.dims, other.shape.dims
986            )
987            .as_str();
988        }
989
990        let iter = self.value.clone().into_iter().zip(other.value.clone());
991
992        let mut num_diff = 0;
993        let max_num_diff = 5;
994
995        for (i, (a, b)) in iter.enumerate() {
996            let a: f64 = a.into();
997            let b: f64 = b.into();
998
999            //if they are both nan, then they are equally nan
1000            let both_nan = a.is_nan() && b.is_nan();
1001            //this works for both infinities
1002            let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.));
1003
1004            if both_nan || both_inf {
1005                continue;
1006            }
1007
1008            let err = (a - b).abs();
1009
1010            if E::dtype().is_float() {
1011                if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) {
1012                    // Only print the first 5 different values.
1013                    if num_diff < max_num_diff {
1014                        message += format!(
1015                            "\n  => Position {i}: {a} != {b} | difference {err} > tolerance \
1016                         {tolerance}"
1017                        )
1018                        .as_str();
1019                    }
1020                    num_diff += 1;
1021                }
1022            } else if err > tolerance || err.is_nan() {
1023                // Only print the first 5 different values.
1024                if num_diff < max_num_diff {
1025                    message += format!(
1026                        "\n  => Position {i}: {a} != {b} | difference {err} > tolerance \
1027                         {tolerance}"
1028                    )
1029                    .as_str();
1030                }
1031                num_diff += 1;
1032            }
1033        }
1034
1035        if num_diff >= max_num_diff {
1036            message += format!("\n{} more errors...", num_diff - 5).as_str();
1037        }
1038
1039        if !message.is_empty() {
1040            panic!("Tensors are not approx eq:{}", message);
1041        }
1042    }
1043}
1044
1045#[allow(deprecated)]
1046impl<const D: usize> Data<usize, D> {
1047    /// Converts the usize data to a different element type.
1048    pub fn from_usize<O: num_traits::FromPrimitive>(self) -> Data<O, D> {
1049        let value: Vec<O> = self
1050            .value
1051            .into_iter()
1052            .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap())
1053            .collect();
1054
1055        Data {
1056            value,
1057            shape: self.shape,
1058        }
1059    }
1060}
1061
1062#[allow(deprecated)]
1063impl<E: Clone, const D: usize> From<&DataSerialize<E>> for Data<E, D> {
1064    fn from(data: &DataSerialize<E>) -> Self {
1065        let mut dims = [0; D];
1066        dims[..D].copy_from_slice(&data.shape[..D]);
1067        Data::new(data.value.clone(), Shape::new(dims))
1068    }
1069}
1070
1071#[allow(deprecated)]
1072impl<E, const D: usize> From<DataSerialize<E>> for Data<E, D> {
1073    fn from(data: DataSerialize<E>) -> Self {
1074        let mut dims = [0; D];
1075        dims[..D].copy_from_slice(&data.shape[..D]);
1076        Data::new(data.value, Shape::new(dims))
1077    }
1078}
1079
1080#[allow(deprecated)]
1081impl<E: core::fmt::Debug + Copy, const A: usize> From<[E; A]> for Data<E, 1> {
1082    fn from(elems: [E; A]) -> Self {
1083        let mut data = Vec::with_capacity(2 * A);
1084        for elem in elems.into_iter() {
1085            data.push(elem);
1086        }
1087
1088        Data::new(data, Shape::new([A]))
1089    }
1090}
1091
1092#[allow(deprecated)]
1093impl<E: core::fmt::Debug + Copy> From<&[E]> for Data<E, 1> {
1094    fn from(elems: &[E]) -> Self {
1095        let mut data = Vec::with_capacity(elems.len());
1096        for elem in elems.iter() {
1097            data.push(*elem);
1098        }
1099
1100        Data::new(data, Shape::new([elems.len()]))
1101    }
1102}
1103
1104#[allow(deprecated)]
1105impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[E; B]; A]> for Data<E, 2> {
1106    fn from(elems: [[E; B]; A]) -> Self {
1107        let mut data = Vec::with_capacity(A * B);
1108        for elem in elems.into_iter().take(A) {
1109            for elem in elem.into_iter().take(B) {
1110                data.push(elem);
1111            }
1112        }
1113
1114        Data::new(data, Shape::new([A, B]))
1115    }
1116}
1117
1118#[allow(deprecated)]
1119impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
1120    From<[[[E; C]; B]; A]> for Data<E, 3>
1121{
1122    fn from(elems: [[[E; C]; B]; A]) -> Self {
1123        let mut data = Vec::with_capacity(A * B * C);
1124
1125        for elem in elems.into_iter().take(A) {
1126            for elem in elem.into_iter().take(B) {
1127                for elem in elem.into_iter().take(C) {
1128                    data.push(elem);
1129                }
1130            }
1131        }
1132
1133        Data::new(data, Shape::new([A, B, C]))
1134    }
1135}
1136
1137#[allow(deprecated)]
1138impl<
1139        E: core::fmt::Debug + Copy,
1140        const A: usize,
1141        const B: usize,
1142        const C: usize,
1143        const D: usize,
1144    > From<[[[[E; D]; C]; B]; A]> for Data<E, 4>
1145{
1146    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
1147        let mut data = Vec::with_capacity(A * B * C * D);
1148
1149        for elem in elems.into_iter().take(A) {
1150            for elem in elem.into_iter().take(B) {
1151                for elem in elem.into_iter().take(C) {
1152                    for elem in elem.into_iter().take(D) {
1153                        data.push(elem);
1154                    }
1155                }
1156            }
1157        }
1158
1159        Data::new(data, Shape::new([A, B, C, D]))
1160    }
1161}
1162
1163#[allow(deprecated)]
1164impl<E: core::fmt::Debug, const D: usize> core::fmt::Display for Data<E, D> {
1165    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1166        f.write_str(format!("{:?}", &self.value).as_str())
1167    }
1168}
1169
1170fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> {
1171    let epsilon_deviations = tolerance / f32::EPSILON as f64;
1172    let epsilon = match ty {
1173        DType::F64 => f32::EPSILON as f64, // Don't increase precision beyond `f32`, see below
1174        DType::F32 => f32::EPSILON as f64,
1175        DType::F16 => half::f16::EPSILON.to_f64(),
1176        DType::BF16 => half::bf16::EPSILON.to_f64(),
1177        _ => unreachable!(),
1178    };
1179    let tolerance_norm = epsilon_deviations * epsilon;
1180    // Clamp to 1.0 so we don't require more precision than `tolerance`. This is because literals
1181    // have a fixed number of digits, so increasing precision breaks things
1182    let value_abs = value.abs().max(1.0);
1183    let tolerance_adjusted = tolerance_norm * value_abs;
1184
1185    let err = (value - other).abs();
1186
1187    if err > tolerance_adjusted || err.is_nan() {
1188        Some((err, tolerance_adjusted))
1189    } else {
1190        None
1191    }
1192}
1193
1194#[cfg(test)]
1195#[allow(deprecated)]
1196mod tests {
1197    use crate::quantization::AffineQuantization;
1198
1199    use super::*;
1200    use alloc::vec;
1201    use rand::{rngs::StdRng, SeedableRng};
1202
1203    #[test]
1204    fn into_vec_should_yield_same_value_as_iter() {
1205        let shape = Shape::new([3, 5, 6]);
1206        let data = TensorData::random::<f32, _, _>(
1207            shape,
1208            Distribution::Default,
1209            &mut StdRng::from_entropy(),
1210        );
1211
1212        let expected = data.iter::<f32>().collect::<Vec<f32>>();
1213        let actual = data.into_vec::<f32>().unwrap();
1214
1215        assert_eq!(expected, actual);
1216    }
1217
1218    #[test]
1219    #[should_panic]
1220    fn into_vec_should_assert_wrong_dtype() {
1221        let shape = Shape::new([3, 5, 6]);
1222        let data = TensorData::random::<f32, _, _>(
1223            shape,
1224            Distribution::Default,
1225            &mut StdRng::from_entropy(),
1226        );
1227
1228        data.into_vec::<i32>().unwrap();
1229    }
1230
1231    #[test]
1232    fn should_have_right_num_elements() {
1233        let shape = Shape::new([3, 5, 6]);
1234        let num_elements = shape.num_elements();
1235        let data = TensorData::random::<f32, _, _>(
1236            shape,
1237            Distribution::Default,
1238            &mut StdRng::from_entropy(),
1239        );
1240
1241        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
1242        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
1243    }
1244
1245    #[test]
1246    fn should_have_right_shape() {
1247        let data = TensorData::from([[3.0, 5.0, 6.0]]);
1248        assert_eq!(data.shape, vec![1, 3]);
1249
1250        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
1251        assert_eq!(data.shape, vec![2, 3]);
1252
1253        let data = TensorData::from([3.0, 5.0, 6.0]);
1254        assert_eq!(data.shape, vec![3]);
1255    }
1256
1257    #[test]
1258    fn should_assert_appox_eq_limit() {
1259        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1260        let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
1261
1262        data1.assert_approx_eq(&data2, 2);
1263    }
1264
1265    #[test]
1266    #[should_panic]
1267    fn should_assert_approx_eq_above_limit() {
1268        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1269        let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
1270
1271        data1.assert_approx_eq(&data2, 2);
1272    }
1273
1274    #[test]
1275    #[should_panic]
1276    fn should_assert_appox_eq_check_shape() {
1277        let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1278        let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
1279
1280        data1.assert_approx_eq(&data2, 2);
1281    }
1282
1283    #[test]
1284    fn should_convert_bytes_correctly() {
1285        let mut vector: Vec<f32> = Vec::with_capacity(5);
1286        vector.push(2.0);
1287        vector.push(3.0);
1288        let data1 = TensorData::new(vector, vec![2]);
1289
1290        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
1291        assert_eq!(data1.bytes.len(), 2 * factor);
1292        assert_eq!(data1.bytes.capacity(), 5 * factor);
1293    }
1294
1295    #[test]
1296    fn should_convert_bytes_correctly_inplace() {
1297        fn test_precision<E: Element>() {
1298            let data = TensorData::new((0..32).collect(), [32]);
1299            for (i, val) in data
1300                .clone()
1301                .convert::<E>()
1302                .into_vec::<E>()
1303                .unwrap()
1304                .into_iter()
1305                .enumerate()
1306            {
1307                assert_eq!(i as u32, val.elem::<u32>())
1308            }
1309        }
1310        test_precision::<f32>();
1311        test_precision::<f16>();
1312        test_precision::<i64>();
1313        test_precision::<i32>();
1314    }
1315
1316    #[test]
1317    #[should_panic = "Expected quantized data"]
1318    fn should_not_dequantize() {
1319        let data = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1320        data.dequantize().unwrap();
1321    }
1322
1323    #[test]
1324    fn should_support_dequantize() {
1325        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
1326        let data = TensorData::quantized(
1327            vec![-128i8, -77, -26, 25, 76, 127],
1328            [2, 3],
1329            QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)),
1330        );
1331
1332        let output = data.dequantize().unwrap();
1333
1334        output.assert_approx_eq(&TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), 4);
1335    }
1336}