1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::RngCore;
9
10use crate::distribution::Distribution;
11use crate::element::{Element, ElementConversion};
12use burn_std::tensor::DType;
13use burn_std::tensor::quantization::{
14 QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes,
15};
16use burn_std::{Bytes, bf16, f16};
17
18#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
20pub struct TensorData {
21 pub bytes: Bytes,
23
24 pub shape: Vec<usize>,
26
27 pub dtype: DType,
29}
30
31impl TensorData {
32 pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
34 let shape = shape.into();
36 Self::check_data_len(&value, &shape);
37
38 Self {
39 bytes: Bytes::from_elems(value),
40 shape,
41 dtype: E::dtype(),
42 }
43 }
44
45 pub fn quantized<E: Element, S: Into<Vec<usize>>>(
47 value: Vec<E>,
48 shape: S,
49 scheme: QuantScheme,
50 qparams: &[f32],
51 ) -> Self {
52 let shape = shape.into();
53 Self::check_data_len(&value, &shape);
54
55 let q_bytes = QuantizedBytes::new(value, scheme, qparams);
56
57 Self {
58 bytes: q_bytes.bytes,
59 shape,
60 dtype: DType::QFloat(q_bytes.scheme),
61 }
62 }
63
64 pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
66 Self {
67 bytes,
68 shape: shape.into(),
69 dtype,
70 }
71 }
72
73 pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
78 Self {
79 bytes: Bytes::from_bytes_vec(bytes),
80 shape: shape.into(),
81 dtype,
82 }
83 }
84
85 fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
87 let expected_data_len = Self::numel(shape);
88 let num_data = data.len();
89 assert_eq!(
90 expected_data_len, num_data,
91 "Shape {shape:?} is invalid for input of size {num_data:?}",
92 );
93 }
94
95 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
97 if E::dtype() == self.dtype {
98 match E::dtype() {
99 DType::Bool => {
103 let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
104 .map_err(DataError::CastError)?;
105 Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
106 }
107 _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
108 }
109 } else {
110 Err(DataError::TypeMismatch(format!(
111 "Invalid target element type (expected {:?}, got {:?})",
112 self.dtype,
113 E::dtype()
114 )))
115 }
116 }
117
118 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
123 if E::dtype() == self.dtype {
124 match E::dtype() {
125 DType::Bool => {
129 let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
130 .map_err(DataError::CastError)?;
131 Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
132 }
133 _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
134 .map_err(DataError::CastError),
135 }
136 } else {
137 Err(DataError::TypeMismatch(format!(
138 "Invalid target element type (expected {:?}, got {:?})",
139 self.dtype,
140 E::dtype()
141 )))
142 }
143 }
144
145 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
147 Ok(self.as_slice()?.to_vec())
148 }
149
150 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
152 if E::dtype() != self.dtype {
154 return Err(DataError::TypeMismatch(format!(
155 "Invalid target element type (expected {:?}, got {:?})",
156 self.dtype,
157 E::dtype()
158 )));
159 }
160
161 match E::dtype() {
162 DType::Bool => {
166 let vec = self.into_vec_unchecked::<u8>()?;
167 Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
168 }
169 _ => self.into_vec_unchecked(),
170 }
171 }
172
173 fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
175 let mut me = self;
176 me.bytes = match me.bytes.try_into_vec::<E>() {
177 Ok(elems) => return Ok(elems),
178 Err(bytes) => bytes,
179 };
180 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
183 .map_err(DataError::CastError)?
184 .to_vec())
185 }
186
187 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
189 if E::dtype() == self.dtype {
190 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
191 } else {
192 match self.dtype {
193 DType::I8 => Box::new(
194 bytemuck::checked::cast_slice(&self.bytes)
195 .iter()
196 .map(|e: &i8| e.elem::<E>()),
197 ),
198 DType::I16 => Box::new(
199 bytemuck::checked::cast_slice(&self.bytes)
200 .iter()
201 .map(|e: &i16| e.elem::<E>()),
202 ),
203 DType::I32 => Box::new(
204 bytemuck::checked::cast_slice(&self.bytes)
205 .iter()
206 .map(|e: &i32| e.elem::<E>()),
207 ),
208 DType::I64 => Box::new(
209 bytemuck::checked::cast_slice(&self.bytes)
210 .iter()
211 .map(|e: &i64| e.elem::<E>()),
212 ),
213 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
214 DType::U16 => Box::new(
215 bytemuck::checked::cast_slice(&self.bytes)
216 .iter()
217 .map(|e: &u16| e.elem::<E>()),
218 ),
219 DType::U32 => Box::new(
220 bytemuck::checked::cast_slice(&self.bytes)
221 .iter()
222 .map(|e: &u32| e.elem::<E>()),
223 ),
224 DType::U64 => Box::new(
225 bytemuck::checked::cast_slice(&self.bytes)
226 .iter()
227 .map(|e: &u64| e.elem::<E>()),
228 ),
229 DType::BF16 => Box::new(
230 bytemuck::checked::cast_slice(&self.bytes)
231 .iter()
232 .map(|e: &bf16| e.elem::<E>()),
233 ),
234 DType::F16 => Box::new(
235 bytemuck::checked::cast_slice(&self.bytes)
236 .iter()
237 .map(|e: &f16| e.elem::<E>()),
238 ),
239 DType::F32 | DType::Flex32 => Box::new(
240 bytemuck::checked::cast_slice(&self.bytes)
241 .iter()
242 .map(|e: &f32| e.elem::<E>()),
243 ),
244 DType::F64 => Box::new(
245 bytemuck::checked::cast_slice(&self.bytes)
246 .iter()
247 .map(|e: &f64| e.elem::<E>()),
248 ),
249 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
251 DType::QFloat(scheme) => match scheme {
252 QuantScheme {
253 level: QuantLevel::Tensor | QuantLevel::Block(_),
254 mode: QuantMode::Symmetric,
255 value:
256 QuantValue::Q8F
257 | QuantValue::Q8S
258 | QuantValue::Q4F
260 | QuantValue::Q4S
261 | QuantValue::Q2F
262 | QuantValue::Q2S,
263 ..
264 } => {
265 let q_bytes = QuantizedBytes {
267 bytes: self.bytes.clone(),
268 scheme,
269 num_elements: self.num_elements(),
270 };
271 let (values, _) = q_bytes.into_vec_i8();
272
273 Box::new(
274 values
275 .iter()
276 .map(|e: &i8| e.elem::<E>())
277 .collect::<Vec<_>>()
278 .into_iter(),
279 )
280 }
281 QuantScheme {
282 level: QuantLevel::Tensor | QuantLevel::Block(_),
283 mode: QuantMode::Symmetric,
284 value:
285 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
286 ..
287 } => {
288 unimplemented!("Not yet implemented for iteration");
289 }
290 },
291 }
292 }
293 }
294
295 pub fn rank(&self) -> usize {
297 self.shape.len()
298 }
299
300 pub fn num_elements(&self) -> usize {
302 Self::numel(&self.shape)
303 }
304
305 fn numel(shape: &[usize]) -> usize {
306 shape.iter().product()
307 }
308
309 pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
311 shape: S,
312 distribution: Distribution,
313 rng: &mut R,
314 ) -> Self {
315 let shape = shape.into();
316 let num_elements = Self::numel(&shape);
317 let mut data = Vec::with_capacity(num_elements);
318
319 for _ in 0..num_elements {
320 data.push(E::random(distribution, rng));
321 }
322
323 TensorData::new(data, shape)
324 }
325
326 pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
328 let shape = shape.into();
329 let num_elements = Self::numel(&shape);
330 let mut data = Vec::<E>::with_capacity(num_elements);
331
332 for _ in 0..num_elements {
333 data.push(0.elem());
334 }
335
336 TensorData::new(data, shape)
337 }
338
339 pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
341 let shape = shape.into();
342 let num_elements = Self::numel(&shape);
343 let mut data = Vec::<E>::with_capacity(num_elements);
344
345 for _ in 0..num_elements {
346 data.push(1.elem());
347 }
348
349 TensorData::new(data, shape)
350 }
351
352 pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
354 let shape = shape.into();
355 let num_elements = Self::numel(&shape);
356 let mut data = Vec::<E>::with_capacity(num_elements);
357 for _ in 0..num_elements {
358 data.push(fill_value)
359 }
360
361 TensorData::new(data, shape)
362 }
363
364 #[allow(dead_code)]
365 pub fn full_dtype<E: Element, S: Into<Vec<usize>>>(
367 shape: S,
368 fill_value: E,
369 dtype: DType,
370 ) -> TensorData {
371 match dtype {
372 DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
373 DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
374 DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
375 DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
376 DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
377 DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
378 DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
379 DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
380 DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
381 DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
382 DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
383 DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
384 DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
385 DType::QFloat(_) => unreachable!(),
386 }
387 }
388
389 pub fn convert<E: Element>(self) -> Self {
391 self.convert_dtype(E::dtype())
392 }
393
394 pub fn convert_dtype(self, dtype: DType) -> Self {
396 if dtype == self.dtype {
397 self
398 } else if dtype.size() == self.dtype.size()
399 && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
400 && !matches!(dtype, DType::Bool | DType::QFloat(_))
401 {
402 match self.dtype {
403 DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
404 DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
405 DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
406 DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
407 DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
408 DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
409 DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
410 DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
411 DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
412 DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
413 DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
414 DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
415 DType::Bool | DType::QFloat(_) => unreachable!(),
416 }
417 } else {
418 match self.dtype {
419 DType::F64 => self.convert_clone_dtype::<f64>(dtype),
420 DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
421 DType::F16 => self.convert_clone_dtype::<f16>(dtype),
422 DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
423 DType::I64 => self.convert_clone_dtype::<i64>(dtype),
424 DType::I32 => self.convert_clone_dtype::<i32>(dtype),
425 DType::I16 => self.convert_clone_dtype::<i16>(dtype),
426 DType::I8 => self.convert_clone_dtype::<i8>(dtype),
427 DType::U64 => self.convert_clone_dtype::<u64>(dtype),
428 DType::U32 => self.convert_clone_dtype::<u32>(dtype),
429 DType::U16 => self.convert_clone_dtype::<u16>(dtype),
430 DType::U8 => self.convert_clone_dtype::<u8>(dtype),
431 DType::Bool => self.convert_clone_dtype::<bool>(dtype),
432 DType::QFloat(_) => unreachable!(),
433 }
434 }
435 }
436
437 fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
438 match dtype {
439 DType::F64 => self.convert_inplace::<Current, f64>(),
440 DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
441 DType::F16 => self.convert_inplace::<Current, f16>(),
442 DType::BF16 => self.convert_inplace::<Current, bf16>(),
443 DType::I64 => self.convert_inplace::<Current, i64>(),
444 DType::I32 => self.convert_inplace::<Current, i32>(),
445 DType::I16 => self.convert_inplace::<Current, i16>(),
446 DType::I8 => self.convert_inplace::<Current, i8>(),
447 DType::U64 => self.convert_inplace::<Current, u64>(),
448 DType::U32 => self.convert_inplace::<Current, u32>(),
449 DType::U16 => self.convert_inplace::<Current, u16>(),
450 DType::U8 => self.convert_inplace::<Current, u8>(),
451 DType::Bool | DType::QFloat(_) => unreachable!(),
452 }
453 }
454
455 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
456 mut self,
457 ) -> Self {
458 for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
459 let t: Target = x.elem();
460 let x = cast_mut::<_, Target>(x);
461 *x = t;
462 }
463
464 self.dtype = Target::dtype();
465
466 self
467 }
468
469 fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
470 match dtype {
471 DType::F64 => self.convert_clone::<Current, f64>(),
472 DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
473 DType::F16 => self.convert_clone::<Current, f16>(),
474 DType::BF16 => self.convert_clone::<Current, bf16>(),
475 DType::I64 => self.convert_clone::<Current, i64>(),
476 DType::I32 => self.convert_clone::<Current, i32>(),
477 DType::I16 => self.convert_clone::<Current, i16>(),
478 DType::I8 => self.convert_clone::<Current, i8>(),
479 DType::U64 => self.convert_clone::<Current, u64>(),
480 DType::U32 => self.convert_clone::<Current, u32>(),
481 DType::U16 => self.convert_clone::<Current, u16>(),
482 DType::U8 => self.convert_clone::<Current, u8>(),
483 DType::Bool => self.convert_clone::<Current, bool>(),
484 DType::QFloat(_) => unreachable!(),
485 }
486 }
487
488 fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
489 self,
490 ) -> Self {
491 let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
492 let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
493
494 for (x, out) in this.iter().zip(&mut out) {
495 *out = x.elem();
496 }
497
498 Self::new(out, self.shape)
499 }
500
501 pub fn as_bytes(&self) -> &[u8] {
503 &self.bytes
504 }
505
506 pub fn into_bytes(self) -> Bytes {
508 self.bytes
509 }
510}
511
512impl<E: Element, const A: usize> From<[E; A]> for TensorData {
513 fn from(elems: [E; A]) -> Self {
514 TensorData::new(elems.to_vec(), [A])
515 }
516}
517
518impl<const A: usize> From<[usize; A]> for TensorData {
519 fn from(elems: [usize; A]) -> Self {
520 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
521 }
522}
523
524impl From<&[usize]> for TensorData {
525 fn from(elems: &[usize]) -> Self {
526 let mut data = Vec::with_capacity(elems.len());
527 for elem in elems.iter() {
528 data.push(*elem as i64);
529 }
530
531 TensorData::new(data, [elems.len()])
532 }
533}
534
535impl<E: Element> From<&[E]> for TensorData {
536 fn from(elems: &[E]) -> Self {
537 let mut data = Vec::with_capacity(elems.len());
538 for elem in elems.iter() {
539 data.push(*elem);
540 }
541
542 TensorData::new(data, [elems.len()])
543 }
544}
545
546impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
547 fn from(elems: [[E; B]; A]) -> Self {
548 let mut data = Vec::with_capacity(A * B);
549 for elem in elems.into_iter().take(A) {
550 for elem in elem.into_iter().take(B) {
551 data.push(elem);
552 }
553 }
554
555 TensorData::new(data, [A, B])
556 }
557}
558
559impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
560 for TensorData
561{
562 fn from(elems: [[[E; C]; B]; A]) -> Self {
563 let mut data = Vec::with_capacity(A * B * C);
564
565 for elem in elems.into_iter().take(A) {
566 for elem in elem.into_iter().take(B) {
567 for elem in elem.into_iter().take(C) {
568 data.push(elem);
569 }
570 }
571 }
572
573 TensorData::new(data, [A, B, C])
574 }
575}
576
577impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
578 From<[[[[E; D]; C]; B]; A]> for TensorData
579{
580 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
581 let mut data = Vec::with_capacity(A * B * C * D);
582
583 for elem in elems.into_iter().take(A) {
584 for elem in elem.into_iter().take(B) {
585 for elem in elem.into_iter().take(C) {
586 for elem in elem.into_iter().take(D) {
587 data.push(elem);
588 }
589 }
590 }
591 }
592
593 TensorData::new(data, [A, B, C, D])
594 }
595}
596
597impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
598 From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
599{
600 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
601 let mut data = Vec::with_capacity(A * B * C * D * E);
602
603 for elem in elems.into_iter().take(A) {
604 for elem in elem.into_iter().take(B) {
605 for elem in elem.into_iter().take(C) {
606 for elem in elem.into_iter().take(D) {
607 for elem in elem.into_iter().take(E) {
608 data.push(elem);
609 }
610 }
611 }
612 }
613 }
614
615 TensorData::new(data, [A, B, C, D, E])
616 }
617}
618impl core::fmt::Display for TensorData {
619 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
620 let fmt = match self.dtype {
621 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
622 DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
623 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
624 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
625 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
626 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
627 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
628 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
629 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
630 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
631 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
632 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
633 DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
634 DType::QFloat(scheme) => match scheme {
635 QuantScheme {
636 level: QuantLevel::Tensor | QuantLevel::Block(_),
637 mode: QuantMode::Symmetric,
638 value:
639 QuantValue::Q8F
640 | QuantValue::Q8S
641 | QuantValue::Q4F
643 | QuantValue::Q4S
644 | QuantValue::Q2F
645 | QuantValue::Q2S,
646 ..
647 } => {
648 format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
649 },
650 QuantScheme {
651 level: QuantLevel::Tensor | QuantLevel::Block(_),
652 mode: QuantMode::Symmetric,
653 value:
654 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
655 ..
656 } => {
657 unimplemented!("Can't format yet");
658 }
659 },
660 };
661 f.write_str(fmt.as_str())
662 }
663}
664
665#[derive(Debug)]
667pub enum DataError {
668 CastError(CheckedCastError),
670 TypeMismatch(String),
672}
673
674impl core::error::Error for DataError {}
675
676impl core::fmt::Display for DataError {
677 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
678 f.write_str(format!("{self:?}").as_str())
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use alloc::vec;
686 use rand::{SeedableRng, rngs::StdRng};
687
688 #[test]
689 fn should_have_rank() {
690 let shape = [3, 5, 6];
691 let data = TensorData::random::<f32, _, _>(
692 shape,
693 Distribution::Default,
694 &mut StdRng::from_os_rng(),
695 );
696
697 assert_eq!(data.rank(), 3);
698 }
699
700 #[test]
701 fn into_vec_should_yield_same_value_as_iter() {
702 let shape = [3, 5, 6];
703 let data = TensorData::random::<f32, _, _>(
704 shape,
705 Distribution::Default,
706 &mut StdRng::from_os_rng(),
707 );
708
709 let expected = data.iter::<f32>().collect::<Vec<f32>>();
710 let actual = data.into_vec::<f32>().unwrap();
711
712 assert_eq!(expected, actual);
713 }
714
715 #[test]
716 #[should_panic]
717 fn into_vec_should_assert_wrong_dtype() {
718 let shape = [3, 5, 6];
719 let data = TensorData::random::<f32, _, _>(
720 shape,
721 Distribution::Default,
722 &mut StdRng::from_os_rng(),
723 );
724
725 data.into_vec::<i32>().unwrap();
726 }
727
728 #[test]
729 fn should_have_right_num_elements() {
730 let shape = [3, 5, 6];
731 let num_elements: usize = shape.iter().product();
732 let data = TensorData::random::<f32, _, _>(
733 shape,
734 Distribution::Default,
735 &mut StdRng::from_os_rng(),
736 );
737
738 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
740 }
741
742 #[test]
743 fn should_have_right_shape() {
744 let data = TensorData::from([[3.0, 5.0, 6.0]]);
745 assert_eq!(data.shape, vec![1, 3]);
746
747 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
748 assert_eq!(data.shape, vec![2, 3]);
749
750 let data = TensorData::from([3.0, 5.0, 6.0]);
751 assert_eq!(data.shape, vec![3]);
752 }
753
754 #[test]
755 fn should_convert_bytes_correctly() {
756 let mut vector: Vec<f32> = Vec::with_capacity(5);
757 vector.push(2.0);
758 vector.push(3.0);
759 let data1 = TensorData::new(vector, vec![2]);
760
761 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
762 assert_eq!(data1.bytes.len(), 2 * factor);
763 assert_eq!(data1.bytes.capacity(), 5 * factor);
764 }
765
766 #[test]
767 fn should_convert_bytes_correctly_inplace() {
768 fn test_precision<E: Element>() {
769 let data = TensorData::new((0..32).collect(), [32]);
770 for (i, val) in data
771 .clone()
772 .convert::<E>()
773 .into_vec::<E>()
774 .unwrap()
775 .into_iter()
776 .enumerate()
777 {
778 assert_eq!(i as u32, val.elem::<u32>())
779 }
780 }
781 test_precision::<f32>();
782 test_precision::<f16>();
783 test_precision::<i64>();
784 test_precision::<i32>();
785 }
786
787 macro_rules! test_dtypes {
788 ($test_name:ident, $($dtype:ty),*) => {
789 $(
790 paste::paste! {
791 #[test]
792 fn [<$test_name _ $dtype:snake>]() {
793 let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
794 let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
795 assert_eq!(full_dtype, full);
796 }
797 }
798 )*
799 };
800}
801
802 test_dtypes!(
803 should_create_with_dtype,
804 bool,
805 i8,
806 i16,
807 i32,
808 i64,
809 u8,
810 u16,
811 u32,
812 u64,
813 f16,
814 bf16,
815 f32,
816 f64
817 );
818}