burn_tensor/tensor/
data.rs

1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use cubecl_quant::scheme::QuantScheme;
9use half::{bf16, f16};
10use num_traits::{Float, ToPrimitive};
11
12use crate::{
13    DType, Distribution, Element, ElementConversion,
14    quantization::{QuantValue, QuantizationStrategy, QuantizedBytes},
15    tensor::Bytes,
16};
17
18use rand::RngCore;
19
20use super::quantization::{QuantLevel, QuantMode};
21
22/// The things that can go wrong when manipulating tensor data.
23#[derive(Debug)]
24pub enum DataError {
25    /// Failed to cast the values to a specified element type.
26    CastError(CheckedCastError),
27    /// Invalid target element type.
28    TypeMismatch(String),
29}
30
31/// Data structure for tensors.
32#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
33pub struct TensorData {
34    /// The values of the tensor (as bytes).
35    pub bytes: Bytes,
36
37    /// The shape of the tensor.
38    pub shape: Vec<usize>,
39
40    /// The data type of the tensor.
41    pub dtype: DType,
42}
43
44impl TensorData {
45    /// Creates a new tensor data structure.
46    pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
47        // Ensure shape is valid
48        let shape = shape.into();
49        Self::check_data_len(&value, &shape);
50
51        Self {
52            bytes: Bytes::from_elems(value),
53            shape,
54            dtype: E::dtype(),
55        }
56    }
57
58    /// Creates a new quantized tensor data structure.
59    pub fn quantized<E: Element, S: Into<Vec<usize>>>(
60        value: Vec<E>,
61        shape: S,
62        strategy: QuantizationStrategy,
63        scheme: QuantScheme,
64    ) -> Self {
65        let shape = shape.into();
66        Self::check_data_len(&value, &shape);
67
68        let q_bytes = QuantizedBytes::new(value, strategy, scheme);
69
70        Self {
71            bytes: q_bytes.bytes,
72            shape,
73            dtype: DType::QFloat(q_bytes.scheme),
74        }
75    }
76
77    /// Creates a new tensor data structure from raw bytes.
78    pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
79        Self {
80            bytes,
81            shape: shape.into(),
82            dtype,
83        }
84    }
85
86    /// Creates a new tensor data structure from raw bytes stored in a vector.
87    ///
88    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are
89    /// certain that the bytes representation is valid.
90    pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
91        Self {
92            bytes: Bytes::from_bytes_vec(bytes),
93            shape: shape.into(),
94            dtype,
95        }
96    }
97
98    // Check that the input vector contains a correct number of elements
99    fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
100        let expected_data_len = Self::numel(shape);
101        let num_data = data.len();
102        assert_eq!(
103            expected_data_len, num_data,
104            "Shape {shape:?} is invalid for input of size {num_data:?}",
105        );
106    }
107
108    /// Returns the immutable slice view of the tensor data.
109    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
110        if E::dtype() == self.dtype {
111            match E::dtype() {
112                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
113                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
114                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
115                DType::Bool => {
116                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
117                        .map_err(DataError::CastError)?;
118                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
119                }
120                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
121            }
122        } else {
123            Err(DataError::TypeMismatch(format!(
124                "Invalid target element type (expected {:?}, got {:?})",
125                self.dtype,
126                E::dtype()
127            )))
128        }
129    }
130
131    /// Returns the mutable slice view of the tensor data.
132    ///
133    /// # Panics
134    /// If the target element type is different from the stored element type.
135    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
136        if E::dtype() == self.dtype {
137            match E::dtype() {
138                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
139                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
140                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
141                DType::Bool => {
142                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
143                        .map_err(DataError::CastError)?;
144                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
145                }
146                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
147                    .map_err(DataError::CastError),
148            }
149        } else {
150            Err(DataError::TypeMismatch(format!(
151                "Invalid target element type (expected {:?}, got {:?})",
152                self.dtype,
153                E::dtype()
154            )))
155        }
156    }
157
158    /// Returns the tensor data as a vector of scalar values.
159    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
160        Ok(self.as_slice()?.to_vec())
161    }
162
163    /// Returns the tensor data as a vector of scalar values.
164    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
165        // This means we cannot call `into_vec` for QFloat
166        if E::dtype() != self.dtype {
167            return Err(DataError::TypeMismatch(format!(
168                "Invalid target element type (expected {:?}, got {:?})",
169                self.dtype,
170                E::dtype()
171            )));
172        }
173
174        match E::dtype() {
175            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying
176            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool
177            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.
178            DType::Bool => {
179                let vec = self.into_vec_unchecked::<u8>()?;
180                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
181            }
182            _ => self.into_vec_unchecked(),
183        }
184    }
185
186    /// Returns the tensor data as a vector of scalar values. Does not check dtype.
187    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
188        let mut me = self;
189        me.bytes = match me.bytes.try_into_vec::<E>() {
190            Ok(elems) => return Ok(elems),
191            Err(bytes) => bytes,
192        };
193        // The bytes might have been deserialized and allocated with a different align.
194        // In that case, we have to memcopy the data into a new vector, more suitably allocated
195        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
196            .map_err(DataError::CastError)?
197            .to_vec())
198    }
199
200    /// Returns an iterator over the values of the tensor data.
201    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
202        if E::dtype() == self.dtype {
203            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
204        } else {
205            match self.dtype {
206                DType::I8 => Box::new(
207                    bytemuck::checked::cast_slice(&self.bytes)
208                        .iter()
209                        .map(|e: &i8| e.elem::<E>()),
210                ),
211                DType::I16 => Box::new(
212                    bytemuck::checked::cast_slice(&self.bytes)
213                        .iter()
214                        .map(|e: &i16| e.elem::<E>()),
215                ),
216                DType::I32 => Box::new(
217                    bytemuck::checked::cast_slice(&self.bytes)
218                        .iter()
219                        .map(|e: &i32| e.elem::<E>()),
220                ),
221                DType::I64 => Box::new(
222                    bytemuck::checked::cast_slice(&self.bytes)
223                        .iter()
224                        .map(|e: &i64| e.elem::<E>()),
225                ),
226                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
227                DType::U16 => Box::new(
228                    bytemuck::checked::cast_slice(&self.bytes)
229                        .iter()
230                        .map(|e: &u16| e.elem::<E>()),
231                ),
232                DType::U32 => Box::new(
233                    bytemuck::checked::cast_slice(&self.bytes)
234                        .iter()
235                        .map(|e: &u32| e.elem::<E>()),
236                ),
237                DType::U64 => Box::new(
238                    bytemuck::checked::cast_slice(&self.bytes)
239                        .iter()
240                        .map(|e: &u64| e.elem::<E>()),
241                ),
242                DType::BF16 => Box::new(
243                    bytemuck::checked::cast_slice(&self.bytes)
244                        .iter()
245                        .map(|e: &bf16| e.elem::<E>()),
246                ),
247                DType::F16 => Box::new(
248                    bytemuck::checked::cast_slice(&self.bytes)
249                        .iter()
250                        .map(|e: &f16| e.elem::<E>()),
251                ),
252                DType::F32 | DType::Flex32 => Box::new(
253                    bytemuck::checked::cast_slice(&self.bytes)
254                        .iter()
255                        .map(|e: &f32| e.elem::<E>()),
256                ),
257                DType::F64 => Box::new(
258                    bytemuck::checked::cast_slice(&self.bytes)
259                        .iter()
260                        .map(|e: &f64| e.elem::<E>()),
261                ),
262                // bool is a byte value equal to either 0 or 1
263                DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
264                DType::QFloat(scheme) => match scheme {
265                    QuantScheme {
266                        level: QuantLevel::Tensor | QuantLevel::Block(_),
267                        mode: QuantMode::Symmetric,
268                        value:
269                            QuantValue::Q8F
270                            | QuantValue::Q8S
271                            // Represent sub-byte values as i8
272                            | QuantValue::Q4F
273                            | QuantValue::Q4S
274                            | QuantValue::Q2F
275                            | QuantValue::Q2S,
276                        ..
277                    } => {
278                        // Quantized int8 values
279                        let q_bytes = QuantizedBytes {
280                            bytes: self.bytes.clone(),
281                            scheme,
282                            num_elements: self.num_elements(),
283                        };
284                        let (values, _) = q_bytes.into_vec_i8();
285
286                        Box::new(
287                            values
288                                .iter()
289                                .map(|e: &i8| e.elem::<E>())
290                                .collect::<Vec<_>>()
291                                .into_iter(),
292                        )
293                    }
294                    QuantScheme {
295                        level: QuantLevel::Tensor | QuantLevel::Block(_),
296                        mode: QuantMode::Symmetric,
297                        value:
298                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
299                        ..
300                    } => {
301                        unimplemented!("Not yet implemented for iteration");
302                    }
303                },
304            }
305        }
306    }
307
308    /// Returns the rank (the number of dimensions).
309    pub fn rank(&self) -> usize {
310        self.shape.len()
311    }
312
313    /// Returns the total number of elements of the tensor data.
314    pub fn num_elements(&self) -> usize {
315        Self::numel(&self.shape)
316    }
317
318    fn numel(shape: &[usize]) -> usize {
319        shape.iter().product()
320    }
321
322    /// Populates the data with random values.
323    pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
324        shape: S,
325        distribution: Distribution,
326        rng: &mut R,
327    ) -> Self {
328        let shape = shape.into();
329        let num_elements = Self::numel(&shape);
330        let mut data = Vec::with_capacity(num_elements);
331
332        for _ in 0..num_elements {
333            data.push(E::random(distribution, rng));
334        }
335
336        TensorData::new(data, shape)
337    }
338
339    /// Populates the data with zeros.
340    pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
341        let shape = shape.into();
342        let num_elements = Self::numel(&shape);
343        let mut data = Vec::<E>::with_capacity(num_elements);
344
345        for _ in 0..num_elements {
346            data.push(0.elem());
347        }
348
349        TensorData::new(data, shape)
350    }
351
352    /// Populates the data with ones.
353    pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
354        let shape = shape.into();
355        let num_elements = Self::numel(&shape);
356        let mut data = Vec::<E>::with_capacity(num_elements);
357
358        for _ in 0..num_elements {
359            data.push(1.elem());
360        }
361
362        TensorData::new(data, shape)
363    }
364
365    /// Populates the data with the given value
366    pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
367        let shape = shape.into();
368        let num_elements = Self::numel(&shape);
369        let mut data = Vec::<E>::with_capacity(num_elements);
370        for _ in 0..num_elements {
371            data.push(fill_value)
372        }
373
374        TensorData::new(data, shape)
375    }
376
377    pub(crate) fn full_dtype<E: Element, S: Into<Vec<usize>>>(
378        shape: S,
379        fill_value: E,
380        dtype: DType,
381    ) -> TensorData {
382        match dtype {
383            DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
384            DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
385            DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
386            DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
387            DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
388            DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
389            DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
390            DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
391            DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
392            DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
393            DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
394            DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
395            DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
396            DType::QFloat(_) => unreachable!(),
397        }
398    }
399
400    /// Converts the data to a different element type.
401    pub fn convert<E: Element>(self) -> Self {
402        self.convert_dtype(E::dtype())
403    }
404
405    /// Converts the data to a different element type.
406    pub fn convert_dtype(self, dtype: DType) -> Self {
407        if dtype == self.dtype {
408            self
409        } else if dtype.size() == self.dtype.size()
410            && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
411            && !matches!(dtype, DType::Bool | DType::QFloat(_))
412        {
413            match self.dtype {
414                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
415                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
416                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
417                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
418                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
419                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
420                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
421                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
422                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
423                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
424                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
425                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
426                DType::Bool | DType::QFloat(_) => unreachable!(),
427            }
428        } else {
429            match self.dtype {
430                DType::F64 => self.convert_clone_dtype::<f64>(dtype),
431                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
432                DType::F16 => self.convert_clone_dtype::<f16>(dtype),
433                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
434                DType::I64 => self.convert_clone_dtype::<i64>(dtype),
435                DType::I32 => self.convert_clone_dtype::<i32>(dtype),
436                DType::I16 => self.convert_clone_dtype::<i16>(dtype),
437                DType::I8 => self.convert_clone_dtype::<i8>(dtype),
438                DType::U64 => self.convert_clone_dtype::<u64>(dtype),
439                DType::U32 => self.convert_clone_dtype::<u32>(dtype),
440                DType::U16 => self.convert_clone_dtype::<u16>(dtype),
441                DType::U8 => self.convert_clone_dtype::<u8>(dtype),
442                DType::Bool => self.convert_clone_dtype::<bool>(dtype),
443                DType::QFloat(_) => unreachable!(),
444            }
445        }
446    }
447
448    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
449        match dtype {
450            DType::F64 => self.convert_inplace::<Current, f64>(),
451            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
452            DType::F16 => self.convert_inplace::<Current, f16>(),
453            DType::BF16 => self.convert_inplace::<Current, bf16>(),
454            DType::I64 => self.convert_inplace::<Current, i64>(),
455            DType::I32 => self.convert_inplace::<Current, i32>(),
456            DType::I16 => self.convert_inplace::<Current, i16>(),
457            DType::I8 => self.convert_inplace::<Current, i8>(),
458            DType::U64 => self.convert_inplace::<Current, u64>(),
459            DType::U32 => self.convert_inplace::<Current, u32>(),
460            DType::U16 => self.convert_inplace::<Current, u16>(),
461            DType::U8 => self.convert_inplace::<Current, u8>(),
462            DType::Bool | DType::QFloat(_) => unreachable!(),
463        }
464    }
465
466    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
467        mut self,
468    ) -> Self {
469        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
470            let t: Target = x.elem();
471            let x = cast_mut::<_, Target>(x);
472            *x = t;
473        }
474
475        self.dtype = Target::dtype();
476
477        self
478    }
479
480    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
481        match dtype {
482            DType::F64 => self.convert_clone::<Current, f64>(),
483            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
484            DType::F16 => self.convert_clone::<Current, f16>(),
485            DType::BF16 => self.convert_clone::<Current, bf16>(),
486            DType::I64 => self.convert_clone::<Current, i64>(),
487            DType::I32 => self.convert_clone::<Current, i32>(),
488            DType::I16 => self.convert_clone::<Current, i16>(),
489            DType::I8 => self.convert_clone::<Current, i8>(),
490            DType::U64 => self.convert_clone::<Current, u64>(),
491            DType::U32 => self.convert_clone::<Current, u32>(),
492            DType::U16 => self.convert_clone::<Current, u16>(),
493            DType::U8 => self.convert_clone::<Current, u8>(),
494            DType::Bool => self.convert_clone::<Current, bool>(),
495            DType::QFloat(_) => unreachable!(),
496        }
497    }
498
499    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
500        self,
501    ) -> Self {
502        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
503        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
504
505        for (x, out) in this.iter().zip(&mut out) {
506            *out = x.elem();
507        }
508
509        Self::new(out, self.shape)
510    }
511
512    /// Returns the data as a slice of bytes.
513    pub fn as_bytes(&self) -> &[u8] {
514        &self.bytes
515    }
516
517    /// Returns the bytes representation of the data.
518    pub fn into_bytes(self) -> Bytes {
519        self.bytes
520    }
521
522    /// Dequantizes the data according to its quantization scheme.
523    pub fn dequantize(self) -> Result<Self, DataError> {
524        if let DType::QFloat(scheme) = self.dtype {
525            let num_elements = self.num_elements();
526            let q_bytes = QuantizedBytes {
527                bytes: self.bytes,
528                scheme,
529                num_elements,
530            };
531
532            let values = q_bytes.dequantize().0;
533            Ok(Self::new(values, self.shape))
534        } else {
535            Err(DataError::TypeMismatch(format!(
536                "Expected quantized data, got {:?}",
537                self.dtype
538            )))
539        }
540    }
541
542    /// Asserts the data is equal to another data.
543    ///
544    /// # Arguments
545    ///
546    /// * `other` - The other data.
547    /// * `strict` - If true, the data types must the be same.
548    ///   Otherwise, the comparison is done in the current data type.
549    ///
550    /// # Panics
551    ///
552    /// Panics if the data is not equal.
553    #[track_caller]
554    pub fn assert_eq(&self, other: &Self, strict: bool) {
555        if strict {
556            assert_eq!(
557                self.dtype, other.dtype,
558                "Data types differ ({:?} != {:?})",
559                self.dtype, other.dtype
560            );
561        }
562
563        match self.dtype {
564            DType::F64 => self.assert_eq_elem::<f64>(other),
565            DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
566            DType::F16 => self.assert_eq_elem::<f16>(other),
567            DType::BF16 => self.assert_eq_elem::<bf16>(other),
568            DType::I64 => self.assert_eq_elem::<i64>(other),
569            DType::I32 => self.assert_eq_elem::<i32>(other),
570            DType::I16 => self.assert_eq_elem::<i16>(other),
571            DType::I8 => self.assert_eq_elem::<i8>(other),
572            DType::U64 => self.assert_eq_elem::<u64>(other),
573            DType::U32 => self.assert_eq_elem::<u32>(other),
574            DType::U16 => self.assert_eq_elem::<u16>(other),
575            DType::U8 => self.assert_eq_elem::<u8>(other),
576            DType::Bool => self.assert_eq_elem::<bool>(other),
577            DType::QFloat(q) => {
578                // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
579                let q_other = if let DType::QFloat(q_other) = other.dtype {
580                    q_other
581                } else {
582                    panic!("Quantized data differs from other not quantized data")
583                };
584
585                // Data equality mostly depends on input quantization type, but we also check level
586                if q.value == q_other.value && q.level == q_other.level {
587                    self.assert_eq_elem::<i8>(other)
588                } else {
589                    panic!("Quantization schemes differ ({q:?} != {q_other:?})")
590                }
591            }
592        }
593    }
594
595    #[track_caller]
596    fn assert_eq_elem<E: Element>(&self, other: &Self) {
597        let mut message = String::new();
598        if self.shape != other.shape {
599            message += format!(
600                "\n  => Shape is different: {:?} != {:?}",
601                self.shape, other.shape
602            )
603            .as_str();
604        }
605
606        let mut num_diff = 0;
607        let max_num_diff = 5;
608        for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
609            if a.cmp(&b).is_ne() {
610                // Only print the first 5 different values.
611                if num_diff < max_num_diff {
612                    message += format!("\n  => Position {i}: {a} != {b}").as_str();
613                }
614                num_diff += 1;
615            }
616        }
617
618        if num_diff >= max_num_diff {
619            message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
620        }
621
622        if !message.is_empty() {
623            panic!("Tensors are not eq:{message}");
624        }
625    }
626
627    /// Asserts the data is approximately equal to another data.
628    ///
629    /// # Arguments
630    ///
631    /// * `other` - The other data.
632    /// * `tolerance` - The tolerance of the comparison.
633    ///
634    /// # Panics
635    ///
636    /// Panics if the data is not approximately equal.
637    #[track_caller]
638    pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
639        let mut message = String::new();
640        if self.shape != other.shape {
641            message += format!(
642                "\n  => Shape is different: {:?} != {:?}",
643                self.shape, other.shape
644            )
645            .as_str();
646        }
647
648        let iter = self.iter::<F>().zip(other.iter::<F>());
649
650        let mut num_diff = 0;
651        let max_num_diff = 5;
652
653        for (i, (a, b)) in iter.enumerate() {
654            //if they are both nan, then they are equally nan
655            let both_nan = a.is_nan() && b.is_nan();
656            //this works for both infinities
657            let both_inf =
658                a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
659
660            if both_nan || both_inf {
661                continue;
662            }
663
664            if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
665                // Only print the first 5 different values.
666                if num_diff < max_num_diff {
667                    let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
668                    let max = F::max(a.abs(), b.abs());
669                    let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
670
671                    let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
672                    let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
673
674                    message += format!(
675                        "\n  => Position {i}: {a} != {b}\n     diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
676                    )
677                    .as_str();
678                }
679                num_diff += 1;
680            }
681        }
682
683        if num_diff >= max_num_diff {
684            message += format!("\n{} more errors...", num_diff - 5).as_str();
685        }
686
687        if !message.is_empty() {
688            panic!("Tensors are not approx eq:{message}");
689        }
690    }
691
692    /// Asserts each value is within a given range.
693    ///
694    /// # Arguments
695    ///
696    /// * `range` - The range.
697    ///
698    /// # Panics
699    ///
700    /// If any value is not within the half-open range bounded inclusively below
701    /// and exclusively above (`start..end`).
702    pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
703        for elem in self.iter::<E>() {
704            if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
705                panic!("Element ({elem:?}) is not within range {range:?}");
706            }
707        }
708    }
709
710    /// Asserts each value is within a given inclusive range.
711    ///
712    /// # Arguments
713    ///
714    /// * `range` - The range.
715    ///
716    /// # Panics
717    ///
718    /// If any value is not within the half-open range bounded inclusively (`start..=end`).
719    pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
720        let start = range.start();
721        let end = range.end();
722
723        for elem in self.iter::<E>() {
724            if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
725                panic!("Element ({elem:?}) is not within range {range:?}");
726            }
727        }
728    }
729}
730
731impl<E: Element, const A: usize> From<[E; A]> for TensorData {
732    fn from(elems: [E; A]) -> Self {
733        TensorData::new(elems.to_vec(), [A])
734    }
735}
736
737impl<const A: usize> From<[usize; A]> for TensorData {
738    fn from(elems: [usize; A]) -> Self {
739        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
740    }
741}
742
743impl From<&[usize]> for TensorData {
744    fn from(elems: &[usize]) -> Self {
745        let mut data = Vec::with_capacity(elems.len());
746        for elem in elems.iter() {
747            data.push(*elem as i64);
748        }
749
750        TensorData::new(data, [elems.len()])
751    }
752}
753
754impl<E: Element> From<&[E]> for TensorData {
755    fn from(elems: &[E]) -> Self {
756        let mut data = Vec::with_capacity(elems.len());
757        for elem in elems.iter() {
758            data.push(*elem);
759        }
760
761        TensorData::new(data, [elems.len()])
762    }
763}
764
765impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
766    fn from(elems: [[E; B]; A]) -> Self {
767        let mut data = Vec::with_capacity(A * B);
768        for elem in elems.into_iter().take(A) {
769            for elem in elem.into_iter().take(B) {
770                data.push(elem);
771            }
772        }
773
774        TensorData::new(data, [A, B])
775    }
776}
777
778impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
779    for TensorData
780{
781    fn from(elems: [[[E; C]; B]; A]) -> Self {
782        let mut data = Vec::with_capacity(A * B * C);
783
784        for elem in elems.into_iter().take(A) {
785            for elem in elem.into_iter().take(B) {
786                for elem in elem.into_iter().take(C) {
787                    data.push(elem);
788                }
789            }
790        }
791
792        TensorData::new(data, [A, B, C])
793    }
794}
795
796impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
797    From<[[[[E; D]; C]; B]; A]> for TensorData
798{
799    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
800        let mut data = Vec::with_capacity(A * B * C * D);
801
802        for elem in elems.into_iter().take(A) {
803            for elem in elem.into_iter().take(B) {
804                for elem in elem.into_iter().take(C) {
805                    for elem in elem.into_iter().take(D) {
806                        data.push(elem);
807                    }
808                }
809            }
810        }
811
812        TensorData::new(data, [A, B, C, D])
813    }
814}
815
816impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
817    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
818{
819    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
820        let mut data = Vec::with_capacity(A * B * C * D * E);
821
822        for elem in elems.into_iter().take(A) {
823            for elem in elem.into_iter().take(B) {
824                for elem in elem.into_iter().take(C) {
825                    for elem in elem.into_iter().take(D) {
826                        for elem in elem.into_iter().take(E) {
827                            data.push(elem);
828                        }
829                    }
830                }
831            }
832        }
833
834        TensorData::new(data, [A, B, C, D, E])
835    }
836}
837
838impl core::fmt::Display for TensorData {
839    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
840        let fmt = match self.dtype {
841            DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
842            DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
843            DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
844            DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
845            DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
846            DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
847            DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
848            DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
849            DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
850            DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
851            DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
852            DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
853            DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
854            DType::QFloat(scheme) => match scheme {
855                QuantScheme {
856                    level: QuantLevel::Tensor | QuantLevel::Block(_),
857                    mode: QuantMode::Symmetric,
858                    value:
859                        QuantValue::Q8F
860                        | QuantValue::Q8S
861                        // Display sub-byte values as i8
862                        | QuantValue::Q4F
863                        | QuantValue::Q4S
864                        | QuantValue::Q2F
865                        | QuantValue::Q2S,
866                    ..
867                } => {
868                    format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
869                },
870                QuantScheme {
871                        level: QuantLevel::Tensor | QuantLevel::Block(_),
872                        mode: QuantMode::Symmetric,
873                        value:
874                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
875                        ..
876                    } => {
877                        unimplemented!("Can't format yet");
878                    }
879            },
880        };
881        f.write_str(fmt.as_str())
882    }
883}
884
885/// The tolerance used to compare to floating point numbers.
886///
887/// Generally, two numbers `x` and `y` are approximately equal if
888///
889/// ```text
890/// |x - y| < max(R * (|x + y|), A)
891/// ```
892///
893/// where `R` is the relative tolerance and `A` is the absolute tolerance.
894///
895///
896/// The most common way to initialize this struct is to use `Tolerance::<F>::default()`.
897/// In that case, the relative and absolute tolerances are computed using an heuristic based
898/// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`.
899///
900/// Another common initialization is `Tolerance::<F>::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`.
901/// This will use a sane default to manage values too close to 0.0 and
902/// use different relative tolerances depending on the floating point precision.
903#[derive(Debug, Clone, Copy)]
904pub struct Tolerance<F> {
905    relative: F,
906    absolute: F,
907}
908
909impl<F: Float> Default for Tolerance<F> {
910    fn default() -> Self {
911        Self::balanced()
912    }
913}
914
915impl<F: Float> Tolerance<F> {
916    /// Create a tolerance with strict precision setting.
917    pub fn strict() -> Self {
918        Self {
919            relative: F::from(0.00).unwrap(),
920            absolute: F::from(64).unwrap() * F::min_positive_value(),
921        }
922    }
923    /// Create a tolerance with balanced precision setting.
924    pub fn balanced() -> Self {
925        Self {
926            relative: F::from(0.005).unwrap(), // 0.5%
927            absolute: F::from(1e-5).unwrap(),
928        }
929    }
930
931    /// Create a tolerance with permissive precision setting.
932    pub fn permissive() -> Self {
933        Self {
934            relative: F::from(0.01).unwrap(), // 1.0%
935            absolute: F::from(0.01).unwrap(),
936        }
937    }
938    /// When comparing two numbers, this uses both the relative and absolute differences.
939    ///
940    /// That is, `x` and `y` are approximately equal if
941    ///
942    /// ```text
943    /// |x - y| < max(R * (|x + y|), A)
944    /// ```
945    ///
946    /// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance.
947    pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
948        let relative = Self::check_relative(relative);
949        let absolute = Self::check_absolute(absolute);
950
951        Self { relative, absolute }
952    }
953
954    /// When comparing two numbers, this uses only the relative difference.
955    ///
956    /// That is, `x` and `y` are approximately equal if
957    ///
958    /// ```text
959    /// |x - y| < R * max(|x|, |y|)
960    /// ```
961    ///
962    /// where `R` is the relative `tolerance`.
963    pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
964        let relative = Self::check_relative(tolerance);
965
966        Self {
967            relative,
968            absolute: F::from(0.0).unwrap(),
969        }
970    }
971
972    /// When comparing two numbers, this uses only the absolute difference.
973    ///
974    /// That is, `x` and `y` are approximately equal if
975    ///
976    /// ```text
977    /// |x - y| < A
978    /// ```
979    ///
980    /// where `A` is the absolute `tolerance`.
981    pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
982        let absolute = Self::check_absolute(tolerance);
983
984        Self {
985            relative: F::from(0.0).unwrap(),
986            absolute,
987        }
988    }
989
990    /// Change the relative tolerance to the given one.
991    pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
992        self.relative = Self::check_relative(tolerance);
993        self
994    }
995
996    /// Change the relative tolerance to the given one only if `F` is half precision.
997    pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
998        if core::mem::size_of::<F>() == 2 {
999            self.relative = Self::check_relative(tolerance);
1000        }
1001        self
1002    }
1003
1004    /// Change the relative tolerance to the given one only if `F` is single precision.
1005    pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1006        if core::mem::size_of::<F>() == 4 {
1007            self.relative = Self::check_relative(tolerance);
1008        }
1009        self
1010    }
1011
1012    /// Change the relative tolerance to the given one only if `F` is double precision.
1013    pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1014        if core::mem::size_of::<F>() == 8 {
1015            self.relative = Self::check_relative(tolerance);
1016        }
1017        self
1018    }
1019
1020    /// Change the absolute tolerance to the given one.
1021    pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1022        self.absolute = Self::check_absolute(tolerance);
1023        self
1024    }
1025
1026    /// Change the absolute tolerance to the given one only if `F` is half precision.
1027    pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1028        if core::mem::size_of::<F>() == 2 {
1029            self.absolute = Self::check_absolute(tolerance);
1030        }
1031        self
1032    }
1033
1034    /// Change the absolute tolerance to the given one only if `F` is single precision.
1035    pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1036        if core::mem::size_of::<F>() == 4 {
1037            self.absolute = Self::check_absolute(tolerance);
1038        }
1039        self
1040    }
1041
1042    /// Change the absolute tolerance to the given one only if `F` is double precision.
1043    pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1044        if core::mem::size_of::<F>() == 8 {
1045            self.absolute = Self::check_absolute(tolerance);
1046        }
1047        self
1048    }
1049
1050    /// Checks if `x` and `y` are approximately equal given the tolerance.
1051    pub fn approx_eq(&self, x: F, y: F) -> bool {
1052        // See the accepted answer here
1053        // https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison
1054
1055        // This also handles the case where both a and b are infinity so that we don't need
1056        // to manage it in the rest of the function.
1057        if x == y {
1058            return true;
1059        }
1060
1061        let diff = (x - y).abs();
1062        let max = F::max(x.abs(), y.abs());
1063
1064        diff < self.absolute.max(self.relative * max)
1065    }
1066
1067    fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
1068        let tolerance = F::from(tolerance).unwrap();
1069        assert!(tolerance <= F::one());
1070        tolerance
1071    }
1072
1073    fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
1074        let tolerance = F::from(tolerance).unwrap();
1075        assert!(tolerance >= F::zero());
1076        tolerance
1077    }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082    use crate::{Shape, quantization::SymmetricQuantization};
1083
1084    use super::*;
1085    use alloc::vec;
1086    use rand::{SeedableRng, rngs::StdRng};
1087
1088    #[test]
1089    fn should_have_rank() {
1090        let shape = Shape::new([3, 5, 6]);
1091        let data = TensorData::random::<f32, _, _>(
1092            shape,
1093            Distribution::Default,
1094            &mut StdRng::from_os_rng(),
1095        );
1096
1097        assert_eq!(data.rank(), 3);
1098    }
1099
1100    #[test]
1101    fn into_vec_should_yield_same_value_as_iter() {
1102        let shape = Shape::new([3, 5, 6]);
1103        let data = TensorData::random::<f32, _, _>(
1104            shape,
1105            Distribution::Default,
1106            &mut StdRng::from_os_rng(),
1107        );
1108
1109        let expected = data.iter::<f32>().collect::<Vec<f32>>();
1110        let actual = data.into_vec::<f32>().unwrap();
1111
1112        assert_eq!(expected, actual);
1113    }
1114
1115    #[test]
1116    #[should_panic]
1117    fn into_vec_should_assert_wrong_dtype() {
1118        let shape = Shape::new([3, 5, 6]);
1119        let data = TensorData::random::<f32, _, _>(
1120            shape,
1121            Distribution::Default,
1122            &mut StdRng::from_os_rng(),
1123        );
1124
1125        data.into_vec::<i32>().unwrap();
1126    }
1127
1128    #[test]
1129    fn should_have_right_num_elements() {
1130        let shape = Shape::new([3, 5, 6]);
1131        let num_elements = shape.num_elements();
1132        let data = TensorData::random::<f32, _, _>(
1133            shape,
1134            Distribution::Default,
1135            &mut StdRng::from_os_rng(),
1136        );
1137
1138        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s
1139        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
1140    }
1141
1142    #[test]
1143    fn should_have_right_shape() {
1144        let data = TensorData::from([[3.0, 5.0, 6.0]]);
1145        assert_eq!(data.shape, vec![1, 3]);
1146
1147        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
1148        assert_eq!(data.shape, vec![2, 3]);
1149
1150        let data = TensorData::from([3.0, 5.0, 6.0]);
1151        assert_eq!(data.shape, vec![3]);
1152    }
1153
1154    #[test]
1155    fn should_assert_appox_eq_limit() {
1156        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1157        let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
1158
1159        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
1160        data1.assert_approx_eq::<half::f16>(&data2, Tolerance::absolute(3e-2));
1161    }
1162
1163    #[test]
1164    #[should_panic]
1165    fn should_assert_approx_eq_above_limit() {
1166        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1167        let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
1168
1169        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1170    }
1171
1172    #[test]
1173    #[should_panic]
1174    fn should_assert_approx_eq_check_shape() {
1175        let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1176        let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
1177
1178        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1179    }
1180
1181    #[test]
1182    fn should_convert_bytes_correctly() {
1183        let mut vector: Vec<f32> = Vec::with_capacity(5);
1184        vector.push(2.0);
1185        vector.push(3.0);
1186        let data1 = TensorData::new(vector, vec![2]);
1187
1188        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
1189        assert_eq!(data1.bytes.len(), 2 * factor);
1190        assert_eq!(data1.bytes.capacity(), 5 * factor);
1191    }
1192
1193    #[test]
1194    fn should_convert_bytes_correctly_inplace() {
1195        fn test_precision<E: Element>() {
1196            let data = TensorData::new((0..32).collect(), [32]);
1197            for (i, val) in data
1198                .clone()
1199                .convert::<E>()
1200                .into_vec::<E>()
1201                .unwrap()
1202                .into_iter()
1203                .enumerate()
1204            {
1205                assert_eq!(i as u32, val.elem::<u32>())
1206            }
1207        }
1208        test_precision::<f32>();
1209        test_precision::<f16>();
1210        test_precision::<i64>();
1211        test_precision::<i32>();
1212    }
1213
1214    #[test]
1215    #[should_panic = "Expected quantized data"]
1216    fn should_not_dequantize() {
1217        let data = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1218        data.dequantize().unwrap();
1219    }
1220
1221    #[test]
1222    fn should_support_dequantize() {
1223        let data = TensorData::quantized(
1224            vec![-127i8, -77, -26, 25, 76, 127],
1225            [2, 3],
1226            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
1227                0.1,
1228                QuantValue::Q8S,
1229            )),
1230            QuantScheme {
1231                level: QuantLevel::Tensor,
1232                value: QuantValue::Q8S,
1233                mode: QuantMode::Symmetric,
1234                ..Default::default()
1235            },
1236        );
1237
1238        let output = data.dequantize().unwrap();
1239
1240        output.assert_approx_eq::<f32>(
1241            &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1242            Tolerance::default(),
1243        );
1244
1245        output.assert_approx_eq::<f16>(
1246            &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1247            Tolerance::default(),
1248        );
1249    }
1250
1251    macro_rules! test_dtypes {
1252    ($test_name:ident, $($dtype:ty),*) => {
1253        $(
1254            paste::paste! {
1255                #[test]
1256                fn [<$test_name _ $dtype:snake>]() {
1257                    let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
1258                    let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
1259                    assert_eq!(full_dtype, full);
1260                }
1261            }
1262        )*
1263    };
1264}
1265
1266    test_dtypes!(
1267        should_create_with_dtype,
1268        bool,
1269        i8,
1270        i16,
1271        i32,
1272        i64,
1273        u8,
1274        u16,
1275        u32,
1276        u64,
1277        f16,
1278        bf16,
1279        f32,
1280        f64
1281    );
1282}