burn_tensor/tensor/
data.rs

1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use half::{bf16, f16};
9use num_traits::{Float, ToPrimitive};
10
11use crate::{
12    DType, Distribution, Element, ElementConversion,
13    quantization::{QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes},
14    tensor::bytes::Bytes,
15};
16
17use rand::RngCore;
18
19/// The things that can go wrong when manipulating tensor data.
20#[derive(Debug)]
21pub enum DataError {
22    /// Failed to cast the values to a specified element type.
23    CastError(CheckedCastError),
24    /// Invalid target element type.
25    TypeMismatch(String),
26}
27
28/// Data structure for tensors.
29#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
30pub struct TensorData {
31    /// The values of the tensor (as bytes).
32    pub bytes: Bytes,
33
34    /// The shape of the tensor.
35    pub shape: Vec<usize>,
36
37    /// The data type of the tensor.
38    pub dtype: DType,
39}
40
41impl TensorData {
42    /// Creates a new tensor data structure.
43    pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
44        // Ensure shape is valid
45        let shape = shape.into();
46        Self::check_data_len(&value, &shape);
47
48        Self {
49            bytes: Bytes::from_elems(value),
50            shape,
51            dtype: E::dtype(),
52        }
53    }
54
55    /// Creates a new quantized tensor data structure.
56    pub fn quantized<E: Element, S: Into<Vec<usize>>>(
57        value: Vec<E>,
58        shape: S,
59        strategy: QuantizationStrategy,
60    ) -> Self {
61        let shape = shape.into();
62        Self::check_data_len(&value, &shape);
63
64        let q_bytes = QuantizedBytes::new(value, strategy);
65
66        Self {
67            bytes: q_bytes.bytes,
68            shape,
69            dtype: DType::QFloat(q_bytes.scheme),
70        }
71    }
72
73    /// Creates a new tensor data structure from raw bytes.
74    ///
75    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
76    /// certain that the bytes representation is valid.
77    pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
78        Self {
79            bytes: Bytes::from_bytes_vec(bytes),
80            shape: shape.into(),
81            dtype,
82        }
83    }
84
85    // Check that the input vector contains a correct number of elements
86    fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
87        let expected_data_len = Self::numel(shape);
88        let num_data = data.len();
89        assert_eq!(
90            expected_data_len, num_data,
91            "Shape {:?} is invalid for input of size {:?}",
92            shape, num_data,
93        );
94    }
95
96    /// Returns the immutable slice view of the tensor data.
97    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
98        if E::dtype() == self.dtype {
99            match E::dtype() {
100                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
101                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
102                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
103                DType::Bool => {
104                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
105                        .map_err(DataError::CastError)?;
106                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
107                }
108                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
109            }
110        } else {
111            Err(DataError::TypeMismatch(format!(
112                "Invalid target element type (expected {:?}, got {:?})",
113                self.dtype,
114                E::dtype()
115            )))
116        }
117    }
118
119    /// Returns the mutable slice view of the tensor data.
120    ///
121    /// # Panics
122    /// If the target element type is different from the stored element type.
123    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
124        if E::dtype() == self.dtype {
125            match E::dtype() {
126                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
127                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
128                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
129                DType::Bool => {
130                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
131                        .map_err(DataError::CastError)?;
132                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
133                }
134                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
135                    .map_err(DataError::CastError),
136            }
137        } else {
138            Err(DataError::TypeMismatch(format!(
139                "Invalid target element type (expected {:?}, got {:?})",
140                self.dtype,
141                E::dtype()
142            )))
143        }
144    }
145
146    /// Returns the tensor data as a vector of scalar values.
147    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
148        Ok(self.as_slice()?.to_vec())
149    }
150
151    /// Returns the tensor data as a vector of scalar values.
152    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
153        // This means we cannot call `into_vec` for QFloat
154        if E::dtype() != self.dtype {
155            return Err(DataError::TypeMismatch(format!(
156                "Invalid target element type (expected {:?}, got {:?})",
157                self.dtype,
158                E::dtype()
159            )));
160        }
161
162        match E::dtype() {
163            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
164            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
165            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
166            DType::Bool => {
167                let vec = self.into_vec_unchecked::<u8>()?;
168                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
169            }
170            _ => self.into_vec_unchecked(),
171        }
172    }
173
174    /// Returns the tensor data as a vector of scalar values. Does not check dtype.
175    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
176        let mut me = self;
177        me.bytes = match me.bytes.try_into_vec::<E>() {
178            Ok(elems) => return Ok(elems),
179            Err(bytes) => bytes,
180        };
181        // The bytes might have been deserialized and allocated with a different align.
182        // In that case, we have to memcopy the data into a new vector, more suitably allocated
183        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
184            .map_err(DataError::CastError)?
185            .to_vec())
186    }
187
188    /// Returns an iterator over the values of the tensor data.
189    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
190        if E::dtype() == self.dtype {
191            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
192        } else {
193            match self.dtype {
194                DType::I8 => Box::new(
195                    bytemuck::checked::cast_slice(&self.bytes)
196                        .iter()
197                        .map(|e: &i8| e.elem::<E>()),
198                ),
199                DType::I16 => Box::new(
200                    bytemuck::checked::cast_slice(&self.bytes)
201                        .iter()
202                        .map(|e: &i16| e.elem::<E>()),
203                ),
204                DType::I32 => Box::new(
205                    bytemuck::checked::cast_slice(&self.bytes)
206                        .iter()
207                        .map(|e: &i32| e.elem::<E>()),
208                ),
209                DType::I64 => Box::new(
210                    bytemuck::checked::cast_slice(&self.bytes)
211                        .iter()
212                        .map(|e: &i64| e.elem::<E>()),
213                ),
214                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
215                DType::U16 => Box::new(
216                    bytemuck::checked::cast_slice(&self.bytes)
217                        .iter()
218                        .map(|e: &u16| e.elem::<E>()),
219                ),
220                DType::U32 => Box::new(
221                    bytemuck::checked::cast_slice(&self.bytes)
222                        .iter()
223                        .map(|e: &u32| e.elem::<E>()),
224                ),
225                DType::U64 => Box::new(
226                    bytemuck::checked::cast_slice(&self.bytes)
227                        .iter()
228                        .map(|e: &u64| e.elem::<E>()),
229                ),
230                DType::BF16 => Box::new(
231                    bytemuck::checked::cast_slice(&self.bytes)
232                        .iter()
233                        .map(|e: &bf16| e.elem::<E>()),
234                ),
235                DType::F16 => Box::new(
236                    bytemuck::checked::cast_slice(&self.bytes)
237                        .iter()
238                        .map(|e: &f16| e.elem::<E>()),
239                ),
240                DType::F32 | DType::Flex32 => Box::new(
241                    bytemuck::checked::cast_slice(&self.bytes)
242                        .iter()
243                        .map(|e: &f32| e.elem::<E>()),
244                ),
245                DType::F64 => Box::new(
246                    bytemuck::checked::cast_slice(&self.bytes)
247                        .iter()
248                        .map(|e: &f64| e.elem::<E>()),
249                ),
250                // bool is a byte value equal to either 0 or 1
251                DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
252                DType::QFloat(scheme) => match scheme {
253                    QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
254                        // Quantized int8 values
255                        let q_bytes = QuantizedBytes {
256                            bytes: self.bytes.clone(),
257                            scheme,
258                            num_elements: self.num_elements(),
259                        };
260                        let (values, _) = q_bytes.into_vec_i8();
261
262                        Box::new(
263                            values
264                                .iter()
265                                .map(|e: &i8| e.elem::<E>())
266                                .collect::<Vec<_>>()
267                                .into_iter(),
268                        )
269                    }
270                },
271            }
272        }
273    }
274
275    /// Returns the total number of elements of the tensor data.
276    pub fn num_elements(&self) -> usize {
277        Self::numel(&self.shape)
278    }
279
280    fn numel(shape: &[usize]) -> usize {
281        shape.iter().product()
282    }
283
284    /// Populates the data with random values.
285    pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
286        shape: S,
287        distribution: Distribution,
288        rng: &mut R,
289    ) -> Self {
290        let shape = shape.into();
291        let num_elements = Self::numel(&shape);
292        let mut data = Vec::with_capacity(num_elements);
293
294        for _ in 0..num_elements {
295            data.push(E::random(distribution, rng));
296        }
297
298        TensorData::new(data, shape)
299    }
300
301    /// Populates the data with zeros.
302    pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
303        let shape = shape.into();
304        let num_elements = Self::numel(&shape);
305        let mut data = Vec::<E>::with_capacity(num_elements);
306
307        for _ in 0..num_elements {
308            data.push(0.elem());
309        }
310
311        TensorData::new(data, shape)
312    }
313
314    /// Populates the data with ones.
315    pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
316        let shape = shape.into();
317        let num_elements = Self::numel(&shape);
318        let mut data = Vec::<E>::with_capacity(num_elements);
319
320        for _ in 0..num_elements {
321            data.push(1.elem());
322        }
323
324        TensorData::new(data, shape)
325    }
326
327    /// Populates the data with the given value
328    pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
329        let shape = shape.into();
330        let num_elements = Self::numel(&shape);
331        let mut data = Vec::<E>::with_capacity(num_elements);
332        for _ in 0..num_elements {
333            data.push(fill_value)
334        }
335
336        TensorData::new(data, shape)
337    }
338
339    /// Converts the data to a different element type.
340    pub fn convert<E: Element>(self) -> Self {
341        self.convert_dtype(E::dtype())
342    }
343
344    /// Converts the data to a different element type.
345    pub fn convert_dtype(self, dtype: DType) -> Self {
346        if dtype == self.dtype {
347            self
348        } else if dtype.size() == self.dtype.size()
349            && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
350            && !matches!(dtype, DType::Bool | DType::QFloat(_))
351        {
352            match self.dtype {
353                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
354                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
355                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
356                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
357                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
358                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
359                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
360                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
361                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
362                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
363                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
364                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
365                DType::Bool | DType::QFloat(_) => unreachable!(),
366            }
367        } else {
368            match self.dtype {
369                DType::F64 => self.convert_clone_dtype::<f64>(dtype),
370                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
371                DType::F16 => self.convert_clone_dtype::<f16>(dtype),
372                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
373                DType::I64 => self.convert_clone_dtype::<i64>(dtype),
374                DType::I32 => self.convert_clone_dtype::<i32>(dtype),
375                DType::I16 => self.convert_clone_dtype::<i16>(dtype),
376                DType::I8 => self.convert_clone_dtype::<i8>(dtype),
377                DType::U64 => self.convert_clone_dtype::<u64>(dtype),
378                DType::U32 => self.convert_clone_dtype::<u32>(dtype),
379                DType::U16 => self.convert_clone_dtype::<u16>(dtype),
380                DType::U8 => self.convert_clone_dtype::<u8>(dtype),
381                DType::Bool => self.convert_clone_dtype::<bool>(dtype),
382                DType::QFloat(_) => unreachable!(),
383            }
384        }
385    }
386
387    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
388        match dtype {
389            DType::F64 => self.convert_inplace::<Current, f64>(),
390            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
391            DType::F16 => self.convert_inplace::<Current, f16>(),
392            DType::BF16 => self.convert_inplace::<Current, bf16>(),
393            DType::I64 => self.convert_inplace::<Current, i64>(),
394            DType::I32 => self.convert_inplace::<Current, i32>(),
395            DType::I16 => self.convert_inplace::<Current, i16>(),
396            DType::I8 => self.convert_inplace::<Current, i8>(),
397            DType::U64 => self.convert_inplace::<Current, u64>(),
398            DType::U32 => self.convert_inplace::<Current, u32>(),
399            DType::U16 => self.convert_inplace::<Current, u16>(),
400            DType::U8 => self.convert_inplace::<Current, u8>(),
401            DType::Bool | DType::QFloat(_) => unreachable!(),
402        }
403    }
404
405    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
406        mut self,
407    ) -> Self {
408        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
409            let t: Target = x.elem();
410            let x = cast_mut::<_, Target>(x);
411            *x = t;
412        }
413
414        self.dtype = Target::dtype();
415
416        self
417    }
418
419    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
420        match dtype {
421            DType::F64 => self.convert_clone::<Current, f64>(),
422            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
423            DType::F16 => self.convert_clone::<Current, f16>(),
424            DType::BF16 => self.convert_clone::<Current, bf16>(),
425            DType::I64 => self.convert_clone::<Current, i64>(),
426            DType::I32 => self.convert_clone::<Current, i32>(),
427            DType::I16 => self.convert_clone::<Current, i16>(),
428            DType::I8 => self.convert_clone::<Current, i8>(),
429            DType::U64 => self.convert_clone::<Current, u64>(),
430            DType::U32 => self.convert_clone::<Current, u32>(),
431            DType::U16 => self.convert_clone::<Current, u16>(),
432            DType::U8 => self.convert_clone::<Current, u8>(),
433            DType::Bool => self.convert_clone::<Current, bool>(),
434            DType::QFloat(_) => unreachable!(),
435        }
436    }
437
438    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
439        self,
440    ) -> Self {
441        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
442        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
443
444        for (x, out) in this.iter().zip(&mut out) {
445            *out = x.elem();
446        }
447
448        Self::new(out, self.shape)
449    }
450
451    /// Returns the data as a slice of bytes.
452    pub fn as_bytes(&self) -> &[u8] {
453        &self.bytes
454    }
455
456    /// Returns the bytes representation of the data.
457    pub fn into_bytes(self) -> Bytes {
458        self.bytes
459    }
460
461    /// Applies the data quantization strategy.
462    ///
463    /// # Panics
464    ///
465    /// Panics if the data type is not supported for quantization.
466    pub fn with_quantization(self, quantization: QuantizationStrategy) -> Self {
467        assert_eq!(
468            self.dtype,
469            DType::F32,
470            "Only f32 data type can be quantized"
471        );
472        let values = quantization.quantize(self.as_slice().unwrap());
473        TensorData::quantized(values, self.shape, quantization)
474    }
475
476    /// Dequantizes the data according to its quantization scheme.
477    pub fn dequantize(self) -> Result<Self, DataError> {
478        if let DType::QFloat(scheme) = self.dtype {
479            let num_elements = self.num_elements();
480            let q_bytes = QuantizedBytes {
481                bytes: self.bytes,
482                scheme,
483                num_elements,
484            };
485
486            let values = q_bytes.dequantize().0;
487            Ok(Self::new(values, self.shape))
488        } else {
489            Err(DataError::TypeMismatch(format!(
490                "Expected quantized data, got {:?}",
491                self.dtype
492            )))
493        }
494    }
495
496    /// Asserts the data is equal to another data.
497    ///
498    /// # Arguments
499    ///
500    /// * `other` - The other data.
501    /// * `strict` - If true, the data types must the be same.
502    ///   Otherwise, the comparison is done in the current data type.
503    ///
504    /// # Panics
505    ///
506    /// Panics if the data is not equal.
507    #[track_caller]
508    pub fn assert_eq(&self, other: &Self, strict: bool) {
509        if strict {
510            assert_eq!(
511                self.dtype, other.dtype,
512                "Data types differ ({:?} != {:?})",
513                self.dtype, other.dtype
514            );
515        }
516
517        match self.dtype {
518            DType::F64 => self.assert_eq_elem::<f64>(other),
519            DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
520            DType::F16 => self.assert_eq_elem::<f16>(other),
521            DType::BF16 => self.assert_eq_elem::<bf16>(other),
522            DType::I64 => self.assert_eq_elem::<i64>(other),
523            DType::I32 => self.assert_eq_elem::<i32>(other),
524            DType::I16 => self.assert_eq_elem::<i16>(other),
525            DType::I8 => self.assert_eq_elem::<i8>(other),
526            DType::U64 => self.assert_eq_elem::<u64>(other),
527            DType::U32 => self.assert_eq_elem::<u32>(other),
528            DType::U16 => self.assert_eq_elem::<u16>(other),
529            DType::U8 => self.assert_eq_elem::<u8>(other),
530            DType::Bool => self.assert_eq_elem::<bool>(other),
531            DType::QFloat(q) => {
532                // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
533                let q_other = if let DType::QFloat(q_other) = other.dtype {
534                    q_other
535                } else {
536                    panic!("Quantized data differs from other not quantized data")
537                };
538                match (q, q_other) {
539                    (
540                        QuantizationScheme::PerTensor(mode, QuantizationType::QInt8),
541                        QuantizationScheme::PerTensor(mode_other, QuantizationType::QInt8),
542                    ) if mode == mode_other => self.assert_eq_elem::<i8>(other),
543                    _ => panic!("Quantization schemes differ ({:?} != {:?})", q, q_other),
544                }
545            }
546        }
547    }
548
549    #[track_caller]
550    fn assert_eq_elem<E: Element>(&self, other: &Self) {
551        let mut message = String::new();
552        if self.shape != other.shape {
553            message += format!(
554                "\n  => Shape is different: {:?} != {:?}",
555                self.shape, other.shape
556            )
557            .as_str();
558        }
559
560        let mut num_diff = 0;
561        let max_num_diff = 5;
562        for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
563            if a.cmp(&b).is_ne() {
564                // Only print the first 5 different values.
565                if num_diff < max_num_diff {
566                    message += format!("\n  => Position {i}: {a} != {b}").as_str();
567                }
568                num_diff += 1;
569            }
570        }
571
572        if num_diff >= max_num_diff {
573            message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
574        }
575
576        if !message.is_empty() {
577            panic!("Tensors are not eq:{}", message);
578        }
579    }
580
581    /// Asserts the data is approximately equal to another data.
582    ///
583    /// # Arguments
584    ///
585    /// * `other` - The other data.
586    /// * `tolerance` - The tolerance of the comparison.
587    ///
588    /// # Panics
589    ///
590    /// Panics if the data is not approximately equal.
591    #[track_caller]
592    pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
593        let mut message = String::new();
594        if self.shape != other.shape {
595            message += format!(
596                "\n  => Shape is different: {:?} != {:?}",
597                self.shape, other.shape
598            )
599            .as_str();
600        }
601
602        let iter = self.iter::<F>().zip(other.iter::<F>());
603
604        let mut num_diff = 0;
605        let max_num_diff = 5;
606
607        for (i, (a, b)) in iter.enumerate() {
608            //if they are both nan, then they are equally nan
609            let both_nan = a.is_nan() && b.is_nan();
610            //this works for both infinities
611            let both_inf =
612                a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
613
614            if both_nan || both_inf {
615                continue;
616            }
617
618            if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
619                // Only print the first 5 different values.
620                if num_diff < max_num_diff {
621                    let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
622                    let diff_rel = diff_abs / ToPrimitive::to_f64(&(a + b).abs()).unwrap();
623
624                    let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
625                    let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
626
627                    message += format!(
628                        "\n  => Position {i}: {a} != {b}\n     diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
629                    )
630                    .as_str();
631                }
632                num_diff += 1;
633            }
634        }
635
636        if num_diff >= max_num_diff {
637            message += format!("\n{} more errors...", num_diff - 5).as_str();
638        }
639
640        if !message.is_empty() {
641            panic!("Tensors are not approx eq:{}", message);
642        }
643    }
644
645    /// Asserts each value is within a given range.
646    ///
647    /// # Arguments
648    ///
649    /// * `range` - The range.
650    ///
651    /// # Panics
652    ///
653    /// If any value is not within the half-open range bounded inclusively below
654    /// and exclusively above (`start..end`).
655    pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
656        let start = range.start.elem::<f32>();
657        let end = range.end.elem::<f32>();
658
659        for elem in self.iter::<f32>() {
660            if elem < start || elem >= end {
661                panic!("Element ({elem:?}) is not within range {range:?}");
662            }
663        }
664    }
665
666    /// Asserts each value is within a given inclusive range.
667    ///
668    /// # Arguments
669    ///
670    /// * `range` - The range.
671    ///
672    /// # Panics
673    ///
674    /// If any value is not within the half-open range bounded inclusively (`start..=end`).
675    pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
676        let start = range.start().elem::<f32>();
677        let end = range.end().elem::<f32>();
678
679        for elem in self.iter::<f32>() {
680            if elem < start || elem > end {
681                panic!("Element ({elem:?}) is not within range {range:?}");
682            }
683        }
684    }
685}
686
687impl<E: Element, const A: usize> From<[E; A]> for TensorData {
688    fn from(elems: [E; A]) -> Self {
689        TensorData::new(elems.to_vec(), [A])
690    }
691}
692
693impl<const A: usize> From<[usize; A]> for TensorData {
694    fn from(elems: [usize; A]) -> Self {
695        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
696    }
697}
698
699impl From<&[usize]> for TensorData {
700    fn from(elems: &[usize]) -> Self {
701        let mut data = Vec::with_capacity(elems.len());
702        for elem in elems.iter() {
703            data.push(*elem as i64);
704        }
705
706        TensorData::new(data, [elems.len()])
707    }
708}
709
710impl<E: Element> From<&[E]> for TensorData {
711    fn from(elems: &[E]) -> Self {
712        let mut data = Vec::with_capacity(elems.len());
713        for elem in elems.iter() {
714            data.push(*elem);
715        }
716
717        TensorData::new(data, [elems.len()])
718    }
719}
720
721impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
722    fn from(elems: [[E; B]; A]) -> Self {
723        let mut data = Vec::with_capacity(A * B);
724        for elem in elems.into_iter().take(A) {
725            for elem in elem.into_iter().take(B) {
726                data.push(elem);
727            }
728        }
729
730        TensorData::new(data, [A, B])
731    }
732}
733
734impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
735    for TensorData
736{
737    fn from(elems: [[[E; C]; B]; A]) -> Self {
738        let mut data = Vec::with_capacity(A * B * C);
739
740        for elem in elems.into_iter().take(A) {
741            for elem in elem.into_iter().take(B) {
742                for elem in elem.into_iter().take(C) {
743                    data.push(elem);
744                }
745            }
746        }
747
748        TensorData::new(data, [A, B, C])
749    }
750}
751
752impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
753    From<[[[[E; D]; C]; B]; A]> for TensorData
754{
755    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
756        let mut data = Vec::with_capacity(A * B * C * D);
757
758        for elem in elems.into_iter().take(A) {
759            for elem in elem.into_iter().take(B) {
760                for elem in elem.into_iter().take(C) {
761                    for elem in elem.into_iter().take(D) {
762                        data.push(elem);
763                    }
764                }
765            }
766        }
767
768        TensorData::new(data, [A, B, C, D])
769    }
770}
771
772impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
773    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
774{
775    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
776        let mut data = Vec::with_capacity(A * B * C * D * E);
777
778        for elem in elems.into_iter().take(A) {
779            for elem in elem.into_iter().take(B) {
780                for elem in elem.into_iter().take(C) {
781                    for elem in elem.into_iter().take(D) {
782                        for elem in elem.into_iter().take(E) {
783                            data.push(elem);
784                        }
785                    }
786                }
787            }
788        }
789
790        TensorData::new(data, [A, B, C, D, E])
791    }
792}
793
794impl core::fmt::Display for TensorData {
795    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
796        let fmt = match self.dtype {
797            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
798            DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
799            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
800            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
801            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
802            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
803            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
804            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
805            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
806            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
807            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
808            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
809            DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
810            DType::QFloat(scheme) => match scheme {
811                QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
812                    format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
813                }
814            },
815        };
816        f.write_str(fmt.as_str())
817    }
818}
819
820/// The tolerance used to compare to floating point numbers.
821///
822/// Generally, two numbers `x` and `y` are approximately equal if
823///
824/// ```text
825/// |x - y| < max(R * (|x + y|), A)
826/// ```
827///
828/// where `R` is the relative tolerance and `A` is the absolute tolerance.
829///
830///
831/// The most common way to initialize this struct is to use `Tolerance::<F>::default()`.
832/// In that case, the relative and absolute tolerances are computed using an heuristic based
833/// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`.
834///
835/// Another common initialization is `Tolerance::<F>::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`.
836/// This will use a sane default to manage values too close to 0.0 and
837/// use different relative tolerances depending on the floating point precision.
838#[derive(Debug, Clone, Copy)]
839pub struct Tolerance<F> {
840    relative: F,
841    absolute: F,
842}
843
844impl<F: Float> Default for Tolerance<F> {
845    fn default() -> Self {
846        Self {
847            relative: F::from(64).unwrap() * F::epsilon(),
848            absolute: F::from(16).unwrap() * F::min_positive_value(),
849        }
850    }
851}
852
853impl<F: Float> Tolerance<F> {
854    /// When comparing two numbers, this uses both the relative and absolute differences.
855    ///
856    /// That is, `x` and `y` are approximately equal if
857    ///
858    /// ```text
859    /// |x - y| < max(R * (|x + y|), A)
860    /// ```
861    ///
862    /// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance.
863    pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
864        let relative = Self::check_relative(relative);
865        let absolute = Self::check_absolute(absolute);
866
867        Self { relative, absolute }
868    }
869
870    /// When comparing two numbers, this uses only the relative difference.
871    ///
872    /// That is, `x` and `y` are approximately equal if
873    ///
874    /// ```text
875    /// |x - y| < R * (|x + y|)
876    /// ```
877    ///
878    /// where `R` is the relative `tolerance`.
879    pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
880        let relative = Self::check_relative(tolerance);
881
882        Self {
883            relative,
884            absolute: F::from(0.0).unwrap(),
885        }
886    }
887
888    /// When comparing two numbers, this uses only the absolute difference.
889    ///
890    /// That is, `x` and `y` are approximately equal if
891    ///
892    /// ```text
893    /// |x - y| < A
894    /// ```
895    ///
896    /// where `A` is the absolute `tolerance`.
897    pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
898        let absolute = Self::check_absolute(tolerance);
899
900        Self {
901            relative: F::from(0.0).unwrap(),
902            absolute,
903        }
904    }
905
906    /// Change the relative tolerance to the given one.
907    pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
908        self.relative = Self::check_relative(tolerance);
909        self
910    }
911
912    /// Change the relative tolerance to the given one only if `F` is half precision.
913    pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
914        if core::mem::size_of::<F>() == 2 {
915            self.relative = Self::check_relative(tolerance);
916        }
917        self
918    }
919
920    /// Change the relative tolerance to the given one only if `F` is single precision.
921    pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
922        if core::mem::size_of::<F>() == 4 {
923            self.relative = Self::check_relative(tolerance);
924        }
925        self
926    }
927
928    /// Change the relative tolerance to the given one only if `F` is double precision.
929    pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
930        if core::mem::size_of::<F>() == 8 {
931            self.relative = Self::check_relative(tolerance);
932        }
933        self
934    }
935
936    /// Change the absolute tolerance to the given one.
937    pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
938        self.absolute = Self::check_absolute(tolerance);
939        self
940    }
941
942    /// Change the absolute tolerance to the given one only if `F` is half precision.
943    pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
944        if core::mem::size_of::<F>() == 2 {
945            self.absolute = Self::check_absolute(tolerance);
946        }
947        self
948    }
949
950    /// Change the absolute tolerance to the given one only if `F` is single precision.
951    pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
952        if core::mem::size_of::<F>() == 4 {
953            self.absolute = Self::check_absolute(tolerance);
954        }
955        self
956    }
957
958    /// Change the absolute tolerance to the given one only if `F` is double precision.
959    pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
960        if core::mem::size_of::<F>() == 8 {
961            self.absolute = Self::check_absolute(tolerance);
962        }
963        self
964    }
965
966    /// Checks if `x` and `y` are approximately equal given the tolerance.
967    pub fn approx_eq(&self, x: F, y: F) -> bool {
968        // See the accepted answer here
969        // https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison
970
971        // This also handles the case where both a and b are infinity so that we don't need
972        // to manage it in the rest of the function.
973        if x == y {
974            return true;
975        }
976
977        let diff = (x - y).abs();
978
979        let norm = (x + y).abs();
980        let norm = norm.min(F::max_value()); // In case |a + b| -> inf.
981
982        diff < self.absolute.max(self.relative * norm)
983    }
984
985    fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
986        let tolerance = F::from(tolerance).unwrap();
987        assert!(tolerance <= F::one());
988        tolerance
989    }
990
991    fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
992        let tolerance = F::from(tolerance).unwrap();
993        assert!(tolerance >= F::zero());
994        tolerance
995    }
996}
997
998#[cfg(test)]
999mod tests {
1000    use crate::{Shape, quantization::SymmetricQuantization};
1001
1002    use super::*;
1003    use alloc::vec;
1004    use rand::{SeedableRng, rngs::StdRng};
1005
1006    #[test]
1007    fn into_vec_should_yield_same_value_as_iter() {
1008        let shape = Shape::new([3, 5, 6]);
1009        let data = TensorData::random::<f32, _, _>(
1010            shape,
1011            Distribution::Default,
1012            &mut StdRng::from_os_rng(),
1013        );
1014
1015        let expected = data.iter::<f32>().collect::<Vec<f32>>();
1016        let actual = data.into_vec::<f32>().unwrap();
1017
1018        assert_eq!(expected, actual);
1019    }
1020
1021    #[test]
1022    #[should_panic]
1023    fn into_vec_should_assert_wrong_dtype() {
1024        let shape = Shape::new([3, 5, 6]);
1025        let data = TensorData::random::<f32, _, _>(
1026            shape,
1027            Distribution::Default,
1028            &mut StdRng::from_os_rng(),
1029        );
1030
1031        data.into_vec::<i32>().unwrap();
1032    }
1033
1034    #[test]
1035    fn should_have_right_num_elements() {
1036        let shape = Shape::new([3, 5, 6]);
1037        let num_elements = shape.num_elements();
1038        let data = TensorData::random::<f32, _, _>(
1039            shape,
1040            Distribution::Default,
1041            &mut StdRng::from_os_rng(),
1042        );
1043
1044        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
1045        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
1046    }
1047
1048    #[test]
1049    fn should_have_right_shape() {
1050        let data = TensorData::from([[3.0, 5.0, 6.0]]);
1051        assert_eq!(data.shape, vec![1, 3]);
1052
1053        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
1054        assert_eq!(data.shape, vec![2, 3]);
1055
1056        let data = TensorData::from([3.0, 5.0, 6.0]);
1057        assert_eq!(data.shape, vec![3]);
1058    }
1059
1060    #[test]
1061    fn should_assert_appox_eq_limit() {
1062        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1063        let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
1064
1065        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
1066        data1.assert_approx_eq::<half::f16>(&data2, Tolerance::absolute(3e-2));
1067    }
1068
1069    #[test]
1070    #[should_panic]
1071    fn should_assert_approx_eq_above_limit() {
1072        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1073        let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
1074
1075        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1076    }
1077
1078    #[test]
1079    #[should_panic]
1080    fn should_assert_approx_eq_check_shape() {
1081        let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1082        let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
1083
1084        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1085    }
1086
1087    #[test]
1088    fn should_convert_bytes_correctly() {
1089        let mut vector: Vec<f32> = Vec::with_capacity(5);
1090        vector.push(2.0);
1091        vector.push(3.0);
1092        let data1 = TensorData::new(vector, vec![2]);
1093
1094        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
1095        assert_eq!(data1.bytes.len(), 2 * factor);
1096        assert_eq!(data1.bytes.capacity(), 5 * factor);
1097    }
1098
1099    #[test]
1100    fn should_convert_bytes_correctly_inplace() {
1101        fn test_precision<E: Element>() {
1102            let data = TensorData::new((0..32).collect(), [32]);
1103            for (i, val) in data
1104                .clone()
1105                .convert::<E>()
1106                .into_vec::<E>()
1107                .unwrap()
1108                .into_iter()
1109                .enumerate()
1110            {
1111                assert_eq!(i as u32, val.elem::<u32>())
1112            }
1113        }
1114        test_precision::<f32>();
1115        test_precision::<f16>();
1116        test_precision::<i64>();
1117        test_precision::<i32>();
1118    }
1119
1120    #[test]
1121    #[should_panic = "Expected quantized data"]
1122    fn should_not_dequantize() {
1123        let data = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1124        data.dequantize().unwrap();
1125    }
1126
1127    #[test]
1128    fn should_support_dequantize() {
1129        let data = TensorData::quantized(
1130            vec![-127i8, -77, -26, 25, 76, 127],
1131            [2, 3],
1132            QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)),
1133        );
1134
1135        let output = data.dequantize().unwrap();
1136
1137        output.assert_approx_eq::<f32>(
1138            &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1139            Tolerance::default(),
1140        );
1141
1142        output.assert_approx_eq::<f16>(
1143            &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1144            Tolerance::default(),
1145        );
1146    }
1147}