1use core::{
2 any::{Any, TypeId},
3 f32,
4};
5
6use alloc::boxed::Box;
7use alloc::format;
8use alloc::string::String;
9use alloc::vec::Vec;
10use bytemuck::{checked::CheckedCastError, AnyBitPattern};
11use half::{bf16, f16};
12
13use crate::{
14 quantization::{
15 Quantization, QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes,
16 },
17 tensor::{bytes::Bytes, Shape},
18 DType, Distribution, Element, ElementConversion,
19};
20
21use num_traits::pow::Pow;
22
23#[cfg(not(feature = "std"))]
24#[allow(unused_imports)]
25use num_traits::Float;
26
27use rand::RngCore;
28
29#[derive(Debug)]
31pub enum DataError {
32 CastError(CheckedCastError),
34 TypeMismatch(String),
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
40pub struct TensorData {
41 bytes: Bytes,
43
44 pub shape: Vec<usize>,
46
47 pub dtype: DType,
49}
50
51impl TensorData {
52 pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
54 let shape = shape.into();
56 Self::check_data_len(&value, &shape);
57
58 Self {
59 bytes: Bytes::from_elems(value),
60 shape,
61 dtype: E::dtype(),
62 }
63 }
64
65 pub fn quantized<E: Element, S: Into<Vec<usize>>>(
67 value: Vec<E>,
68 shape: S,
69 strategy: QuantizationStrategy,
70 ) -> Self {
71 let shape = shape.into();
72 Self::check_data_len(&value, &shape);
73
74 let q_bytes = QuantizedBytes::new(value, strategy);
75
76 Self {
77 bytes: q_bytes.bytes,
78 shape,
79 dtype: DType::QFloat(q_bytes.scheme),
80 }
81 }
82
83 pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
88 Self {
89 bytes: Bytes::from_bytes_vec(bytes),
90 shape: shape.into(),
91 dtype,
92 }
93 }
94
95 fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
97 let expected_data_len = Self::numel(shape);
98 let num_data = data.len();
99 assert_eq!(
100 expected_data_len, num_data,
101 "Shape {:?} is invalid for input of size {:?}",
102 shape, num_data,
103 );
104 }
105
106 fn try_as_slice<E: Element>(&self) -> Result<&[E], DataError> {
107 bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError)
108 }
109
110 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
112 if E::dtype() == self.dtype {
113 self.try_as_slice()
114 } else {
115 Err(DataError::TypeMismatch(format!(
116 "Invalid target element type (expected {:?}, got {:?})",
117 self.dtype,
118 E::dtype()
119 )))
120 }
121 }
122
123 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
128 if E::dtype() == self.dtype {
129 bytemuck::checked::try_cast_slice_mut(&mut self.bytes).map_err(DataError::CastError)
130 } else {
131 Err(DataError::TypeMismatch(format!(
132 "Invalid target element type (expected {:?}, got {:?})",
133 self.dtype,
134 E::dtype()
135 )))
136 }
137 }
138
139 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
141 Ok(self.as_slice()?.to_vec())
142 }
143
144 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
146 if E::dtype() != self.dtype {
148 return Err(DataError::TypeMismatch(format!(
149 "Invalid target element type (expected {:?}, got {:?})",
150 self.dtype,
151 E::dtype()
152 )));
153 }
154
155 let mut me = self;
156 me.bytes = match me.bytes.try_into_vec::<E>() {
157 Ok(elems) => return Ok(elems),
158 Err(bytes) => bytes,
159 };
160 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
163 .map_err(DataError::CastError)?
164 .to_vec())
165 }
166
167 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
169 if E::dtype() == self.dtype {
170 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
171 } else {
172 match self.dtype {
173 DType::I8 => Box::new(
174 bytemuck::checked::cast_slice(&self.bytes)
175 .iter()
176 .map(|e: &i8| e.elem::<E>()),
177 ),
178 DType::I16 => Box::new(
179 bytemuck::checked::cast_slice(&self.bytes)
180 .iter()
181 .map(|e: &i16| e.elem::<E>()),
182 ),
183 DType::I32 => Box::new(
184 bytemuck::checked::cast_slice(&self.bytes)
185 .iter()
186 .map(|e: &i32| e.elem::<E>()),
187 ),
188 DType::I64 => Box::new(
189 bytemuck::checked::cast_slice(&self.bytes)
190 .iter()
191 .map(|e: &i64| e.elem::<E>()),
192 ),
193 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
194 DType::U16 => Box::new(
195 bytemuck::checked::cast_slice(&self.bytes)
196 .iter()
197 .map(|e: &u16| e.elem::<E>()),
198 ),
199 DType::U32 => Box::new(
200 bytemuck::checked::cast_slice(&self.bytes)
201 .iter()
202 .map(|e: &u32| e.elem::<E>()),
203 ),
204 DType::U64 => Box::new(
205 bytemuck::checked::cast_slice(&self.bytes)
206 .iter()
207 .map(|e: &u64| e.elem::<E>()),
208 ),
209 DType::BF16 => Box::new(
210 bytemuck::checked::cast_slice(&self.bytes)
211 .iter()
212 .map(|e: &bf16| e.elem::<E>()),
213 ),
214 DType::F16 => Box::new(
215 bytemuck::checked::cast_slice(&self.bytes)
216 .iter()
217 .map(|e: &f16| e.elem::<E>()),
218 ),
219 DType::F32 => Box::new(
220 bytemuck::checked::cast_slice(&self.bytes)
221 .iter()
222 .map(|e: &f32| e.elem::<E>()),
223 ),
224 DType::F64 => Box::new(
225 bytemuck::checked::cast_slice(&self.bytes)
226 .iter()
227 .map(|e: &f64| e.elem::<E>()),
228 ),
229 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
231 DType::QFloat(scheme) => match scheme {
232 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
233 | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
234 let q_bytes = QuantizedBytes {
236 bytes: self.bytes.clone(),
237 scheme,
238 num_elements: self.num_elements(),
239 };
240 let (values, _) = q_bytes.into_vec_i8();
241
242 Box::new(
243 values
244 .iter()
245 .map(|e: &i8| e.elem::<E>())
246 .collect::<Vec<_>>()
247 .into_iter(),
248 )
249 }
250 },
251 }
252 }
253 }
254
255 pub fn num_elements(&self) -> usize {
257 Self::numel(&self.shape)
258 }
259
260 fn numel(shape: &[usize]) -> usize {
261 shape.iter().product()
262 }
263
264 pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
266 shape: S,
267 distribution: Distribution,
268 rng: &mut R,
269 ) -> Self {
270 let shape = shape.into();
271 let num_elements = Self::numel(&shape);
272 let mut data = Vec::with_capacity(num_elements);
273
274 for _ in 0..num_elements {
275 data.push(E::random(distribution, rng));
276 }
277
278 TensorData::new(data, shape)
279 }
280
281 pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
283 let shape = shape.into();
284 let num_elements = Self::numel(&shape);
285 let mut data = Vec::<E>::with_capacity(num_elements);
286
287 for _ in 0..num_elements {
288 data.push(0.elem());
289 }
290
291 TensorData::new(data, shape)
292 }
293
294 pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
296 let shape = shape.into();
297 let num_elements = Self::numel(&shape);
298 let mut data = Vec::<E>::with_capacity(num_elements);
299
300 for _ in 0..num_elements {
301 data.push(1.elem());
302 }
303
304 TensorData::new(data, shape)
305 }
306
307 pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
309 let shape = shape.into();
310 let num_elements = Self::numel(&shape);
311 let mut data = Vec::<E>::with_capacity(num_elements);
312 for _ in 0..num_elements {
313 data.push(fill_value)
314 }
315
316 TensorData::new(data, shape)
317 }
318
319 pub fn convert<E: Element>(self) -> Self {
321 if E::dtype() == self.dtype {
322 self
323 } else if core::mem::size_of::<E>() == self.dtype.size()
324 && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
325 {
326 match self.dtype {
327 DType::F64 => self.convert_inplace::<f64, E>(),
328 DType::F32 => self.convert_inplace::<f32, E>(),
329 DType::F16 => self.convert_inplace::<f16, E>(),
330 DType::BF16 => self.convert_inplace::<bf16, E>(),
331 DType::I64 => self.convert_inplace::<i64, E>(),
332 DType::I32 => self.convert_inplace::<i32, E>(),
333 DType::I16 => self.convert_inplace::<i16, E>(),
334 DType::I8 => self.convert_inplace::<i8, E>(),
335 DType::U64 => self.convert_inplace::<u64, E>(),
336 DType::U32 => self.convert_inplace::<u32, E>(),
337 DType::U16 => self.convert_inplace::<u16, E>(),
338 DType::U8 => self.convert_inplace::<u8, E>(),
339 DType::Bool | DType::QFloat(_) => unreachable!(),
340 }
341 } else {
342 TensorData::new(self.iter::<E>().collect(), self.shape)
343 }
344 }
345
346 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element>(mut self) -> Self {
347 let step = core::mem::size_of::<Current>();
348
349 for offset in 0..(self.bytes.len() / step) {
350 let start = offset * step;
351 let end = start + step;
352
353 let slice_old = &mut self.bytes[start..end];
354 let val: Current = *bytemuck::from_bytes(slice_old);
355 let val = &val.elem::<Target>();
356 let slice_new = bytemuck::bytes_of(val);
357
358 slice_old.clone_from_slice(slice_new);
359 }
360 self.dtype = Target::dtype();
361
362 self
363 }
364
365 pub fn as_bytes(&self) -> &[u8] {
367 &self.bytes
368 }
369
370 pub fn into_bytes(self) -> Bytes {
372 self.bytes
373 }
374
375 pub fn with_quantization(self, quantization: QuantizationStrategy) -> Self {
381 assert_eq!(
382 self.dtype,
383 DType::F32,
384 "Only f32 data type can be quantized"
385 );
386 match &quantization {
387 QuantizationStrategy::PerTensorAffineInt8(strategy) => TensorData::quantized(
388 strategy.quantize(self.as_slice().unwrap()),
389 self.shape,
390 quantization,
391 ),
392 QuantizationStrategy::PerTensorSymmetricInt8(strategy) => TensorData::quantized(
393 strategy.quantize(self.as_slice().unwrap()),
394 self.shape,
395 quantization,
396 ),
397 }
398 }
399
400 pub fn dequantize(self) -> Result<Self, DataError> {
402 if let DType::QFloat(scheme) = self.dtype {
403 let num_elements = self.num_elements();
404 let q_bytes = QuantizedBytes {
405 bytes: self.bytes,
406 scheme,
407 num_elements,
408 };
409
410 let values = q_bytes.dequantize().0;
411 Ok(Self::new(values, self.shape))
412 } else {
413 Err(DataError::TypeMismatch(format!(
414 "Expected quantized data, got {:?}",
415 self.dtype
416 )))
417 }
418 }
419
420 #[track_caller]
431 pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
432 let tolerance = 0.1.pow(precision as f64);
433
434 self.assert_approx_eq_diff(other, tolerance)
435 }
436
437 #[track_caller]
449 pub fn assert_eq(&self, other: &Self, strict: bool) {
450 if strict {
451 assert_eq!(
452 self.dtype, other.dtype,
453 "Data types differ ({:?} != {:?})",
454 self.dtype, other.dtype
455 );
456 }
457
458 match self.dtype {
459 DType::F64 => self.assert_eq_elem::<f64>(other),
460 DType::F32 => self.assert_eq_elem::<f32>(other),
461 DType::F16 => self.assert_eq_elem::<f16>(other),
462 DType::BF16 => self.assert_eq_elem::<bf16>(other),
463 DType::I64 => self.assert_eq_elem::<i64>(other),
464 DType::I32 => self.assert_eq_elem::<i32>(other),
465 DType::I16 => self.assert_eq_elem::<i16>(other),
466 DType::I8 => self.assert_eq_elem::<i8>(other),
467 DType::U64 => self.assert_eq_elem::<u64>(other),
468 DType::U32 => self.assert_eq_elem::<u32>(other),
469 DType::U16 => self.assert_eq_elem::<u16>(other),
470 DType::U8 => self.assert_eq_elem::<u8>(other),
471 DType::Bool => self.assert_eq_elem::<bool>(other),
472 DType::QFloat(q) => {
473 let q_other = if let DType::QFloat(q_other) = other.dtype {
475 q_other
476 } else {
477 panic!("Quantized data differs from other not quantized data")
478 };
479 match (q, q_other) {
480 (
481 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
482 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
483 )
484 | (
485 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
486 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
487 ) => self.assert_eq_elem::<i8>(other),
488 _ => panic!("Quantization schemes differ ({:?} != {:?})", q, q_other),
489 }
490 }
491 }
492 }
493
494 #[track_caller]
495 fn assert_eq_elem<E: Element>(&self, other: &Self) {
496 let mut message = String::new();
497 if self.shape != other.shape {
498 message += format!(
499 "\n => Shape is different: {:?} != {:?}",
500 self.shape, other.shape
501 )
502 .as_str();
503 }
504
505 let mut num_diff = 0;
506 let max_num_diff = 5;
507 for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
508 if a.cmp(&b).is_ne() {
509 if num_diff < max_num_diff {
511 message += format!("\n => Position {i}: {a} != {b}").as_str();
512 }
513 num_diff += 1;
514 }
515 }
516
517 if num_diff >= max_num_diff {
518 message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
519 }
520
521 if !message.is_empty() {
522 panic!("Tensors are not eq:{}", message);
523 }
524 }
525
526 #[track_caller]
537 pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) {
538 let mut message = String::new();
539 if self.shape != other.shape {
540 message += format!(
541 "\n => Shape is different: {:?} != {:?}",
542 self.shape, other.shape
543 )
544 .as_str();
545 }
546
547 let iter = self.iter::<f64>().zip(other.iter::<f64>());
548
549 let mut num_diff = 0;
550 let max_num_diff = 5;
551
552 for (i, (a, b)) in iter.enumerate() {
553 let both_nan = a.is_nan() && b.is_nan();
555 let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.));
557
558 if both_nan || both_inf {
559 continue;
560 }
561
562 let err = (a - b).abs();
563
564 if self.dtype.is_float() {
565 if let Some((err, tolerance)) = compare_floats(a, b, self.dtype, tolerance) {
566 if num_diff < max_num_diff {
568 message += format!(
569 "\n => Position {i}: {a} != {b} | difference {err} > tolerance \
570 {tolerance}"
571 )
572 .as_str();
573 }
574 num_diff += 1;
575 }
576 } else if err > tolerance || err.is_nan() {
577 if num_diff < max_num_diff {
579 message += format!(
580 "\n => Position {i}: {a} != {b} | difference {err} > tolerance \
581 {tolerance}"
582 )
583 .as_str();
584 }
585 num_diff += 1;
586 }
587 }
588
589 if num_diff >= max_num_diff {
590 message += format!("\n{} more errors...", num_diff - 5).as_str();
591 }
592
593 if !message.is_empty() {
594 panic!("Tensors are not approx eq:{}", message);
595 }
596 }
597
598 pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
609 let start = range.start.elem::<f32>();
610 let end = range.end.elem::<f32>();
611
612 for elem in self.iter::<f32>() {
613 if elem < start || elem >= end {
614 panic!("Element ({elem:?}) is not within range {range:?}");
615 }
616 }
617 }
618
619 pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
629 let start = range.start().elem::<f32>();
630 let end = range.end().elem::<f32>();
631
632 for elem in self.iter::<f32>() {
633 if elem < start || elem > end {
634 panic!("Element ({elem:?}) is not within range {range:?}");
635 }
636 }
637 }
638}
639
640impl<E: Element, const A: usize> From<[E; A]> for TensorData {
641 fn from(elems: [E; A]) -> Self {
642 TensorData::new(elems.to_vec(), [A])
643 }
644}
645
646impl<const A: usize> From<[usize; A]> for TensorData {
647 fn from(elems: [usize; A]) -> Self {
648 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
649 }
650}
651
652impl From<&[usize]> for TensorData {
653 fn from(elems: &[usize]) -> Self {
654 let mut data = Vec::with_capacity(elems.len());
655 for elem in elems.iter() {
656 data.push(*elem as i64);
657 }
658
659 TensorData::new(data, [elems.len()])
660 }
661}
662
663impl<E: Element> From<&[E]> for TensorData {
664 fn from(elems: &[E]) -> Self {
665 let mut data = Vec::with_capacity(elems.len());
666 for elem in elems.iter() {
667 data.push(*elem);
668 }
669
670 TensorData::new(data, [elems.len()])
671 }
672}
673
674impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
675 fn from(elems: [[E; B]; A]) -> Self {
676 let mut data = Vec::with_capacity(A * B);
677 for elem in elems.into_iter().take(A) {
678 for elem in elem.into_iter().take(B) {
679 data.push(elem);
680 }
681 }
682
683 TensorData::new(data, [A, B])
684 }
685}
686
687impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
688 for TensorData
689{
690 fn from(elems: [[[E; C]; B]; A]) -> Self {
691 let mut data = Vec::with_capacity(A * B * C);
692
693 for elem in elems.into_iter().take(A) {
694 for elem in elem.into_iter().take(B) {
695 for elem in elem.into_iter().take(C) {
696 data.push(elem);
697 }
698 }
699 }
700
701 TensorData::new(data, [A, B, C])
702 }
703}
704
705impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
706 From<[[[[E; D]; C]; B]; A]> for TensorData
707{
708 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
709 let mut data = Vec::with_capacity(A * B * C * D);
710
711 for elem in elems.into_iter().take(A) {
712 for elem in elem.into_iter().take(B) {
713 for elem in elem.into_iter().take(C) {
714 for elem in elem.into_iter().take(D) {
715 data.push(elem);
716 }
717 }
718 }
719 }
720
721 TensorData::new(data, [A, B, C, D])
722 }
723}
724
725impl<
726 Elem: Element,
727 const A: usize,
728 const B: usize,
729 const C: usize,
730 const D: usize,
731 const E: usize,
732 > From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
733{
734 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
735 let mut data = Vec::with_capacity(A * B * C * D * E);
736
737 for elem in elems.into_iter().take(A) {
738 for elem in elem.into_iter().take(B) {
739 for elem in elem.into_iter().take(C) {
740 for elem in elem.into_iter().take(D) {
741 for elem in elem.into_iter().take(E) {
742 data.push(elem);
743 }
744 }
745 }
746 }
747 }
748
749 TensorData::new(data, [A, B, C, D, E])
750 }
751}
752
753impl core::fmt::Display for TensorData {
754 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
755 let fmt = match self.dtype {
756 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
757 DType::F32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
758 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
759 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
760 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
761 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
762 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
763 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
764 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
765 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
766 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
767 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
768 DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
769 DType::QFloat(scheme) => match scheme {
770 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
771 | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
772 format!("{:?} {scheme:?}", self.try_as_slice::<i8>().unwrap())
773 }
774 },
775 };
776 f.write_str(fmt.as_str())
777 }
778}
779
780#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)]
782#[deprecated(
783 since = "0.14.0",
784 note = "the internal data format has changed, please use `TensorData` instead"
785)]
786pub struct DataSerialize<E> {
787 pub value: Vec<E>,
789 pub shape: Vec<usize>,
791}
792
793#[derive(new, Debug, Clone, PartialEq, Eq)]
795#[deprecated(
796 since = "0.14.0",
797 note = "the internal data format has changed, please use `TensorData` instead"
798)]
799pub struct Data<E, const D: usize> {
800 pub value: Vec<E>,
802
803 pub shape: Shape,
805}
806
807#[allow(deprecated)]
808impl<const D: usize, E: Element> Data<E, D> {
809 pub fn convert<EOther: Element>(self) -> Data<EOther, D> {
811 let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
812
813 Data {
814 value,
815 shape: self.shape,
816 }
817 }
818
819 pub fn assert_within_range<EOther: Element>(&self, range: core::ops::Range<EOther>) {
830 let start = range.start.elem::<f32>();
831 let end = range.end.elem::<f32>();
832
833 for elem in self.value.iter() {
834 let elem = elem.elem::<f32>();
835 if elem < start || elem >= end {
836 panic!("Element ({elem:?}) is not within range {range:?}");
837 }
838 }
839 }
840}
841
842#[allow(deprecated)]
843impl<E: Element> DataSerialize<E> {
844 pub fn convert<EOther: Element>(self) -> DataSerialize<EOther> {
846 if TypeId::of::<E>() == TypeId::of::<EOther>() {
847 let cast: Box<dyn Any> = Box::new(self);
848 let cast: Box<DataSerialize<EOther>> = cast.downcast().unwrap();
849 return *cast;
850 }
851
852 let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
853
854 DataSerialize {
855 value,
856 shape: self.shape,
857 }
858 }
859
860 pub fn into_tensor_data(self) -> TensorData {
862 TensorData::new(self.value, self.shape)
863 }
864}
865
866#[allow(deprecated)]
867impl<E: Element, const D: usize> Data<E, D> {
868 pub fn random<R: RngCore>(shape: Shape, distribution: Distribution, rng: &mut R) -> Self {
870 let num_elements = shape.num_elements();
871 let mut data = Vec::with_capacity(num_elements);
872
873 for _ in 0..num_elements {
874 data.push(E::random(distribution, rng));
875 }
876
877 Data::new(data, shape)
878 }
879}
880
881#[allow(deprecated)]
882impl<E: core::fmt::Debug, const D: usize> Data<E, D>
883where
884 E: Element,
885{
886 pub fn zeros<S: Into<Shape>>(shape: S) -> Data<E, D> {
888 let shape = shape.into();
889 let num_elements = shape.num_elements();
890 let mut data = Vec::with_capacity(num_elements);
891
892 for _ in 0..num_elements {
893 data.push(0.elem());
894 }
895
896 Data::new(data, shape)
897 }
898}
899
900#[allow(deprecated)]
901impl<E: core::fmt::Debug, const D: usize> Data<E, D>
902where
903 E: Element,
904{
905 pub fn ones(shape: Shape) -> Data<E, D> {
907 let num_elements = shape.num_elements();
908 let mut data = Vec::with_capacity(num_elements);
909
910 for _ in 0..num_elements {
911 data.push(1.elem());
912 }
913
914 Data::new(data, shape)
915 }
916}
917
918#[allow(deprecated)]
919impl<E: core::fmt::Debug, const D: usize> Data<E, D>
920where
921 E: Element,
922{
923 pub fn full(shape: Shape, fill_value: E) -> Data<E, D> {
925 let num_elements = shape.num_elements();
926 let mut data = Vec::with_capacity(num_elements);
927 for _ in 0..num_elements {
928 data.push(fill_value)
929 }
930
931 Data::new(data, shape)
932 }
933}
934
935#[allow(deprecated)]
936impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
937 pub fn serialize(&self) -> DataSerialize<E> {
943 DataSerialize {
944 value: self.value.clone(),
945 shape: self.shape.dims.to_vec(),
946 }
947 }
948}
949
950#[allow(deprecated)]
951impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data<E, D> {
952 #[track_caller]
963 pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
964 let tolerance = 0.1.pow(precision as f64);
965
966 self.assert_approx_eq_diff(other, tolerance)
967 }
968
969 #[track_caller]
980 pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) {
981 let mut message = String::new();
982 if self.shape != other.shape {
983 message += format!(
984 "\n => Shape is different: {:?} != {:?}",
985 self.shape.dims, other.shape.dims
986 )
987 .as_str();
988 }
989
990 let iter = self.value.clone().into_iter().zip(other.value.clone());
991
992 let mut num_diff = 0;
993 let max_num_diff = 5;
994
995 for (i, (a, b)) in iter.enumerate() {
996 let a: f64 = a.into();
997 let b: f64 = b.into();
998
999 let both_nan = a.is_nan() && b.is_nan();
1001 let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.));
1003
1004 if both_nan || both_inf {
1005 continue;
1006 }
1007
1008 let err = (a - b).abs();
1009
1010 if E::dtype().is_float() {
1011 if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) {
1012 if num_diff < max_num_diff {
1014 message += format!(
1015 "\n => Position {i}: {a} != {b} | difference {err} > tolerance \
1016 {tolerance}"
1017 )
1018 .as_str();
1019 }
1020 num_diff += 1;
1021 }
1022 } else if err > tolerance || err.is_nan() {
1023 if num_diff < max_num_diff {
1025 message += format!(
1026 "\n => Position {i}: {a} != {b} | difference {err} > tolerance \
1027 {tolerance}"
1028 )
1029 .as_str();
1030 }
1031 num_diff += 1;
1032 }
1033 }
1034
1035 if num_diff >= max_num_diff {
1036 message += format!("\n{} more errors...", num_diff - 5).as_str();
1037 }
1038
1039 if !message.is_empty() {
1040 panic!("Tensors are not approx eq:{}", message);
1041 }
1042 }
1043}
1044
1045#[allow(deprecated)]
1046impl<const D: usize> Data<usize, D> {
1047 pub fn from_usize<O: num_traits::FromPrimitive>(self) -> Data<O, D> {
1049 let value: Vec<O> = self
1050 .value
1051 .into_iter()
1052 .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap())
1053 .collect();
1054
1055 Data {
1056 value,
1057 shape: self.shape,
1058 }
1059 }
1060}
1061
1062#[allow(deprecated)]
1063impl<E: Clone, const D: usize> From<&DataSerialize<E>> for Data<E, D> {
1064 fn from(data: &DataSerialize<E>) -> Self {
1065 let mut dims = [0; D];
1066 dims[..D].copy_from_slice(&data.shape[..D]);
1067 Data::new(data.value.clone(), Shape::new(dims))
1068 }
1069}
1070
1071#[allow(deprecated)]
1072impl<E, const D: usize> From<DataSerialize<E>> for Data<E, D> {
1073 fn from(data: DataSerialize<E>) -> Self {
1074 let mut dims = [0; D];
1075 dims[..D].copy_from_slice(&data.shape[..D]);
1076 Data::new(data.value, Shape::new(dims))
1077 }
1078}
1079
1080#[allow(deprecated)]
1081impl<E: core::fmt::Debug + Copy, const A: usize> From<[E; A]> for Data<E, 1> {
1082 fn from(elems: [E; A]) -> Self {
1083 let mut data = Vec::with_capacity(2 * A);
1084 for elem in elems.into_iter() {
1085 data.push(elem);
1086 }
1087
1088 Data::new(data, Shape::new([A]))
1089 }
1090}
1091
1092#[allow(deprecated)]
1093impl<E: core::fmt::Debug + Copy> From<&[E]> for Data<E, 1> {
1094 fn from(elems: &[E]) -> Self {
1095 let mut data = Vec::with_capacity(elems.len());
1096 for elem in elems.iter() {
1097 data.push(*elem);
1098 }
1099
1100 Data::new(data, Shape::new([elems.len()]))
1101 }
1102}
1103
1104#[allow(deprecated)]
1105impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize> From<[[E; B]; A]> for Data<E, 2> {
1106 fn from(elems: [[E; B]; A]) -> Self {
1107 let mut data = Vec::with_capacity(A * B);
1108 for elem in elems.into_iter().take(A) {
1109 for elem in elem.into_iter().take(B) {
1110 data.push(elem);
1111 }
1112 }
1113
1114 Data::new(data, Shape::new([A, B]))
1115 }
1116}
1117
1118#[allow(deprecated)]
1119impl<E: core::fmt::Debug + Copy, const A: usize, const B: usize, const C: usize>
1120 From<[[[E; C]; B]; A]> for Data<E, 3>
1121{
1122 fn from(elems: [[[E; C]; B]; A]) -> Self {
1123 let mut data = Vec::with_capacity(A * B * C);
1124
1125 for elem in elems.into_iter().take(A) {
1126 for elem in elem.into_iter().take(B) {
1127 for elem in elem.into_iter().take(C) {
1128 data.push(elem);
1129 }
1130 }
1131 }
1132
1133 Data::new(data, Shape::new([A, B, C]))
1134 }
1135}
1136
1137#[allow(deprecated)]
1138impl<
1139 E: core::fmt::Debug + Copy,
1140 const A: usize,
1141 const B: usize,
1142 const C: usize,
1143 const D: usize,
1144 > From<[[[[E; D]; C]; B]; A]> for Data<E, 4>
1145{
1146 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
1147 let mut data = Vec::with_capacity(A * B * C * D);
1148
1149 for elem in elems.into_iter().take(A) {
1150 for elem in elem.into_iter().take(B) {
1151 for elem in elem.into_iter().take(C) {
1152 for elem in elem.into_iter().take(D) {
1153 data.push(elem);
1154 }
1155 }
1156 }
1157 }
1158
1159 Data::new(data, Shape::new([A, B, C, D]))
1160 }
1161}
1162
1163#[allow(deprecated)]
1164impl<E: core::fmt::Debug, const D: usize> core::fmt::Display for Data<E, D> {
1165 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1166 f.write_str(format!("{:?}", &self.value).as_str())
1167 }
1168}
1169
1170fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> {
1171 let epsilon_deviations = tolerance / f32::EPSILON as f64;
1172 let epsilon = match ty {
1173 DType::F64 => f32::EPSILON as f64, DType::F32 => f32::EPSILON as f64,
1175 DType::F16 => half::f16::EPSILON.to_f64(),
1176 DType::BF16 => half::bf16::EPSILON.to_f64(),
1177 _ => unreachable!(),
1178 };
1179 let tolerance_norm = epsilon_deviations * epsilon;
1180 let value_abs = value.abs().max(1.0);
1183 let tolerance_adjusted = tolerance_norm * value_abs;
1184
1185 let err = (value - other).abs();
1186
1187 if err > tolerance_adjusted || err.is_nan() {
1188 Some((err, tolerance_adjusted))
1189 } else {
1190 None
1191 }
1192}
1193
1194#[cfg(test)]
1195#[allow(deprecated)]
1196mod tests {
1197 use crate::quantization::AffineQuantization;
1198
1199 use super::*;
1200 use alloc::vec;
1201 use rand::{rngs::StdRng, SeedableRng};
1202
1203 #[test]
1204 fn into_vec_should_yield_same_value_as_iter() {
1205 let shape = Shape::new([3, 5, 6]);
1206 let data = TensorData::random::<f32, _, _>(
1207 shape,
1208 Distribution::Default,
1209 &mut StdRng::from_entropy(),
1210 );
1211
1212 let expected = data.iter::<f32>().collect::<Vec<f32>>();
1213 let actual = data.into_vec::<f32>().unwrap();
1214
1215 assert_eq!(expected, actual);
1216 }
1217
1218 #[test]
1219 #[should_panic]
1220 fn into_vec_should_assert_wrong_dtype() {
1221 let shape = Shape::new([3, 5, 6]);
1222 let data = TensorData::random::<f32, _, _>(
1223 shape,
1224 Distribution::Default,
1225 &mut StdRng::from_entropy(),
1226 );
1227
1228 data.into_vec::<i32>().unwrap();
1229 }
1230
1231 #[test]
1232 fn should_have_right_num_elements() {
1233 let shape = Shape::new([3, 5, 6]);
1234 let num_elements = shape.num_elements();
1235 let data = TensorData::random::<f32, _, _>(
1236 shape,
1237 Distribution::Default,
1238 &mut StdRng::from_entropy(),
1239 );
1240
1241 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
1243 }
1244
1245 #[test]
1246 fn should_have_right_shape() {
1247 let data = TensorData::from([[3.0, 5.0, 6.0]]);
1248 assert_eq!(data.shape, vec![1, 3]);
1249
1250 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
1251 assert_eq!(data.shape, vec![2, 3]);
1252
1253 let data = TensorData::from([3.0, 5.0, 6.0]);
1254 assert_eq!(data.shape, vec![3]);
1255 }
1256
1257 #[test]
1258 fn should_assert_appox_eq_limit() {
1259 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1260 let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
1261
1262 data1.assert_approx_eq(&data2, 2);
1263 }
1264
1265 #[test]
1266 #[should_panic]
1267 fn should_assert_approx_eq_above_limit() {
1268 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1269 let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
1270
1271 data1.assert_approx_eq(&data2, 2);
1272 }
1273
1274 #[test]
1275 #[should_panic]
1276 fn should_assert_appox_eq_check_shape() {
1277 let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1278 let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
1279
1280 data1.assert_approx_eq(&data2, 2);
1281 }
1282
1283 #[test]
1284 fn should_convert_bytes_correctly() {
1285 let mut vector: Vec<f32> = Vec::with_capacity(5);
1286 vector.push(2.0);
1287 vector.push(3.0);
1288 let data1 = TensorData::new(vector, vec![2]);
1289
1290 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
1291 assert_eq!(data1.bytes.len(), 2 * factor);
1292 assert_eq!(data1.bytes.capacity(), 5 * factor);
1293 }
1294
1295 #[test]
1296 fn should_convert_bytes_correctly_inplace() {
1297 fn test_precision<E: Element>() {
1298 let data = TensorData::new((0..32).collect(), [32]);
1299 for (i, val) in data
1300 .clone()
1301 .convert::<E>()
1302 .into_vec::<E>()
1303 .unwrap()
1304 .into_iter()
1305 .enumerate()
1306 {
1307 assert_eq!(i as u32, val.elem::<u32>())
1308 }
1309 }
1310 test_precision::<f32>();
1311 test_precision::<f16>();
1312 test_precision::<i64>();
1313 test_precision::<i32>();
1314 }
1315
1316 #[test]
1317 #[should_panic = "Expected quantized data"]
1318 fn should_not_dequantize() {
1319 let data = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1320 data.dequantize().unwrap();
1321 }
1322
1323 #[test]
1324 fn should_support_dequantize() {
1325 let data = TensorData::quantized(
1327 vec![-128i8, -77, -26, 25, 76, 127],
1328 [2, 3],
1329 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.019607844, -128)),
1330 );
1331
1332 let output = data.dequantize().unwrap();
1333
1334 output.assert_approx_eq(&TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), 4);
1335 }
1336}