1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use cubecl_quant::scheme::QuantScheme;
9use half::{bf16, f16};
10use num_traits::{Float, ToPrimitive};
11
12use crate::{
13 DType, Distribution, Element, ElementConversion,
14 quantization::{QuantValue, QuantizationStrategy, QuantizedBytes},
15 tensor::Bytes,
16};
17
18use rand::RngCore;
19
20use super::quantization::{QuantLevel, QuantMode};
21
22#[derive(Debug)]
24pub enum DataError {
25 CastError(CheckedCastError),
27 TypeMismatch(String),
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
33pub struct TensorData {
34 pub bytes: Bytes,
36
37 pub shape: Vec<usize>,
39
40 pub dtype: DType,
42}
43
44impl TensorData {
45 pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
47 let shape = shape.into();
49 Self::check_data_len(&value, &shape);
50
51 Self {
52 bytes: Bytes::from_elems(value),
53 shape,
54 dtype: E::dtype(),
55 }
56 }
57
58 pub fn quantized<E: Element, S: Into<Vec<usize>>>(
60 value: Vec<E>,
61 shape: S,
62 strategy: QuantizationStrategy,
63 scheme: QuantScheme,
64 ) -> Self {
65 let shape = shape.into();
66 Self::check_data_len(&value, &shape);
67
68 let q_bytes = QuantizedBytes::new(value, strategy, scheme);
69
70 Self {
71 bytes: q_bytes.bytes,
72 shape,
73 dtype: DType::QFloat(q_bytes.scheme),
74 }
75 }
76
77 pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
79 Self {
80 bytes,
81 shape: shape.into(),
82 dtype,
83 }
84 }
85
86 pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
91 Self {
92 bytes: Bytes::from_bytes_vec(bytes),
93 shape: shape.into(),
94 dtype,
95 }
96 }
97
98 fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
100 let expected_data_len = Self::numel(shape);
101 let num_data = data.len();
102 assert_eq!(
103 expected_data_len, num_data,
104 "Shape {shape:?} is invalid for input of size {num_data:?}",
105 );
106 }
107
108 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
110 if E::dtype() == self.dtype {
111 match E::dtype() {
112 DType::Bool => {
116 let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
117 .map_err(DataError::CastError)?;
118 Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
119 }
120 _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
121 }
122 } else {
123 Err(DataError::TypeMismatch(format!(
124 "Invalid target element type (expected {:?}, got {:?})",
125 self.dtype,
126 E::dtype()
127 )))
128 }
129 }
130
131 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
136 if E::dtype() == self.dtype {
137 match E::dtype() {
138 DType::Bool => {
142 let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
143 .map_err(DataError::CastError)?;
144 Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
145 }
146 _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
147 .map_err(DataError::CastError),
148 }
149 } else {
150 Err(DataError::TypeMismatch(format!(
151 "Invalid target element type (expected {:?}, got {:?})",
152 self.dtype,
153 E::dtype()
154 )))
155 }
156 }
157
158 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
160 Ok(self.as_slice()?.to_vec())
161 }
162
163 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
165 if E::dtype() != self.dtype {
167 return Err(DataError::TypeMismatch(format!(
168 "Invalid target element type (expected {:?}, got {:?})",
169 self.dtype,
170 E::dtype()
171 )));
172 }
173
174 match E::dtype() {
175 DType::Bool => {
179 let vec = self.into_vec_unchecked::<u8>()?;
180 Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
181 }
182 _ => self.into_vec_unchecked(),
183 }
184 }
185
186 fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
188 let mut me = self;
189 me.bytes = match me.bytes.try_into_vec::<E>() {
190 Ok(elems) => return Ok(elems),
191 Err(bytes) => bytes,
192 };
193 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
196 .map_err(DataError::CastError)?
197 .to_vec())
198 }
199
200 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
202 if E::dtype() == self.dtype {
203 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
204 } else {
205 match self.dtype {
206 DType::I8 => Box::new(
207 bytemuck::checked::cast_slice(&self.bytes)
208 .iter()
209 .map(|e: &i8| e.elem::<E>()),
210 ),
211 DType::I16 => Box::new(
212 bytemuck::checked::cast_slice(&self.bytes)
213 .iter()
214 .map(|e: &i16| e.elem::<E>()),
215 ),
216 DType::I32 => Box::new(
217 bytemuck::checked::cast_slice(&self.bytes)
218 .iter()
219 .map(|e: &i32| e.elem::<E>()),
220 ),
221 DType::I64 => Box::new(
222 bytemuck::checked::cast_slice(&self.bytes)
223 .iter()
224 .map(|e: &i64| e.elem::<E>()),
225 ),
226 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
227 DType::U16 => Box::new(
228 bytemuck::checked::cast_slice(&self.bytes)
229 .iter()
230 .map(|e: &u16| e.elem::<E>()),
231 ),
232 DType::U32 => Box::new(
233 bytemuck::checked::cast_slice(&self.bytes)
234 .iter()
235 .map(|e: &u32| e.elem::<E>()),
236 ),
237 DType::U64 => Box::new(
238 bytemuck::checked::cast_slice(&self.bytes)
239 .iter()
240 .map(|e: &u64| e.elem::<E>()),
241 ),
242 DType::BF16 => Box::new(
243 bytemuck::checked::cast_slice(&self.bytes)
244 .iter()
245 .map(|e: &bf16| e.elem::<E>()),
246 ),
247 DType::F16 => Box::new(
248 bytemuck::checked::cast_slice(&self.bytes)
249 .iter()
250 .map(|e: &f16| e.elem::<E>()),
251 ),
252 DType::F32 | DType::Flex32 => Box::new(
253 bytemuck::checked::cast_slice(&self.bytes)
254 .iter()
255 .map(|e: &f32| e.elem::<E>()),
256 ),
257 DType::F64 => Box::new(
258 bytemuck::checked::cast_slice(&self.bytes)
259 .iter()
260 .map(|e: &f64| e.elem::<E>()),
261 ),
262 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
264 DType::QFloat(scheme) => match scheme {
265 QuantScheme {
266 level: QuantLevel::Tensor | QuantLevel::Block(_),
267 mode: QuantMode::Symmetric,
268 value:
269 QuantValue::Q8F
270 | QuantValue::Q8S
271 | QuantValue::Q4F
273 | QuantValue::Q4S
274 | QuantValue::Q2F
275 | QuantValue::Q2S,
276 ..
277 } => {
278 let q_bytes = QuantizedBytes {
280 bytes: self.bytes.clone(),
281 scheme,
282 num_elements: self.num_elements(),
283 };
284 let (values, _) = q_bytes.into_vec_i8();
285
286 Box::new(
287 values
288 .iter()
289 .map(|e: &i8| e.elem::<E>())
290 .collect::<Vec<_>>()
291 .into_iter(),
292 )
293 }
294 QuantScheme {
295 level: QuantLevel::Tensor | QuantLevel::Block(_),
296 mode: QuantMode::Symmetric,
297 value:
298 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
299 ..
300 } => {
301 unimplemented!("Not yet implemented for iteration");
302 }
303 },
304 }
305 }
306 }
307
308 pub fn rank(&self) -> usize {
310 self.shape.len()
311 }
312
313 pub fn num_elements(&self) -> usize {
315 Self::numel(&self.shape)
316 }
317
318 fn numel(shape: &[usize]) -> usize {
319 shape.iter().product()
320 }
321
322 pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
324 shape: S,
325 distribution: Distribution,
326 rng: &mut R,
327 ) -> Self {
328 let shape = shape.into();
329 let num_elements = Self::numel(&shape);
330 let mut data = Vec::with_capacity(num_elements);
331
332 for _ in 0..num_elements {
333 data.push(E::random(distribution, rng));
334 }
335
336 TensorData::new(data, shape)
337 }
338
339 pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
341 let shape = shape.into();
342 let num_elements = Self::numel(&shape);
343 let mut data = Vec::<E>::with_capacity(num_elements);
344
345 for _ in 0..num_elements {
346 data.push(0.elem());
347 }
348
349 TensorData::new(data, shape)
350 }
351
352 pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
354 let shape = shape.into();
355 let num_elements = Self::numel(&shape);
356 let mut data = Vec::<E>::with_capacity(num_elements);
357
358 for _ in 0..num_elements {
359 data.push(1.elem());
360 }
361
362 TensorData::new(data, shape)
363 }
364
365 pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
367 let shape = shape.into();
368 let num_elements = Self::numel(&shape);
369 let mut data = Vec::<E>::with_capacity(num_elements);
370 for _ in 0..num_elements {
371 data.push(fill_value)
372 }
373
374 TensorData::new(data, shape)
375 }
376
377 pub(crate) fn full_dtype<E: Element, S: Into<Vec<usize>>>(
378 shape: S,
379 fill_value: E,
380 dtype: DType,
381 ) -> TensorData {
382 match dtype {
383 DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
384 DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
385 DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
386 DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
387 DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
388 DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
389 DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
390 DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
391 DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
392 DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
393 DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
394 DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
395 DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
396 DType::QFloat(_) => unreachable!(),
397 }
398 }
399
400 pub fn convert<E: Element>(self) -> Self {
402 self.convert_dtype(E::dtype())
403 }
404
405 pub fn convert_dtype(self, dtype: DType) -> Self {
407 if dtype == self.dtype {
408 self
409 } else if dtype.size() == self.dtype.size()
410 && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
411 && !matches!(dtype, DType::Bool | DType::QFloat(_))
412 {
413 match self.dtype {
414 DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
415 DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
416 DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
417 DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
418 DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
419 DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
420 DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
421 DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
422 DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
423 DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
424 DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
425 DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
426 DType::Bool | DType::QFloat(_) => unreachable!(),
427 }
428 } else {
429 match self.dtype {
430 DType::F64 => self.convert_clone_dtype::<f64>(dtype),
431 DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
432 DType::F16 => self.convert_clone_dtype::<f16>(dtype),
433 DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
434 DType::I64 => self.convert_clone_dtype::<i64>(dtype),
435 DType::I32 => self.convert_clone_dtype::<i32>(dtype),
436 DType::I16 => self.convert_clone_dtype::<i16>(dtype),
437 DType::I8 => self.convert_clone_dtype::<i8>(dtype),
438 DType::U64 => self.convert_clone_dtype::<u64>(dtype),
439 DType::U32 => self.convert_clone_dtype::<u32>(dtype),
440 DType::U16 => self.convert_clone_dtype::<u16>(dtype),
441 DType::U8 => self.convert_clone_dtype::<u8>(dtype),
442 DType::Bool => self.convert_clone_dtype::<bool>(dtype),
443 DType::QFloat(_) => unreachable!(),
444 }
445 }
446 }
447
448 fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
449 match dtype {
450 DType::F64 => self.convert_inplace::<Current, f64>(),
451 DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
452 DType::F16 => self.convert_inplace::<Current, f16>(),
453 DType::BF16 => self.convert_inplace::<Current, bf16>(),
454 DType::I64 => self.convert_inplace::<Current, i64>(),
455 DType::I32 => self.convert_inplace::<Current, i32>(),
456 DType::I16 => self.convert_inplace::<Current, i16>(),
457 DType::I8 => self.convert_inplace::<Current, i8>(),
458 DType::U64 => self.convert_inplace::<Current, u64>(),
459 DType::U32 => self.convert_inplace::<Current, u32>(),
460 DType::U16 => self.convert_inplace::<Current, u16>(),
461 DType::U8 => self.convert_inplace::<Current, u8>(),
462 DType::Bool | DType::QFloat(_) => unreachable!(),
463 }
464 }
465
466 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
467 mut self,
468 ) -> Self {
469 for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
470 let t: Target = x.elem();
471 let x = cast_mut::<_, Target>(x);
472 *x = t;
473 }
474
475 self.dtype = Target::dtype();
476
477 self
478 }
479
480 fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
481 match dtype {
482 DType::F64 => self.convert_clone::<Current, f64>(),
483 DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
484 DType::F16 => self.convert_clone::<Current, f16>(),
485 DType::BF16 => self.convert_clone::<Current, bf16>(),
486 DType::I64 => self.convert_clone::<Current, i64>(),
487 DType::I32 => self.convert_clone::<Current, i32>(),
488 DType::I16 => self.convert_clone::<Current, i16>(),
489 DType::I8 => self.convert_clone::<Current, i8>(),
490 DType::U64 => self.convert_clone::<Current, u64>(),
491 DType::U32 => self.convert_clone::<Current, u32>(),
492 DType::U16 => self.convert_clone::<Current, u16>(),
493 DType::U8 => self.convert_clone::<Current, u8>(),
494 DType::Bool => self.convert_clone::<Current, bool>(),
495 DType::QFloat(_) => unreachable!(),
496 }
497 }
498
499 fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
500 self,
501 ) -> Self {
502 let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
503 let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
504
505 for (x, out) in this.iter().zip(&mut out) {
506 *out = x.elem();
507 }
508
509 Self::new(out, self.shape)
510 }
511
512 pub fn as_bytes(&self) -> &[u8] {
514 &self.bytes
515 }
516
517 pub fn into_bytes(self) -> Bytes {
519 self.bytes
520 }
521
522 pub fn dequantize(self) -> Result<Self, DataError> {
524 if let DType::QFloat(scheme) = self.dtype {
525 let num_elements = self.num_elements();
526 let q_bytes = QuantizedBytes {
527 bytes: self.bytes,
528 scheme,
529 num_elements,
530 };
531
532 let values = q_bytes.dequantize().0;
533 Ok(Self::new(values, self.shape))
534 } else {
535 Err(DataError::TypeMismatch(format!(
536 "Expected quantized data, got {:?}",
537 self.dtype
538 )))
539 }
540 }
541
542 #[track_caller]
554 pub fn assert_eq(&self, other: &Self, strict: bool) {
555 if strict {
556 assert_eq!(
557 self.dtype, other.dtype,
558 "Data types differ ({:?} != {:?})",
559 self.dtype, other.dtype
560 );
561 }
562
563 match self.dtype {
564 DType::F64 => self.assert_eq_elem::<f64>(other),
565 DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
566 DType::F16 => self.assert_eq_elem::<f16>(other),
567 DType::BF16 => self.assert_eq_elem::<bf16>(other),
568 DType::I64 => self.assert_eq_elem::<i64>(other),
569 DType::I32 => self.assert_eq_elem::<i32>(other),
570 DType::I16 => self.assert_eq_elem::<i16>(other),
571 DType::I8 => self.assert_eq_elem::<i8>(other),
572 DType::U64 => self.assert_eq_elem::<u64>(other),
573 DType::U32 => self.assert_eq_elem::<u32>(other),
574 DType::U16 => self.assert_eq_elem::<u16>(other),
575 DType::U8 => self.assert_eq_elem::<u8>(other),
576 DType::Bool => self.assert_eq_elem::<bool>(other),
577 DType::QFloat(q) => {
578 let q_other = if let DType::QFloat(q_other) = other.dtype {
580 q_other
581 } else {
582 panic!("Quantized data differs from other not quantized data")
583 };
584
585 if q.value == q_other.value && q.level == q_other.level {
587 self.assert_eq_elem::<i8>(other)
588 } else {
589 panic!("Quantization schemes differ ({q:?} != {q_other:?})")
590 }
591 }
592 }
593 }
594
595 #[track_caller]
596 fn assert_eq_elem<E: Element>(&self, other: &Self) {
597 let mut message = String::new();
598 if self.shape != other.shape {
599 message += format!(
600 "\n => Shape is different: {:?} != {:?}",
601 self.shape, other.shape
602 )
603 .as_str();
604 }
605
606 let mut num_diff = 0;
607 let max_num_diff = 5;
608 for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
609 if a.cmp(&b).is_ne() {
610 if num_diff < max_num_diff {
612 message += format!("\n => Position {i}: {a} != {b}").as_str();
613 }
614 num_diff += 1;
615 }
616 }
617
618 if num_diff >= max_num_diff {
619 message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
620 }
621
622 if !message.is_empty() {
623 panic!("Tensors are not eq:{message}");
624 }
625 }
626
627 #[track_caller]
638 pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
639 let mut message = String::new();
640 if self.shape != other.shape {
641 message += format!(
642 "\n => Shape is different: {:?} != {:?}",
643 self.shape, other.shape
644 )
645 .as_str();
646 }
647
648 let iter = self.iter::<F>().zip(other.iter::<F>());
649
650 let mut num_diff = 0;
651 let max_num_diff = 5;
652
653 for (i, (a, b)) in iter.enumerate() {
654 let both_nan = a.is_nan() && b.is_nan();
656 let both_inf =
658 a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
659
660 if both_nan || both_inf {
661 continue;
662 }
663
664 if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
665 if num_diff < max_num_diff {
667 let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
668 let max = F::max(a.abs(), b.abs());
669 let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
670
671 let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
672 let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
673
674 message += format!(
675 "\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
676 )
677 .as_str();
678 }
679 num_diff += 1;
680 }
681 }
682
683 if num_diff >= max_num_diff {
684 message += format!("\n{} more errors...", num_diff - 5).as_str();
685 }
686
687 if !message.is_empty() {
688 panic!("Tensors are not approx eq:{message}");
689 }
690 }
691
692 pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
703 for elem in self.iter::<E>() {
704 if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
705 panic!("Element ({elem:?}) is not within range {range:?}");
706 }
707 }
708 }
709
710 pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
720 let start = range.start();
721 let end = range.end();
722
723 for elem in self.iter::<E>() {
724 if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
725 panic!("Element ({elem:?}) is not within range {range:?}");
726 }
727 }
728 }
729}
730
731impl<E: Element, const A: usize> From<[E; A]> for TensorData {
732 fn from(elems: [E; A]) -> Self {
733 TensorData::new(elems.to_vec(), [A])
734 }
735}
736
737impl<const A: usize> From<[usize; A]> for TensorData {
738 fn from(elems: [usize; A]) -> Self {
739 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
740 }
741}
742
743impl From<&[usize]> for TensorData {
744 fn from(elems: &[usize]) -> Self {
745 let mut data = Vec::with_capacity(elems.len());
746 for elem in elems.iter() {
747 data.push(*elem as i64);
748 }
749
750 TensorData::new(data, [elems.len()])
751 }
752}
753
754impl<E: Element> From<&[E]> for TensorData {
755 fn from(elems: &[E]) -> Self {
756 let mut data = Vec::with_capacity(elems.len());
757 for elem in elems.iter() {
758 data.push(*elem);
759 }
760
761 TensorData::new(data, [elems.len()])
762 }
763}
764
765impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
766 fn from(elems: [[E; B]; A]) -> Self {
767 let mut data = Vec::with_capacity(A * B);
768 for elem in elems.into_iter().take(A) {
769 for elem in elem.into_iter().take(B) {
770 data.push(elem);
771 }
772 }
773
774 TensorData::new(data, [A, B])
775 }
776}
777
778impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
779 for TensorData
780{
781 fn from(elems: [[[E; C]; B]; A]) -> Self {
782 let mut data = Vec::with_capacity(A * B * C);
783
784 for elem in elems.into_iter().take(A) {
785 for elem in elem.into_iter().take(B) {
786 for elem in elem.into_iter().take(C) {
787 data.push(elem);
788 }
789 }
790 }
791
792 TensorData::new(data, [A, B, C])
793 }
794}
795
796impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
797 From<[[[[E; D]; C]; B]; A]> for TensorData
798{
799 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
800 let mut data = Vec::with_capacity(A * B * C * D);
801
802 for elem in elems.into_iter().take(A) {
803 for elem in elem.into_iter().take(B) {
804 for elem in elem.into_iter().take(C) {
805 for elem in elem.into_iter().take(D) {
806 data.push(elem);
807 }
808 }
809 }
810 }
811
812 TensorData::new(data, [A, B, C, D])
813 }
814}
815
816impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
817 From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
818{
819 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
820 let mut data = Vec::with_capacity(A * B * C * D * E);
821
822 for elem in elems.into_iter().take(A) {
823 for elem in elem.into_iter().take(B) {
824 for elem in elem.into_iter().take(C) {
825 for elem in elem.into_iter().take(D) {
826 for elem in elem.into_iter().take(E) {
827 data.push(elem);
828 }
829 }
830 }
831 }
832 }
833
834 TensorData::new(data, [A, B, C, D, E])
835 }
836}
837
838impl core::fmt::Display for TensorData {
839 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
840 let fmt = match self.dtype {
841 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
842 DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
843 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
844 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
845 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
846 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
847 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
848 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
849 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
850 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
851 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
852 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
853 DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
854 DType::QFloat(scheme) => match scheme {
855 QuantScheme {
856 level: QuantLevel::Tensor | QuantLevel::Block(_),
857 mode: QuantMode::Symmetric,
858 value:
859 QuantValue::Q8F
860 | QuantValue::Q8S
861 | QuantValue::Q4F
863 | QuantValue::Q4S
864 | QuantValue::Q2F
865 | QuantValue::Q2S,
866 ..
867 } => {
868 format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
869 },
870 QuantScheme {
871 level: QuantLevel::Tensor | QuantLevel::Block(_),
872 mode: QuantMode::Symmetric,
873 value:
874 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
875 ..
876 } => {
877 unimplemented!("Can't format yet");
878 }
879 },
880 };
881 f.write_str(fmt.as_str())
882 }
883}
884
885#[derive(Debug, Clone, Copy)]
904pub struct Tolerance<F> {
905 relative: F,
906 absolute: F,
907}
908
909impl<F: Float> Default for Tolerance<F> {
910 fn default() -> Self {
911 Self::balanced()
912 }
913}
914
915impl<F: Float> Tolerance<F> {
916 pub fn strict() -> Self {
918 Self {
919 relative: F::from(0.00).unwrap(),
920 absolute: F::from(64).unwrap() * F::min_positive_value(),
921 }
922 }
923 pub fn balanced() -> Self {
925 Self {
926 relative: F::from(0.005).unwrap(), absolute: F::from(1e-5).unwrap(),
928 }
929 }
930
931 pub fn permissive() -> Self {
933 Self {
934 relative: F::from(0.01).unwrap(), absolute: F::from(0.01).unwrap(),
936 }
937 }
938 pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
948 let relative = Self::check_relative(relative);
949 let absolute = Self::check_absolute(absolute);
950
951 Self { relative, absolute }
952 }
953
954 pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
964 let relative = Self::check_relative(tolerance);
965
966 Self {
967 relative,
968 absolute: F::from(0.0).unwrap(),
969 }
970 }
971
972 pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
982 let absolute = Self::check_absolute(tolerance);
983
984 Self {
985 relative: F::from(0.0).unwrap(),
986 absolute,
987 }
988 }
989
990 pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
992 self.relative = Self::check_relative(tolerance);
993 self
994 }
995
996 pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
998 if core::mem::size_of::<F>() == 2 {
999 self.relative = Self::check_relative(tolerance);
1000 }
1001 self
1002 }
1003
1004 pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1006 if core::mem::size_of::<F>() == 4 {
1007 self.relative = Self::check_relative(tolerance);
1008 }
1009 self
1010 }
1011
1012 pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1014 if core::mem::size_of::<F>() == 8 {
1015 self.relative = Self::check_relative(tolerance);
1016 }
1017 self
1018 }
1019
1020 pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1022 self.absolute = Self::check_absolute(tolerance);
1023 self
1024 }
1025
1026 pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1028 if core::mem::size_of::<F>() == 2 {
1029 self.absolute = Self::check_absolute(tolerance);
1030 }
1031 self
1032 }
1033
1034 pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1036 if core::mem::size_of::<F>() == 4 {
1037 self.absolute = Self::check_absolute(tolerance);
1038 }
1039 self
1040 }
1041
1042 pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
1044 if core::mem::size_of::<F>() == 8 {
1045 self.absolute = Self::check_absolute(tolerance);
1046 }
1047 self
1048 }
1049
1050 pub fn approx_eq(&self, x: F, y: F) -> bool {
1052 if x == y {
1058 return true;
1059 }
1060
1061 let diff = (x - y).abs();
1062 let max = F::max(x.abs(), y.abs());
1063
1064 diff < self.absolute.max(self.relative * max)
1065 }
1066
1067 fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
1068 let tolerance = F::from(tolerance).unwrap();
1069 assert!(tolerance <= F::one());
1070 tolerance
1071 }
1072
1073 fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
1074 let tolerance = F::from(tolerance).unwrap();
1075 assert!(tolerance >= F::zero());
1076 tolerance
1077 }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082 use crate::{Shape, quantization::SymmetricQuantization};
1083
1084 use super::*;
1085 use alloc::vec;
1086 use rand::{SeedableRng, rngs::StdRng};
1087
1088 #[test]
1089 fn should_have_rank() {
1090 let shape = Shape::new([3, 5, 6]);
1091 let data = TensorData::random::<f32, _, _>(
1092 shape,
1093 Distribution::Default,
1094 &mut StdRng::from_os_rng(),
1095 );
1096
1097 assert_eq!(data.rank(), 3);
1098 }
1099
1100 #[test]
1101 fn into_vec_should_yield_same_value_as_iter() {
1102 let shape = Shape::new([3, 5, 6]);
1103 let data = TensorData::random::<f32, _, _>(
1104 shape,
1105 Distribution::Default,
1106 &mut StdRng::from_os_rng(),
1107 );
1108
1109 let expected = data.iter::<f32>().collect::<Vec<f32>>();
1110 let actual = data.into_vec::<f32>().unwrap();
1111
1112 assert_eq!(expected, actual);
1113 }
1114
1115 #[test]
1116 #[should_panic]
1117 fn into_vec_should_assert_wrong_dtype() {
1118 let shape = Shape::new([3, 5, 6]);
1119 let data = TensorData::random::<f32, _, _>(
1120 shape,
1121 Distribution::Default,
1122 &mut StdRng::from_os_rng(),
1123 );
1124
1125 data.into_vec::<i32>().unwrap();
1126 }
1127
1128 #[test]
1129 fn should_have_right_num_elements() {
1130 let shape = Shape::new([3, 5, 6]);
1131 let num_elements = shape.num_elements();
1132 let data = TensorData::random::<f32, _, _>(
1133 shape,
1134 Distribution::Default,
1135 &mut StdRng::from_os_rng(),
1136 );
1137
1138 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
1140 }
1141
1142 #[test]
1143 fn should_have_right_shape() {
1144 let data = TensorData::from([[3.0, 5.0, 6.0]]);
1145 assert_eq!(data.shape, vec![1, 3]);
1146
1147 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
1148 assert_eq!(data.shape, vec![2, 3]);
1149
1150 let data = TensorData::from([3.0, 5.0, 6.0]);
1151 assert_eq!(data.shape, vec![3]);
1152 }
1153
1154 #[test]
1155 fn should_assert_appox_eq_limit() {
1156 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1157 let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
1158
1159 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
1160 data1.assert_approx_eq::<half::f16>(&data2, Tolerance::absolute(3e-2));
1161 }
1162
1163 #[test]
1164 #[should_panic]
1165 fn should_assert_approx_eq_above_limit() {
1166 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
1167 let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
1168
1169 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1170 }
1171
1172 #[test]
1173 #[should_panic]
1174 fn should_assert_approx_eq_check_shape() {
1175 let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1176 let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
1177
1178 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
1179 }
1180
1181 #[test]
1182 fn should_convert_bytes_correctly() {
1183 let mut vector: Vec<f32> = Vec::with_capacity(5);
1184 vector.push(2.0);
1185 vector.push(3.0);
1186 let data1 = TensorData::new(vector, vec![2]);
1187
1188 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
1189 assert_eq!(data1.bytes.len(), 2 * factor);
1190 assert_eq!(data1.bytes.capacity(), 5 * factor);
1191 }
1192
1193 #[test]
1194 fn should_convert_bytes_correctly_inplace() {
1195 fn test_precision<E: Element>() {
1196 let data = TensorData::new((0..32).collect(), [32]);
1197 for (i, val) in data
1198 .clone()
1199 .convert::<E>()
1200 .into_vec::<E>()
1201 .unwrap()
1202 .into_iter()
1203 .enumerate()
1204 {
1205 assert_eq!(i as u32, val.elem::<u32>())
1206 }
1207 }
1208 test_precision::<f32>();
1209 test_precision::<f16>();
1210 test_precision::<i64>();
1211 test_precision::<i32>();
1212 }
1213
1214 #[test]
1215 #[should_panic = "Expected quantized data"]
1216 fn should_not_dequantize() {
1217 let data = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
1218 data.dequantize().unwrap();
1219 }
1220
1221 #[test]
1222 fn should_support_dequantize() {
1223 let data = TensorData::quantized(
1224 vec![-127i8, -77, -26, 25, 76, 127],
1225 [2, 3],
1226 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
1227 0.1,
1228 QuantValue::Q8S,
1229 )),
1230 QuantScheme {
1231 level: QuantLevel::Tensor,
1232 value: QuantValue::Q8S,
1233 mode: QuantMode::Symmetric,
1234 ..Default::default()
1235 },
1236 );
1237
1238 let output = data.dequantize().unwrap();
1239
1240 output.assert_approx_eq::<f32>(
1241 &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1242 Tolerance::default(),
1243 );
1244
1245 output.assert_approx_eq::<f16>(
1246 &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),
1247 Tolerance::default(),
1248 );
1249 }
1250
1251 macro_rules! test_dtypes {
1252 ($test_name:ident, $($dtype:ty),*) => {
1253 $(
1254 paste::paste! {
1255 #[test]
1256 fn [<$test_name _ $dtype:snake>]() {
1257 let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
1258 let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
1259 assert_eq!(full_dtype, full);
1260 }
1261 }
1262 )*
1263 };
1264}
1265
1266 test_dtypes!(
1267 should_create_with_dtype,
1268 bool,
1269 i8,
1270 i16,
1271 i32,
1272 i64,
1273 u8,
1274 u16,
1275 u32,
1276 u64,
1277 f16,
1278 bf16,
1279 f32,
1280 f64
1281 );
1282}