1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use half::{bf16, f16};
9use num_traits::{Float, ToPrimitive};
10
11use crate::{
12 DType, Distribution, Element, ElementConversion,
13 quantization::{QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes},
14 tensor::bytes::Bytes,
15};
16
17use rand::RngCore;
18
19#[derive(Debug)]
21pub enum DataError {
22 CastError(CheckedCastError),
24 TypeMismatch(String),
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
30pub struct TensorData {
31 pub bytes: Bytes,
33
34 pub shape: Vec<usize>,
36
37 pub dtype: DType,
39}
40
41impl TensorData {
42 pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
44 let shape = shape.into();
46 Self::check_data_len(&value, &shape);
47
48 Self {
49 bytes: Bytes::from_elems(value),
50 shape,
51 dtype: E::dtype(),
52 }
53 }
54
55 pub fn quantized<E: Element, S: Into<Vec<usize>>>(
57 value: Vec<E>,
58 shape: S,
59 strategy: QuantizationStrategy,
60 ) -> Self {
61 let shape = shape.into();
62 Self::check_data_len(&value, &shape);
63
64 let q_bytes = QuantizedBytes::new(value, strategy);
65
66 Self {
67 bytes: q_bytes.bytes,
68 shape,
69 dtype: DType::QFloat(q_bytes.scheme),
70 }
71 }
72
73 pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
78 Self {
79 bytes: Bytes::from_bytes_vec(bytes),
80 shape: shape.into(),
81 dtype,
82 }
83 }
84
85 fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
87 let expected_data_len = Self::numel(shape);
88 let num_data = data.len();
89 assert_eq!(
90 expected_data_len, num_data,
91 "Shape {:?} is invalid for input of size {:?}",
92 shape, num_data,
93 );
94 }
95
96 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
98 if E::dtype() == self.dtype {
99 match E::dtype() {
100 DType::Bool => {
104 let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
105 .map_err(DataError::CastError)?;
106 Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
107 }
108 _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
109 }
110 } else {
111 Err(DataError::TypeMismatch(format!(
112 "Invalid target element type (expected {:?}, got {:?})",
113 self.dtype,
114 E::dtype()
115 )))
116 }
117 }
118
119 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
124 if E::dtype() == self.dtype {
125 match E::dtype() {
126 DType::Bool => {
130 let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
131 .map_err(DataError::CastError)?;
132 Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
133 }
134 _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
135 .map_err(DataError::CastError),
136 }
137 } else {
138 Err(DataError::TypeMismatch(format!(
139 "Invalid target element type (expected {:?}, got {:?})",
140 self.dtype,
141 E::dtype()
142 )))
143 }
144 }
145
146 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
148 Ok(self.as_slice()?.to_vec())
149 }
150
151 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
153 if E::dtype() != self.dtype {
155 return Err(DataError::TypeMismatch(format!(
156 "Invalid target element type (expected {:?}, got {:?})",
157 self.dtype,
158 E::dtype()
159 )));
160 }
161
162 match E::dtype() {
163 DType::Bool => {
167 let vec = self.into_vec_unchecked::<u8>()?;
168 Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
169 }
170 _ => self.into_vec_unchecked(),
171 }
172 }
173
174 fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
176 let mut me = self;
177 me.bytes = match me.bytes.try_into_vec::<E>() {
178 Ok(elems) => return Ok(elems),
179 Err(bytes) => bytes,
180 };
181 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
184 .map_err(DataError::CastError)?
185 .to_vec())
186 }
187
188 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
190 if E::dtype() == self.dtype {
191 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
192 } else {
193 match self.dtype {
194 DType::I8 => Box::new(
195 bytemuck::checked::cast_slice(&self.bytes)
196 .iter()
197 .map(|e: &i8| e.elem::<E>()),
198 ),
199 DType::I16 => Box::new(
200 bytemuck::checked::cast_slice(&self.bytes)
201 .iter()
202 .map(|e: &i16| e.elem::<E>()),
203 ),
204 DType::I32 => Box::new(
205 bytemuck::checked::cast_slice(&self.bytes)
206 .iter()
207 .map(|e: &i32| e.elem::<E>()),
208 ),
209 DType::I64 => Box::new(
210 bytemuck::checked::cast_slice(&self.bytes)
211 .iter()
212 .map(|e: &i64| e.elem::<E>()),
213 ),
214 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
215 DType::U16 => Box::new(
216 bytemuck::checked::cast_slice(&self.bytes)
217 .iter()
218 .map(|e: &u16| e.elem::<E>()),
219 ),
220 DType::U32 => Box::new(
221 bytemuck::checked::cast_slice(&self.bytes)
222 .iter()
223 .map(|e: &u32| e.elem::<E>()),
224 ),
225 DType::U64 => Box::new(
226 bytemuck::checked::cast_slice(&self.bytes)
227 .iter()
228 .map(|e: &u64| e.elem::<E>()),
229 ),
230 DType::BF16 => Box::new(
231 bytemuck::checked::cast_slice(&self.bytes)
232 .iter()
233 .map(|e: &bf16| e.elem::<E>()),
234 ),
235 DType::F16 => Box::new(
236 bytemuck::checked::cast_slice(&self.bytes)
237 .iter()
238 .map(|e: &f16| e.elem::<E>()),
239 ),
240 DType::F32 | DType::Flex32 => Box::new(
241 bytemuck::checked::cast_slice(&self.bytes)
242 .iter()
243 .map(|e: &f32| e.elem::<E>()),
244 ),
245 DType::F64 => Box::new(
246 bytemuck::checked::cast_slice(&self.bytes)
247 .iter()
248 .map(|e: &f64| e.elem::<E>()),
249 ),
250 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
252 DType::QFloat(scheme) => match scheme {
253 QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
254 let q_bytes = QuantizedBytes {
256 bytes: self.bytes.clone(),
257 scheme,
258 num_elements: self.num_elements(),
259 };
260 let (values, _) = q_bytes.into_vec_i8();
261
262 Box::new(
263 values
264 .iter()
265 .map(|e: &i8| e.elem::<E>())
266 .collect::<Vec<_>>()
267 .into_iter(),
268 )
269 }
270 },
271 }
272 }
273 }
274
275 pub fn num_elements(&self) -> usize {
277 Self::numel(&self.shape)
278 }
279
280 fn numel(shape: &[usize]) -> usize {
281 shape.iter().product()
282 }
283
284 pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
286 shape: S,
287 distribution: Distribution,
288 rng: &mut R,
289 ) -> Self {
290 let shape = shape.into();
291 let num_elements = Self::numel(&shape);
292 let mut data = Vec::with_capacity(num_elements);
293
294 for _ in 0..num_elements {
295 data.push(E::random(distribution, rng));
296 }
297
298 TensorData::new(data, shape)
299 }
300
301 pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
303 let shape = shape.into();
304 let num_elements = Self::numel(&shape);
305 let mut data = Vec::<E>::with_capacity(num_elements);
306
307 for _ in 0..num_elements {
308 data.push(0.elem());
309 }
310
311 TensorData::new(data, shape)
312 }
313
314 pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
316 let shape = shape.into();
317 let num_elements = Self::numel(&shape);
318 let mut data = Vec::<E>::with_capacity(num_elements);
319
320 for _ in 0..num_elements {
321 data.push(1.elem());
322 }
323
324 TensorData::new(data, shape)
325 }
326
327 pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
329 let shape = shape.into();
330 let num_elements = Self::numel(&shape);
331 let mut data = Vec::<E>::with_capacity(num_elements);
332 for _ in 0..num_elements {
333 data.push(fill_value)
334 }
335
336 TensorData::new(data, shape)
337 }
338
339 pub fn convert<E: Element>(self) -> Self {
341 self.convert_dtype(E::dtype())
342 }
343
344 pub fn convert_dtype(self, dtype: DType) -> Self {
346 if dtype == self.dtype {
347 self
348 } else if dtype.size() == self.dtype.size()
349 && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
350 && !matches!(dtype, DType::Bool | DType::QFloat(_))
351 {
352 match self.dtype {
353 DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
354 DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
355 DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
356 DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
357 DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
358 DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
359 DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
360 DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
361 DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
362 DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
363 DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
364 DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
365 DType::Bool | DType::QFloat(_) => unreachable!(),
366 }
367 } else {
368 match self.dtype {
369 DType::F64 => self.convert_clone_dtype::<f64>(dtype),
370 DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
371 DType::F16 => self.convert_clone_dtype::<f16>(dtype),
372 DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
373 DType::I64 => self.convert_clone_dtype::<i64>(dtype),
374 DType::I32 => self.convert_clone_dtype::<i32>(dtype),
375 DType::I16 => self.convert_clone_dtype::<i16>(dtype),
376 DType::I8 => self.convert_clone_dtype::<i8>(dtype),
377 DType::U64 => self.convert_clone_dtype::<u64>(dtype),
378 DType::U32 => self.convert_clone_dtype::<u32>(dtype),
379 DType::U16 => self.convert_clone_dtype::<u16>(dtype),
380 DType::U8 => self.convert_clone_dtype::<u8>(dtype),
381 DType::Bool => self.convert_clone_dtype::<bool>(dtype),
382 DType::QFloat(_) => unreachable!(),
383 }
384 }
385 }
386
387 fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
388 match dtype {
389 DType::F64 => self.convert_inplace::<Current, f64>(),
390 DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
391 DType::F16 => self.convert_inplace::<Current, f16>(),
392 DType::BF16 => self.convert_inplace::<Current, bf16>(),
393 DType::I64 => self.convert_inplace::<Current, i64>(),
394 DType::I32 => self.convert_inplace::<Current, i32>(),
395 DType::I16 => self.convert_inplace::<Current, i16>(),
396 DType::I8 => self.convert_inplace::<Current, i8>(),
397 DType::U64 => self.convert_inplace::<Current, u64>(),
398 DType::U32 => self.convert_inplace::<Current, u32>(),
399 DType::U16 => self.convert_inplace::<Current, u16>(),
400 DType::U8 => self.convert_inplace::<Current, u8>(),
401 DType::Bool | DType::QFloat(_) => unreachable!(),
402 }
403 }
404
405 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
406 mut self,
407 ) -> Self {
408 for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
409 let t: Target = x.elem();
410 let x = cast_mut::<_, Target>(x);
411 *x = t;
412 }
413
414 self.dtype = Target::dtype();
415
416 self
417 }
418
419 fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
420 match dtype {
421 DType::F64 => self.convert_clone::<Current, f64>(),
422 DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
423 DType::F16 => self.convert_clone::<Current, f16>(),
424 DType::BF16 => self.convert_clone::<Current, bf16>(),
425 DType::I64 => self.convert_clone::<Current, i64>(),
426 DType::I32 => self.convert_clone::<Current, i32>(),
427 DType::I16 => self.convert_clone::<Current, i16>(),
428 DType::I8 => self.convert_clone::<Current, i8>(),
429 DType::U64 => self.convert_clone::<Current, u64>(),
430 DType::U32 => self.convert_clone::<Current, u32>(),
431 DType::U16 => self.convert_clone::<Current, u16>(),
432 DType::U8 => self.convert_clone::<Current, u8>(),
433 DType::Bool => self.convert_clone::<Current, bool>(),
434 DType::QFloat(_) => unreachable!(),
435 }
436 }
437
438 fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
439 self,
440 ) -> Self {
441 let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
442 let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
443
444 for (x, out) in this.iter().zip(&mut out) {
445 *out = x.elem();
446 }
447
448 Self::new(out, self.shape)
449 }
450
451 pub fn as_bytes(&self) -> &[u8] {
453 &self.bytes
454 }
455
456 pub fn into_bytes(self) -> Bytes {
458 self.bytes
459 }
460
461 pub fn with_quantization(self, quantization: QuantizationStrategy) -> Self {
467 assert_eq!(
468 self.dtype,
469 DType::F32,
470 "Only f32 data type can be quantized"
471 );
472 let values = quantization.quantize(self.as_slice().unwrap());
473 TensorData::quantized(values, self.shape, quantization)
474 }
475
476 pub fn dequantize(self) -> Result<Self, DataError> {
478 if let DType::QFloat(scheme) = self.dtype {
479 let num_elements = self.num_elements();
480 let q_bytes = QuantizedBytes {
481 bytes: self.bytes,
482 scheme,
483 num_elements,
484 };
485
486 let values = q_bytes.dequantize().0;
487 Ok(Self::new(values, self.shape))
488 } else {
489 Err(DataError::TypeMismatch(format!(
490 "Expected quantized data, got {:?}",
491 self.dtype
492 )))
493 }
494 }
495
496 #[track_caller]
508 pub fn assert_eq(&self, other: &Self, strict: bool) {
509 if strict {
510 assert_eq!(
511 self.dtype, other.dtype,
512 "Data types differ ({:?} != {:?})",
513 self.dtype, other.dtype
514 );
515 }
516
517 match self.dtype {
518 DType::F64 => self.assert_eq_elem::<f64>(other),
519 DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
520 DType::F16 => self.assert_eq_elem::<f16>(other),
521 DType::BF16 => self.assert_eq_elem::<bf16>(other),
522 DType::I64 => self.assert_eq_elem::<i64>(other),
523 DType::I32 => self.assert_eq_elem::<i32>(other),
524 DType::I16 => self.assert_eq_elem::<i16>(other),
525 DType::I8 => self.assert_eq_elem::<i8>(other),
526 DType::U64 => self.assert_eq_elem::<u64>(other),
527 DType::U32 => self.assert_eq_elem::<u32>(other),
528 DType::U16 => self.assert_eq_elem::<u16>(other),
529 DType::U8 => self.assert_eq_elem::<u8>(other),
530 DType::Bool => self.assert_eq_elem::<bool>(other),
531 DType::QFloat(q) => {
532 let q_other = if let DType::QFloat(q_other) = other.dtype {
534 q_other
535 } else {
536 panic!("Quantized data differs from other not quantized data")
537 };
538 match (q, q_other) {
539 (
540 QuantizationScheme::PerTensor(mode, QuantizationType::QInt8),
541 QuantizationScheme::PerTensor(mode_other, QuantizationType::QInt8),
542 ) if mode == mode_other => self.assert_eq_elem::<i8>(other),
543 _ => panic!("Quantization schemes differ ({:?} != {:?})", q, q_other),
544 }
545 }
546 }
547 }
548
549 #[track_caller]
550 fn assert_eq_elem<E: Element>(&self, other: &Self) {
551 let mut message = String::new();
552 if self.shape != other.shape {
553 message += format!(
554 "\n => Shape is different: {:?} != {:?}",
555 self.shape, other.shape
556 )
557 .as_str();
558 }
559
560 let mut num_diff = 0;
561 let max_num_diff = 5;
562 for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
563 if a.cmp(&b).is_ne() {
564 if num_diff < max_num_diff {
566 message += format!("\n => Position {i}: {a} != {b}").as_str();
567 }
568 num_diff += 1;
569 }
570 }
571
572 if num_diff >= max_num_diff {
573 message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
574 }
575
576 if !message.is_empty() {
577 panic!("Tensors are not eq:{}", message);
578 }
579 }
580
581 #[track_caller]
592 pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
593 let mut message = String::new();
594 if self.shape != other.shape {
595 message += format!(
596 "\n => Shape is different: {:?} != {:?}",
597 self.shape, other.shape
598 )
599 .as_str();
600 }
601
602 let iter = self.iter::<F>().zip(other.iter::<F>());
603
604 let mut num_diff = 0;
605 let max_num_diff = 5;
606
607 for (i, (a, b)) in iter.enumerate() {
608 let both_nan = a.is_nan() && b.is_nan();
610 let both_inf =
612 a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
613
614 if both_nan || both_inf {
615 continue;
616 }
617
618 if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
619 if num_diff < max_num_diff {
621 let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
622 let diff_rel = diff_abs / ToPrimitive::to_f64(&(a + b).abs()).unwrap();
623
624 let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
625 let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
626
627 message += format!(
628 "\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
629 )
630 .as_str();
631 }
632 num_diff += 1;
633 }
634 }
635
636 if num_diff >= max_num_diff {
637 message += format!("\n{} more errors...", num_diff - 5).as_str();
638 }
639
640 if !message.is_empty() {
641 panic!("Tensors are not approx eq:{}", message);
642 }
643 }
644
645 pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
656 let start = range.start.elem::<f32>();
657 let end = range.end.elem::<f32>();
658
659 for elem in self.iter::<f32>() {
660 if elem < start || elem >= end {
661 panic!("Element ({elem:?}) is not within range {range:?}");
662 }
663 }
664 }
665
666 pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
676 let start = range.start().elem::<f32>();
677 let end = range.end().elem::<f32>();
678
679 for elem in self.iter::<f32>() {
680 if elem < start || elem > end {
681 panic!("Element ({elem:?}) is not within range {range:?}");
682 }
683 }
684 }
685}
686
687impl<E: Element, const A: usize> From<[E; A]> for TensorData {
688 fn from(elems: [E; A]) -> Self {
689 TensorData::new(elems.to_vec(), [A])
690 }
691}
692
693impl<const A: usize> From<[usize; A]> for TensorData {
694 fn from(elems: [usize; A]) -> Self {
695 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
696 }
697}
698
699impl From<&[usize]> for TensorData {
700 fn from(elems: &[usize]) -> Self {
701 let mut data = Vec::with_capacity(elems.len());
702 for elem in elems.iter() {
703 data.push(*elem as i64);
704 }
705
706 TensorData::new(data, [elems.len()])
707 }
708}
709
710impl<E: Element> From<&[E]> for TensorData {
711 fn from(elems: &[E]) -> Self {
712 let mut data = Vec::with_capacity(elems.len());
713 for elem in elems.iter() {
714 data.push(*elem);
715 }
716
717 TensorData::new(data, [elems.len()])
718 }
719}
720
721impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
722 fn from(elems: [[E; B]; A]) -> Self {
723 let mut data = Vec::with_capacity(A * B);
724 for elem in elems.into_iter().take(A) {
725 for elem in elem.into_iter().take(B) {
726 data.push(elem);
727 }
728 }
729
730 TensorData::new(data, [A, B])
731 }
732}
733
734impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
735 for TensorData
736{
737 fn from(elems: [[[E; C]; B]; A]) -> Self {
738 let mut data = Vec::with_capacity(A * B * C);
739
740 for elem in elems.into_iter().take(A) {
741 for elem in elem.into_iter().take(B) {
742 for elem in elem.into_iter().take(C) {
743 data.push(elem);
744 }
745 }
746 }
747
748 TensorData::new(data, [A, B, C])
749 }
750}
751
752impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
753 From<[[[[E; D]; C]; B]; A]> for TensorData
754{
755 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
756 let mut data = Vec::with_capacity(A * B * C * D);
757
758 for elem in elems.into_iter().take(A) {
759 for elem in elem.into_iter().take(B) {
760 for elem in elem.into_iter().take(C) {
761 for elem in elem.into_iter().take(D) {
762 data.push(elem);
763 }
764 }
765 }
766 }
767
768 TensorData::new(data, [A, B, C, D])
769 }
770}
771
772impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
773 From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
774{
775 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
776 let mut data = Vec::with_capacity(A * B * C * D * E);
777
778 for elem in elems.into_iter().take(A) {
779 for elem in elem.into_iter().take(B) {
780 for elem in elem.into_iter().take(C) {
781 for elem in elem.into_iter().take(D) {
782 for elem in elem.into_iter().take(E) {
783 data.push(elem);
784 }
785 }
786 }
787 }
788 }
789
790 TensorData::new(data, [A, B, C, D, E])
791 }
792}
793
794impl core::fmt::Display for TensorData {
795 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
796 let fmt = match self.dtype {
797 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
798 DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
799 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
800 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
801 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
802 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
803 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
804 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
805 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
806 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
807 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
808 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
809 DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
810 DType::QFloat(scheme) => match scheme {
811 QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
812 format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
813 }
814 },
815 };
816 f.write_str(fmt.as_str())
817 }
818}
819
820#[derive(Debug, Clone, Copy)]
839pub struct Tolerance<F> {
840 relative: F,
841 absolute: F,
842}
843
844impl<F: Float> Default for Tolerance<F> {
845 fn default() -> Self {
846 Self {
847 relative: F::from(64).unwrap() * F::epsilon(),
848 absolute: F::from(16).unwrap() * F::min_positive_value(),
849 }
850 }
851}
852
853impl<F: Float> Tolerance<F> {
854 pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
864 let relative = Self::check_relative(relative);
865 let absolute = Self::check_absolute(absolute);
866
867 Self { relative, absolute }
868 }
869
870 pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
880 let relative = Self::check_relative(tolerance);
881
882 Self {
883 relative,
884 absolute: F::from(0.0).unwrap(),
885 }
886 }
887
888 pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
898 let absolute = Self::check_absolute(tolerance);
899
900 Self {
901 relative: F::from(0.0).unwrap(),
902 absolute,
903 }
904 }
905
906 pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
908 self.relative = Self::check_relative(tolerance);
909 self
910 }
911
912 pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
914 if core::mem::size_of::<F>() == 2 {
915 self.relative = Self::check_relative(tolerance);
916 }
917 self
918 }
919
920 pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
922 if core::mem::size_of::<F>() == 4 {
923 self.relative = Self::check_relative(tolerance);
924 }
925 self
926 }
927
928 pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
930 if core::mem::size_of::<F>() == 8 {
931 self.relative = Self::check_relative(tolerance);
932 }
933 self
934 }
935
936 pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
938 self.absolute = Self::check_absolute(tolerance);
939 self
940 }
941
942 pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
944 if core::mem::size_of::<F>() == 2 {
945 self.absolute = Self::check_absolute(tolerance);
946 }
947 self
948 }
949
950 pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
952 if core::mem::size_of::<F>() == 4 {
953 self.absolute = Self::check_absolute(tolerance);
954 }
955 self
956 }
957
958 pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
960 if core::mem::size_of::<F>() == 8 {
961 self.absolute = Self::check_absolute(tolerance);
962 }
963 self
964 }
965
966 pub fn approx_eq(&self, x: F, y: F) -> bool {
968 if x == y {
974 return true;
975 }
976
977 let diff = (x - y).abs();
978
979 let norm = (x + y).abs();
980 let norm = norm.min(F::max_value()); diff < self.absolute.max(self.relative * norm)
983 }
984
985 fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
986 let tolerance = F::from(tolerance).unwrap();
987 assert!(tolerance <= F::one());
988 tolerance
989 }
990
991 fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
992 let tolerance = F::from(tolerance).unwrap();
993 assert!(tolerance >= F::zero());
994 tolerance
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use crate::{Shape, quantization::SymmetricQuantization};
1001
1002 use super::*;
1003 use alloc::vec;
1004 use rand::{SeedableRng, rngs::StdRng};
1005
1006 #[test]
1007 fn into_vec_should_yield_same_value_as_iter() {
1008 let shape = Shape::new([3, 5, 6]);
1009 let data = TensorData::random::<f32, _, _>(
1010 shape,
1011 Distribution::Default,
1012 &mut StdRng::from_os_rng(),
1013 );
1014
1015 let expected = data.iter::<f32>().collect::<Vec<f32>>();
1016 let actual = data.into_vec::<f32>().unwrap();
1017
1018 assert_eq!(expected, actual);
1019 }
1020
1021 #[test]
1022 #[should_panic]
1023 fn into_vec_should_assert_wrong_dtype() {
1024 let shape = Shape::new([3, 5, 6]);
1025 let data = TensorData::random::<f32, _, _>(
1026 shape,
1027 Distribution::Default,
1028 &mut StdRng::from_os_rng(),
1029 );
1030
1031 data.into_vec::<i32>().unwrap();
1032 }
1033
1034 #[test]
1035 fn should_have_right_num_elements() {
1036 let shape = Shape::new([3, 5, 6]);
1037 let num_elements = shape.num_elements();
1038 let data = TensorData::random::<f32, _, _>(
1039 shape,
1040 Distribution::Default,
1041 &mut StdRng::from_os_rng(),
1042 );
1043
1044 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
1046 }
1047
1048 #[test]
1049 fn should_have_right_shape() {
1050 let data = TensorData::from([[3.0, 5.0, 6.0]]);
1051 assert_eq!(data.shape, vec![1, 3]);
1052
1053 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
1054 assert_eq!(data.shape, vec![2, 3]);
1055
1056 let data = TensorData::from([3.0, 5.0, 6.0]);
1057 assert_eq!(data.shape, vec![3]);
1058 }
1059
1060 #[test]
1061 fn should_assert_appox_eq_limit() {
1062 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1063 let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
1064
1065 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
1066 data1.assert_approx_eq::<half::f16>(&data2, Tolerance::absolute(3e-2));
1067 }
1068
1069 #[test]
1070 #[should_panic]
1071 fn should_assert_approx_eq_above_limit() {
1072 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1073 let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
1074
1075 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1076 }
1077
1078 #[test]
1079 #[should_panic]
1080 fn should_assert_approx_eq_check_shape() {
1081 let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1082 let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
1083
1084 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1085 }
1086
1087 #[test]
1088 fn should_convert_bytes_correctly() {
1089 let mut vector: Vec<f32> = Vec::with_capacity(5);
1090 vector.push(2.0);
1091 vector.push(3.0);
1092 let data1 = TensorData::new(vector, vec![2]);
1093
1094 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
1095 assert_eq!(data1.bytes.len(), 2 * factor);
1096 assert_eq!(data1.bytes.capacity(), 5 * factor);
1097 }
1098
1099 #[test]
1100 fn should_convert_bytes_correctly_inplace() {
1101 fn test_precision<E: Element>() {
1102 let data = TensorData::new((0..32).collect(), [32]);
1103 for (i, val) in data
1104 .clone()
1105 .convert::<E>()
1106 .into_vec::<E>()
1107 .unwrap()
1108 .into_iter()
1109 .enumerate()
1110 {
1111 assert_eq!(i as u32, val.elem::<u32>())
1112 }
1113 }
1114 test_precision::<f32>();
1115 test_precision::<f16>();
1116 test_precision::<i64>();
1117 test_precision::<i32>();
1118 }
1119
1120 #[test]
1121 #[should_panic = "Expected quantized data"]
1122 fn should_not_dequantize() {
1123 let data = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1124 data.dequantize().unwrap();
1125 }
1126
1127 #[test]
1128 fn should_support_dequantize() {
1129 let data = TensorData::quantized(
1130 vec![-127i8, -77, -26, 25, 76, 127],
1131 [2, 3],
1132 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.1)),
1133 );
1134
1135 let output = data.dequantize().unwrap();
1136
1137 output.assert_approx_eq::<f32>(
1138 &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1139 Tolerance::default(),
1140 );
1141
1142 output.assert_approx_eq::<f16>(
1143 &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1144 Tolerance::default(),
1145 );
1146 }
1147}