1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::RngCore;
9use thiserror::Error;
10
11use crate::Scalar;
12use crate::distribution::Distribution;
13use crate::element::{Element, ElementConversion};
14use burn_std::tensor::DType;
15use burn_std::{Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, bf16, f16};
16
17#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19pub struct TensorData {
20 pub bytes: Bytes,
22
23 pub shape: Vec<usize>,
25
26 pub dtype: DType,
28}
29
30impl TensorData {
31 pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
33 let shape = shape.into();
35 Self::check_data_len(&value, &shape);
36
37 Self {
38 bytes: Bytes::from_elems(value),
39 shape,
40 dtype: E::dtype(),
41 }
42 }
43
44 pub fn quantized<E: Element, S: Into<Vec<usize>>>(
46 value: Vec<E>,
47 shape: S,
48 scheme: QuantScheme,
49 qparams: &[f32],
50 ) -> Self {
51 let shape = shape.into();
52 Self::check_data_len(&value, &shape);
53
54 let q_bytes = QuantizedBytes::new(value, scheme, qparams);
55
56 Self {
57 bytes: q_bytes.bytes,
58 shape,
59 dtype: DType::QFloat(q_bytes.scheme),
60 }
61 }
62
63 pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
65 Self {
66 bytes,
67 shape: shape.into(),
68 dtype,
69 }
70 }
71
72 pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
77 Self {
78 bytes: Bytes::from_bytes_vec(bytes),
79 shape: shape.into(),
80 dtype,
81 }
82 }
83
84 fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
86 let expected_data_len = Self::numel(shape);
87 let num_data = data.len();
88 assert_eq!(
89 expected_data_len, num_data,
90 "Shape {shape:?} is invalid for input of size {num_data:?}",
91 );
92 }
93
94 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
96 if E::dtype() == self.dtype {
97 match E::dtype() {
98 DType::Bool => {
102 let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
103 .map_err(DataError::CastError)?;
104 Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
105 }
106 _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
107 }
108 } else {
109 Err(DataError::TypeMismatch(format!(
110 "Invalid target element type (expected {:?}, got {:?})",
111 self.dtype,
112 E::dtype()
113 )))
114 }
115 }
116
117 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
122 if E::dtype() == self.dtype {
123 match E::dtype() {
124 DType::Bool => {
128 let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
129 .map_err(DataError::CastError)?;
130 Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
131 }
132 _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
133 .map_err(DataError::CastError),
134 }
135 } else {
136 Err(DataError::TypeMismatch(format!(
137 "Invalid target element type (expected {:?}, got {:?})",
138 self.dtype,
139 E::dtype()
140 )))
141 }
142 }
143
144 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
146 Ok(self.as_slice()?.to_vec())
147 }
148
149 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
151 if E::dtype() != self.dtype {
153 return Err(DataError::TypeMismatch(format!(
154 "Invalid target element type (expected {:?}, got {:?})",
155 self.dtype,
156 E::dtype()
157 )));
158 }
159
160 match E::dtype() {
161 DType::Bool => {
165 let vec = self.into_vec_unchecked::<u8>()?;
166 Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
167 }
168 _ => self.into_vec_unchecked(),
169 }
170 }
171
172 fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
174 let mut me = self;
175 me.bytes = match me.bytes.try_into_vec::<E>() {
176 Ok(elems) => return Ok(elems),
177 Err(bytes) => bytes,
178 };
179
180 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
183 .map_err(DataError::CastError)?
184 .to_vec())
185 }
186
187 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
189 if E::dtype() == self.dtype {
190 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
191 } else {
192 match self.dtype {
193 DType::I8 => Box::new(
194 bytemuck::checked::cast_slice(&self.bytes)
195 .iter()
196 .map(|e: &i8| e.elem::<E>()),
197 ),
198 DType::I16 => Box::new(
199 bytemuck::checked::cast_slice(&self.bytes)
200 .iter()
201 .map(|e: &i16| e.elem::<E>()),
202 ),
203 DType::I32 => Box::new(
204 bytemuck::checked::cast_slice(&self.bytes)
205 .iter()
206 .map(|e: &i32| e.elem::<E>()),
207 ),
208 DType::I64 => Box::new(
209 bytemuck::checked::cast_slice(&self.bytes)
210 .iter()
211 .map(|e: &i64| e.elem::<E>()),
212 ),
213 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
214 DType::U16 => Box::new(
215 bytemuck::checked::cast_slice(&self.bytes)
216 .iter()
217 .map(|e: &u16| e.elem::<E>()),
218 ),
219 DType::U32 => Box::new(
220 bytemuck::checked::cast_slice(&self.bytes)
221 .iter()
222 .map(|e: &u32| e.elem::<E>()),
223 ),
224 DType::U64 => Box::new(
225 bytemuck::checked::cast_slice(&self.bytes)
226 .iter()
227 .map(|e: &u64| e.elem::<E>()),
228 ),
229 DType::BF16 => Box::new(
230 bytemuck::checked::cast_slice(&self.bytes)
231 .iter()
232 .map(|e: &bf16| e.elem::<E>()),
233 ),
234 DType::F16 => Box::new(
235 bytemuck::checked::cast_slice(&self.bytes)
236 .iter()
237 .map(|e: &f16| e.elem::<E>()),
238 ),
239 DType::F32 | DType::Flex32 => Box::new(
240 bytemuck::checked::cast_slice(&self.bytes)
241 .iter()
242 .map(|e: &f32| e.elem::<E>()),
243 ),
244 DType::F64 => Box::new(
245 bytemuck::checked::cast_slice(&self.bytes)
246 .iter()
247 .map(|e: &f64| e.elem::<E>()),
248 ),
249 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
251 DType::QFloat(scheme) => match scheme {
252 QuantScheme {
253 level: QuantLevel::Tensor | QuantLevel::Block(_),
254 mode: QuantMode::Symmetric,
255 value:
256 QuantValue::Q8F
257 | QuantValue::Q8S
258 | QuantValue::Q4F
260 | QuantValue::Q4S
261 | QuantValue::Q2F
262 | QuantValue::Q2S,
263 ..
264 } => {
265 let q_bytes = QuantizedBytes {
267 bytes: self.bytes.clone(),
268 scheme,
269 num_elements: self.num_elements(),
270 };
271 let (values, _) = q_bytes.into_vec_i8();
272
273 Box::new(
274 values
275 .iter()
276 .map(|e: &i8| e.elem::<E>())
277 .collect::<Vec<_>>()
278 .into_iter(),
279 )
280 }
281 QuantScheme {
282 level: QuantLevel::Tensor | QuantLevel::Block(_),
283 mode: QuantMode::Symmetric,
284 value:
285 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
286 ..
287 } => {
288 unimplemented!("Not yet implemented for iteration");
289 }
290 },
291 }
292 }
293 }
294
295 pub fn rank(&self) -> usize {
297 self.shape.len()
298 }
299
300 pub fn num_elements(&self) -> usize {
302 Self::numel(&self.shape)
303 }
304
305 fn numel(shape: &[usize]) -> usize {
306 shape.iter().product()
307 }
308
309 pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
311 shape: S,
312 distribution: Distribution,
313 rng: &mut R,
314 ) -> Self {
315 let shape = shape.into();
316 let num_elements = Self::numel(&shape);
317 let mut data = Vec::with_capacity(num_elements);
318
319 for _ in 0..num_elements {
320 data.push(E::random(distribution, rng));
321 }
322
323 TensorData::new(data, shape)
324 }
325
326 pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
328 let shape = shape.into();
329 let num_elements = Self::numel(&shape);
330 let mut data = Vec::<E>::with_capacity(num_elements);
331
332 for _ in 0..num_elements {
333 data.push(0.elem());
334 }
335
336 TensorData::new(data, shape)
337 }
338
339 pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
341 let shape = shape.into();
342 let num_elements = Self::numel(&shape);
343 let mut data = Vec::<E>::with_capacity(num_elements);
344
345 for _ in 0..num_elements {
346 data.push(1.elem());
347 }
348
349 TensorData::new(data, shape)
350 }
351
352 pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
354 let shape = shape.into();
355 let num_elements = Self::numel(&shape);
356 let mut data = Vec::<E>::with_capacity(num_elements);
357 for _ in 0..num_elements {
358 data.push(fill_value)
359 }
360
361 TensorData::new(data, shape)
362 }
363
364 pub fn full_dtype<E: Into<Scalar>, S: Into<Vec<usize>>>(
366 shape: S,
367 fill_value: E,
368 dtype: DType,
369 ) -> TensorData {
370 let fill_value = fill_value.into();
371 match dtype {
372 DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
373 DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
374 DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
375 DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
376 DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
377 DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
378 DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
379 DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
380 DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
381 DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
382 DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
383 DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
384 DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
385 DType::QFloat(_) => unreachable!(),
386 }
387 }
388
389 pub fn convert<E: Element>(self) -> Self {
391 self.convert_dtype(E::dtype())
392 }
393
394 pub fn convert_dtype(self, dtype: DType) -> Self {
396 if dtype == self.dtype {
397 self
398 } else if dtype.size() == self.dtype.size()
399 && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
400 && !matches!(dtype, DType::Bool | DType::QFloat(_))
401 {
402 match self.dtype {
403 DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
404 DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
405 DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
406 DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
407 DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
408 DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
409 DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
410 DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
411 DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
412 DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
413 DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
414 DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
415 DType::Bool | DType::QFloat(_) => unreachable!(),
416 }
417 } else {
418 match self.dtype {
419 DType::F64 => self.convert_clone_dtype::<f64>(dtype),
420 DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
421 DType::F16 => self.convert_clone_dtype::<f16>(dtype),
422 DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
423 DType::I64 => self.convert_clone_dtype::<i64>(dtype),
424 DType::I32 => self.convert_clone_dtype::<i32>(dtype),
425 DType::I16 => self.convert_clone_dtype::<i16>(dtype),
426 DType::I8 => self.convert_clone_dtype::<i8>(dtype),
427 DType::U64 => self.convert_clone_dtype::<u64>(dtype),
428 DType::U32 => self.convert_clone_dtype::<u32>(dtype),
429 DType::U16 => self.convert_clone_dtype::<u16>(dtype),
430 DType::U8 => self.convert_clone_dtype::<u8>(dtype),
431 DType::Bool => self.convert_clone_dtype::<bool>(dtype),
432 DType::QFloat(_) => unreachable!(),
433 }
434 }
435 }
436
437 fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
438 match dtype {
439 DType::F64 => self.convert_inplace::<Current, f64>(),
440 DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
441 DType::F16 => self.convert_inplace::<Current, f16>(),
442 DType::BF16 => self.convert_inplace::<Current, bf16>(),
443 DType::I64 => self.convert_inplace::<Current, i64>(),
444 DType::I32 => self.convert_inplace::<Current, i32>(),
445 DType::I16 => self.convert_inplace::<Current, i16>(),
446 DType::I8 => self.convert_inplace::<Current, i8>(),
447 DType::U64 => self.convert_inplace::<Current, u64>(),
448 DType::U32 => self.convert_inplace::<Current, u32>(),
449 DType::U16 => self.convert_inplace::<Current, u16>(),
450 DType::U8 => self.convert_inplace::<Current, u8>(),
451 DType::Bool | DType::QFloat(_) => unreachable!(),
452 }
453 }
454
455 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
456 mut self,
457 ) -> Self {
458 for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
459 let t: Target = x.elem();
460 let x = cast_mut::<_, Target>(x);
461 *x = t;
462 }
463
464 self.dtype = Target::dtype();
465
466 self
467 }
468
469 fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
470 match dtype {
471 DType::F64 => self.convert_clone::<Current, f64>(),
472 DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
473 DType::F16 => self.convert_clone::<Current, f16>(),
474 DType::BF16 => self.convert_clone::<Current, bf16>(),
475 DType::I64 => self.convert_clone::<Current, i64>(),
476 DType::I32 => self.convert_clone::<Current, i32>(),
477 DType::I16 => self.convert_clone::<Current, i16>(),
478 DType::I8 => self.convert_clone::<Current, i8>(),
479 DType::U64 => self.convert_clone::<Current, u64>(),
480 DType::U32 => self.convert_clone::<Current, u32>(),
481 DType::U16 => self.convert_clone::<Current, u16>(),
482 DType::U8 => self.convert_clone::<Current, u8>(),
483 DType::Bool => self.convert_clone::<Current, bool>(),
484 DType::QFloat(_) => unreachable!(),
485 }
486 }
487
488 fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
489 self,
490 ) -> Self {
491 let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
492 let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
493
494 for (x, out) in this.iter().zip(&mut out) {
495 *out = x.elem();
496 }
497
498 Self::new(out, self.shape)
499 }
500
501 pub fn as_bytes(&self) -> &[u8] {
503 &self.bytes
504 }
505
506 pub fn into_bytes(self) -> Bytes {
508 self.bytes
509 }
510}
511
512impl<E: Element, const A: usize> From<[E; A]> for TensorData {
513 fn from(elems: [E; A]) -> Self {
514 TensorData::new(elems.to_vec(), [A])
515 }
516}
517
518impl<const A: usize> From<[usize; A]> for TensorData {
519 fn from(elems: [usize; A]) -> Self {
520 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
521 }
522}
523
524impl From<&[usize]> for TensorData {
525 fn from(elems: &[usize]) -> Self {
526 let mut data = Vec::with_capacity(elems.len());
527 for elem in elems.iter() {
528 data.push(*elem as i64);
529 }
530
531 TensorData::new(data, [elems.len()])
532 }
533}
534
535impl<E: Element> From<&[E]> for TensorData {
536 fn from(elems: &[E]) -> Self {
537 let mut data = Vec::with_capacity(elems.len());
538 for elem in elems.iter() {
539 data.push(*elem);
540 }
541
542 TensorData::new(data, [elems.len()])
543 }
544}
545
546impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
547 fn from(elems: [[E; B]; A]) -> Self {
548 let mut data = Vec::with_capacity(A * B);
549 for elem in elems.into_iter().take(A) {
550 for elem in elem.into_iter().take(B) {
551 data.push(elem);
552 }
553 }
554
555 TensorData::new(data, [A, B])
556 }
557}
558
559impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
560 for TensorData
561{
562 fn from(elems: [[[E; C]; B]; A]) -> Self {
563 let mut data = Vec::with_capacity(A * B * C);
564
565 for elem in elems.into_iter().take(A) {
566 for elem in elem.into_iter().take(B) {
567 for elem in elem.into_iter().take(C) {
568 data.push(elem);
569 }
570 }
571 }
572
573 TensorData::new(data, [A, B, C])
574 }
575}
576
577impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
578 From<[[[[E; D]; C]; B]; A]> for TensorData
579{
580 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
581 let mut data = Vec::with_capacity(A * B * C * D);
582
583 for elem in elems.into_iter().take(A) {
584 for elem in elem.into_iter().take(B) {
585 for elem in elem.into_iter().take(C) {
586 for elem in elem.into_iter().take(D) {
587 data.push(elem);
588 }
589 }
590 }
591 }
592
593 TensorData::new(data, [A, B, C, D])
594 }
595}
596
597impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
598 From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
599{
600 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
601 let mut data = Vec::with_capacity(A * B * C * D * E);
602
603 for elem in elems.into_iter().take(A) {
604 for elem in elem.into_iter().take(B) {
605 for elem in elem.into_iter().take(C) {
606 for elem in elem.into_iter().take(D) {
607 for elem in elem.into_iter().take(E) {
608 data.push(elem);
609 }
610 }
611 }
612 }
613 }
614
615 TensorData::new(data, [A, B, C, D, E])
616 }
617}
618impl core::fmt::Display for TensorData {
619 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
620 let fmt = match self.dtype {
621 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
622 DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
623 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
624 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
625 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
626 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
627 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
628 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
629 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
630 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
631 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
632 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
633 DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
634 DType::QFloat(scheme) => match scheme {
635 QuantScheme {
636 level: QuantLevel::Tensor | QuantLevel::Block(_),
637 mode: QuantMode::Symmetric,
638 value:
639 QuantValue::Q8F
640 | QuantValue::Q8S
641 | QuantValue::Q4F
643 | QuantValue::Q4S
644 | QuantValue::Q2F
645 | QuantValue::Q2S,
646 ..
647 } => {
648 format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
649 },
650 QuantScheme {
651 level: QuantLevel::Tensor | QuantLevel::Block(_),
652 mode: QuantMode::Symmetric,
653 value:
654 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
655 ..
656 } => {
657 unimplemented!("Can't format yet");
658 }
659 },
660 };
661 f.write_str(fmt.as_str())
662 }
663}
664
665#[derive(Debug, Error)]
667pub enum DataError {
668 #[error("Failed to cast values to the specified element type.\nError:\n {0}")]
670 CastError(CheckedCastError),
671 #[error("{0}")]
673 TypeMismatch(String),
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use alloc::vec;
680 use rand::{SeedableRng, rngs::StdRng};
681
682 #[test]
683 fn should_have_rank() {
684 let shape = [3, 5, 6];
685 let data = TensorData::random::<f32, _, _>(
686 shape,
687 Distribution::Default,
688 &mut StdRng::from_os_rng(),
689 );
690
691 assert_eq!(data.rank(), 3);
692 }
693
694 #[test]
695 fn into_vec_should_yield_same_value_as_iter() {
696 let shape = [3, 5, 6];
697 let data = TensorData::random::<f32, _, _>(
698 shape,
699 Distribution::Default,
700 &mut StdRng::from_os_rng(),
701 );
702
703 let expected = data.iter::<f32>().collect::<Vec<f32>>();
704 let actual = data.into_vec::<f32>().unwrap();
705
706 assert_eq!(expected, actual);
707 }
708
709 #[test]
710 #[should_panic]
711 fn into_vec_should_assert_wrong_dtype() {
712 let shape = [3, 5, 6];
713 let data = TensorData::random::<f32, _, _>(
714 shape,
715 Distribution::Default,
716 &mut StdRng::from_os_rng(),
717 );
718
719 data.into_vec::<i32>().unwrap();
720 }
721
722 #[test]
723 fn should_have_right_num_elements() {
724 let shape = [3, 5, 6];
725 let num_elements: usize = shape.iter().product();
726 let data = TensorData::random::<f32, _, _>(
727 shape,
728 Distribution::Default,
729 &mut StdRng::from_os_rng(),
730 );
731
732 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
734 }
735
736 #[test]
737 fn should_have_right_shape() {
738 let data = TensorData::from([[3.0, 5.0, 6.0]]);
739 assert_eq!(data.shape, vec![1, 3]);
740
741 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
742 assert_eq!(data.shape, vec![2, 3]);
743
744 let data = TensorData::from([3.0, 5.0, 6.0]);
745 assert_eq!(data.shape, vec![3]);
746 }
747
748 #[test]
749 fn should_convert_bytes_correctly() {
750 let mut vector: Vec<f32> = Vec::with_capacity(5);
751 vector.push(2.0);
752 vector.push(3.0);
753 let data1 = TensorData::new(vector, vec![2]);
754
755 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
756 assert_eq!(data1.bytes.len(), 2 * factor);
757 assert_eq!(data1.bytes.capacity(), 5 * factor);
758 }
759
760 #[test]
761 fn should_convert_bytes_correctly_inplace() {
762 fn test_precision<E: Element>() {
763 let data = TensorData::new((0..32).collect(), [32]);
764 for (i, val) in data
765 .clone()
766 .convert::<E>()
767 .into_vec::<E>()
768 .unwrap()
769 .into_iter()
770 .enumerate()
771 {
772 assert_eq!(i as u32, val.elem::<u32>())
773 }
774 }
775 test_precision::<f32>();
776 test_precision::<f16>();
777 test_precision::<i64>();
778 test_precision::<i32>();
779 }
780
781 macro_rules! test_dtypes {
782 ($test_name:ident, $($dtype:ty),*) => {
783 $(
784 paste::paste! {
785 #[test]
786 fn [<$test_name _ $dtype:snake>]() {
787 let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
788 let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
789 assert_eq!(full_dtype, full);
790 }
791 }
792 )*
793 };
794}
795
796 test_dtypes!(
797 should_create_with_dtype,
798 bool,
799 i8,
800 i16,
801 i32,
802 i64,
803 u8,
804 u16,
805 u32,
806 u64,
807 f16,
808 bf16,
809 f32,
810 f64
811 );
812}