1use core::f32;
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::String;
6use alloc::vec::Vec;
7use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};
8use rand::RngCore;
9
10use crate::distribution::Distribution;
11use crate::element::{Element, ElementConversion};
12use burn_std::tensor::DType;
13use burn_std::{Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, bf16, f16};
14
15#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
17pub struct TensorData {
18 pub bytes: Bytes,
20
21 pub shape: Vec<usize>,
23
24 pub dtype: DType,
26}
27
28impl TensorData {
29 pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self {
31 let shape = shape.into();
33 Self::check_data_len(&value, &shape);
34
35 Self {
36 bytes: Bytes::from_elems(value),
37 shape,
38 dtype: E::dtype(),
39 }
40 }
41
42 pub fn quantized<E: Element, S: Into<Vec<usize>>>(
44 value: Vec<E>,
45 shape: S,
46 scheme: QuantScheme,
47 qparams: &[f32],
48 ) -> Self {
49 let shape = shape.into();
50 Self::check_data_len(&value, &shape);
51
52 let q_bytes = QuantizedBytes::new(value, scheme, qparams);
53
54 Self {
55 bytes: q_bytes.bytes,
56 shape,
57 dtype: DType::QFloat(q_bytes.scheme),
58 }
59 }
60
61 pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self {
63 Self {
64 bytes,
65 shape: shape.into(),
66 dtype,
67 }
68 }
69
70 pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {
75 Self {
76 bytes: Bytes::from_bytes_vec(bytes),
77 shape: shape.into(),
78 dtype,
79 }
80 }
81
82 fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) {
84 let expected_data_len = Self::numel(shape);
85 let num_data = data.len();
86 assert_eq!(
87 expected_data_len, num_data,
88 "Shape {shape:?} is invalid for input of size {num_data:?}",
89 );
90 }
91
92 pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {
94 if E::dtype() == self.dtype {
95 match E::dtype() {
96 DType::Bool => {
100 let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)
101 .map_err(DataError::CastError)?;
102 Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })
103 }
104 _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),
105 }
106 } else {
107 Err(DataError::TypeMismatch(format!(
108 "Invalid target element type (expected {:?}, got {:?})",
109 self.dtype,
110 E::dtype()
111 )))
112 }
113 }
114
115 pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {
120 if E::dtype() == self.dtype {
121 match E::dtype() {
122 DType::Bool => {
126 let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)
127 .map_err(DataError::CastError)?;
128 Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })
129 }
130 _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)
131 .map_err(DataError::CastError),
132 }
133 } else {
134 Err(DataError::TypeMismatch(format!(
135 "Invalid target element type (expected {:?}, got {:?})",
136 self.dtype,
137 E::dtype()
138 )))
139 }
140 }
141
142 pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {
144 Ok(self.as_slice()?.to_vec())
145 }
146
147 pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {
149 if E::dtype() != self.dtype {
151 return Err(DataError::TypeMismatch(format!(
152 "Invalid target element type (expected {:?}, got {:?})",
153 self.dtype,
154 E::dtype()
155 )));
156 }
157
158 match E::dtype() {
159 DType::Bool => {
163 let vec = self.into_vec_unchecked::<u8>()?;
164 Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })
165 }
166 _ => self.into_vec_unchecked(),
167 }
168 }
169
170 fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {
172 let mut me = self;
173 me.bytes = match me.bytes.try_into_vec::<E>() {
174 Ok(elems) => return Ok(elems),
175 Err(bytes) => bytes,
176 };
177
178 Ok(bytemuck::checked::try_cast_slice(me.as_bytes())
181 .map_err(DataError::CastError)?
182 .to_vec())
183 }
184
185 pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
187 if E::dtype() == self.dtype {
188 Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())
189 } else {
190 match self.dtype {
191 DType::I8 => Box::new(
192 bytemuck::checked::cast_slice(&self.bytes)
193 .iter()
194 .map(|e: &i8| e.elem::<E>()),
195 ),
196 DType::I16 => Box::new(
197 bytemuck::checked::cast_slice(&self.bytes)
198 .iter()
199 .map(|e: &i16| e.elem::<E>()),
200 ),
201 DType::I32 => Box::new(
202 bytemuck::checked::cast_slice(&self.bytes)
203 .iter()
204 .map(|e: &i32| e.elem::<E>()),
205 ),
206 DType::I64 => Box::new(
207 bytemuck::checked::cast_slice(&self.bytes)
208 .iter()
209 .map(|e: &i64| e.elem::<E>()),
210 ),
211 DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
212 DType::U16 => Box::new(
213 bytemuck::checked::cast_slice(&self.bytes)
214 .iter()
215 .map(|e: &u16| e.elem::<E>()),
216 ),
217 DType::U32 => Box::new(
218 bytemuck::checked::cast_slice(&self.bytes)
219 .iter()
220 .map(|e: &u32| e.elem::<E>()),
221 ),
222 DType::U64 => Box::new(
223 bytemuck::checked::cast_slice(&self.bytes)
224 .iter()
225 .map(|e: &u64| e.elem::<E>()),
226 ),
227 DType::BF16 => Box::new(
228 bytemuck::checked::cast_slice(&self.bytes)
229 .iter()
230 .map(|e: &bf16| e.elem::<E>()),
231 ),
232 DType::F16 => Box::new(
233 bytemuck::checked::cast_slice(&self.bytes)
234 .iter()
235 .map(|e: &f16| e.elem::<E>()),
236 ),
237 DType::F32 | DType::Flex32 => Box::new(
238 bytemuck::checked::cast_slice(&self.bytes)
239 .iter()
240 .map(|e: &f32| e.elem::<E>()),
241 ),
242 DType::F64 => Box::new(
243 bytemuck::checked::cast_slice(&self.bytes)
244 .iter()
245 .map(|e: &f64| e.elem::<E>()),
246 ),
247 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),
249 DType::QFloat(scheme) => match scheme {
250 QuantScheme {
251 level: QuantLevel::Tensor | QuantLevel::Block(_),
252 mode: QuantMode::Symmetric,
253 value:
254 QuantValue::Q8F
255 | QuantValue::Q8S
256 | QuantValue::Q4F
258 | QuantValue::Q4S
259 | QuantValue::Q2F
260 | QuantValue::Q2S,
261 ..
262 } => {
263 let q_bytes = QuantizedBytes {
265 bytes: self.bytes.clone(),
266 scheme,
267 num_elements: self.num_elements(),
268 };
269 let (values, _) = q_bytes.into_vec_i8();
270
271 Box::new(
272 values
273 .iter()
274 .map(|e: &i8| e.elem::<E>())
275 .collect::<Vec<_>>()
276 .into_iter(),
277 )
278 }
279 QuantScheme {
280 level: QuantLevel::Tensor | QuantLevel::Block(_),
281 mode: QuantMode::Symmetric,
282 value:
283 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
284 ..
285 } => {
286 unimplemented!("Not yet implemented for iteration");
287 }
288 },
289 }
290 }
291 }
292
293 pub fn rank(&self) -> usize {
295 self.shape.len()
296 }
297
298 pub fn num_elements(&self) -> usize {
300 Self::numel(&self.shape)
301 }
302
303 fn numel(shape: &[usize]) -> usize {
304 shape.iter().product()
305 }
306
307 pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>(
309 shape: S,
310 distribution: Distribution,
311 rng: &mut R,
312 ) -> Self {
313 let shape = shape.into();
314 let num_elements = Self::numel(&shape);
315 let mut data = Vec::with_capacity(num_elements);
316
317 for _ in 0..num_elements {
318 data.push(E::random(distribution, rng));
319 }
320
321 TensorData::new(data, shape)
322 }
323
324 pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
326 let shape = shape.into();
327 let num_elements = Self::numel(&shape);
328 let mut data = Vec::<E>::with_capacity(num_elements);
329
330 for _ in 0..num_elements {
331 data.push(0.elem());
332 }
333
334 TensorData::new(data, shape)
335 }
336
337 pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData {
339 let shape = shape.into();
340 let num_elements = Self::numel(&shape);
341 let mut data = Vec::<E>::with_capacity(num_elements);
342
343 for _ in 0..num_elements {
344 data.push(1.elem());
345 }
346
347 TensorData::new(data, shape)
348 }
349
350 pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData {
352 let shape = shape.into();
353 let num_elements = Self::numel(&shape);
354 let mut data = Vec::<E>::with_capacity(num_elements);
355 for _ in 0..num_elements {
356 data.push(fill_value)
357 }
358
359 TensorData::new(data, shape)
360 }
361
362 #[allow(dead_code)]
363 pub fn full_dtype<E: Element, S: Into<Vec<usize>>>(
365 shape: S,
366 fill_value: E,
367 dtype: DType,
368 ) -> TensorData {
369 match dtype {
370 DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),
371 DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),
372 DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),
373 DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),
374 DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),
375 DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),
376 DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),
377 DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),
378 DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),
379 DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),
380 DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),
381 DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),
382 DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()),
383 DType::QFloat(_) => unreachable!(),
384 }
385 }
386
387 pub fn convert<E: Element>(self) -> Self {
389 self.convert_dtype(E::dtype())
390 }
391
392 pub fn convert_dtype(self, dtype: DType) -> Self {
394 if dtype == self.dtype {
395 self
396 } else if dtype.size() == self.dtype.size()
397 && !matches!(self.dtype, DType::Bool | DType::QFloat(_))
398 && !matches!(dtype, DType::Bool | DType::QFloat(_))
399 {
400 match self.dtype {
401 DType::F64 => self.convert_inplace_dtype::<f64>(dtype),
402 DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),
403 DType::F16 => self.convert_inplace_dtype::<f16>(dtype),
404 DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),
405 DType::I64 => self.convert_inplace_dtype::<i64>(dtype),
406 DType::I32 => self.convert_inplace_dtype::<i32>(dtype),
407 DType::I16 => self.convert_inplace_dtype::<i16>(dtype),
408 DType::I8 => self.convert_inplace_dtype::<i8>(dtype),
409 DType::U64 => self.convert_inplace_dtype::<u64>(dtype),
410 DType::U32 => self.convert_inplace_dtype::<u32>(dtype),
411 DType::U16 => self.convert_inplace_dtype::<u16>(dtype),
412 DType::U8 => self.convert_inplace_dtype::<u8>(dtype),
413 DType::Bool | DType::QFloat(_) => unreachable!(),
414 }
415 } else {
416 match self.dtype {
417 DType::F64 => self.convert_clone_dtype::<f64>(dtype),
418 DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),
419 DType::F16 => self.convert_clone_dtype::<f16>(dtype),
420 DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),
421 DType::I64 => self.convert_clone_dtype::<i64>(dtype),
422 DType::I32 => self.convert_clone_dtype::<i32>(dtype),
423 DType::I16 => self.convert_clone_dtype::<i16>(dtype),
424 DType::I8 => self.convert_clone_dtype::<i8>(dtype),
425 DType::U64 => self.convert_clone_dtype::<u64>(dtype),
426 DType::U32 => self.convert_clone_dtype::<u32>(dtype),
427 DType::U16 => self.convert_clone_dtype::<u16>(dtype),
428 DType::U8 => self.convert_clone_dtype::<u8>(dtype),
429 DType::Bool => self.convert_clone_dtype::<bool>(dtype),
430 DType::QFloat(_) => unreachable!(),
431 }
432 }
433 }
434
435 fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {
436 match dtype {
437 DType::F64 => self.convert_inplace::<Current, f64>(),
438 DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),
439 DType::F16 => self.convert_inplace::<Current, f16>(),
440 DType::BF16 => self.convert_inplace::<Current, bf16>(),
441 DType::I64 => self.convert_inplace::<Current, i64>(),
442 DType::I32 => self.convert_inplace::<Current, i32>(),
443 DType::I16 => self.convert_inplace::<Current, i16>(),
444 DType::I8 => self.convert_inplace::<Current, i8>(),
445 DType::U64 => self.convert_inplace::<Current, u64>(),
446 DType::U32 => self.convert_inplace::<Current, u32>(),
447 DType::U16 => self.convert_inplace::<Current, u16>(),
448 DType::U8 => self.convert_inplace::<Current, u8>(),
449 DType::Bool | DType::QFloat(_) => unreachable!(),
450 }
451 }
452
453 fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(
454 mut self,
455 ) -> Self {
456 for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {
457 let t: Target = x.elem();
458 let x = cast_mut::<_, Target>(x);
459 *x = t;
460 }
461
462 self.dtype = Target::dtype();
463
464 self
465 }
466
467 fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {
468 match dtype {
469 DType::F64 => self.convert_clone::<Current, f64>(),
470 DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),
471 DType::F16 => self.convert_clone::<Current, f16>(),
472 DType::BF16 => self.convert_clone::<Current, bf16>(),
473 DType::I64 => self.convert_clone::<Current, i64>(),
474 DType::I32 => self.convert_clone::<Current, i32>(),
475 DType::I16 => self.convert_clone::<Current, i16>(),
476 DType::I8 => self.convert_clone::<Current, i8>(),
477 DType::U64 => self.convert_clone::<Current, u64>(),
478 DType::U32 => self.convert_clone::<Current, u32>(),
479 DType::U16 => self.convert_clone::<Current, u16>(),
480 DType::U8 => self.convert_clone::<Current, u8>(),
481 DType::Bool => self.convert_clone::<Current, bool>(),
482 DType::QFloat(_) => unreachable!(),
483 }
484 }
485
486 fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(
487 self,
488 ) -> Self {
489 let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);
490 let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];
491
492 for (x, out) in this.iter().zip(&mut out) {
493 *out = x.elem();
494 }
495
496 Self::new(out, self.shape)
497 }
498
499 pub fn as_bytes(&self) -> &[u8] {
501 &self.bytes
502 }
503
504 pub fn into_bytes(self) -> Bytes {
506 self.bytes
507 }
508}
509
510impl<E: Element, const A: usize> From<[E; A]> for TensorData {
511 fn from(elems: [E; A]) -> Self {
512 TensorData::new(elems.to_vec(), [A])
513 }
514}
515
516impl<const A: usize> From<[usize; A]> for TensorData {
517 fn from(elems: [usize; A]) -> Self {
518 TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])
519 }
520}
521
522impl From<&[usize]> for TensorData {
523 fn from(elems: &[usize]) -> Self {
524 let mut data = Vec::with_capacity(elems.len());
525 for elem in elems.iter() {
526 data.push(*elem as i64);
527 }
528
529 TensorData::new(data, [elems.len()])
530 }
531}
532
533impl<E: Element> From<&[E]> for TensorData {
534 fn from(elems: &[E]) -> Self {
535 let mut data = Vec::with_capacity(elems.len());
536 for elem in elems.iter() {
537 data.push(*elem);
538 }
539
540 TensorData::new(data, [elems.len()])
541 }
542}
543
544impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {
545 fn from(elems: [[E; B]; A]) -> Self {
546 let mut data = Vec::with_capacity(A * B);
547 for elem in elems.into_iter().take(A) {
548 for elem in elem.into_iter().take(B) {
549 data.push(elem);
550 }
551 }
552
553 TensorData::new(data, [A, B])
554 }
555}
556
557impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>
558 for TensorData
559{
560 fn from(elems: [[[E; C]; B]; A]) -> Self {
561 let mut data = Vec::with_capacity(A * B * C);
562
563 for elem in elems.into_iter().take(A) {
564 for elem in elem.into_iter().take(B) {
565 for elem in elem.into_iter().take(C) {
566 data.push(elem);
567 }
568 }
569 }
570
571 TensorData::new(data, [A, B, C])
572 }
573}
574
575impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>
576 From<[[[[E; D]; C]; B]; A]> for TensorData
577{
578 fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {
579 let mut data = Vec::with_capacity(A * B * C * D);
580
581 for elem in elems.into_iter().take(A) {
582 for elem in elem.into_iter().take(B) {
583 for elem in elem.into_iter().take(C) {
584 for elem in elem.into_iter().take(D) {
585 data.push(elem);
586 }
587 }
588 }
589 }
590
591 TensorData::new(data, [A, B, C, D])
592 }
593}
594
595impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>
596 From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData
597{
598 fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {
599 let mut data = Vec::with_capacity(A * B * C * D * E);
600
601 for elem in elems.into_iter().take(A) {
602 for elem in elem.into_iter().take(B) {
603 for elem in elem.into_iter().take(C) {
604 for elem in elem.into_iter().take(D) {
605 for elem in elem.into_iter().take(E) {
606 data.push(elem);
607 }
608 }
609 }
610 }
611 }
612
613 TensorData::new(data, [A, B, C, D, E])
614 }
615}
616impl core::fmt::Display for TensorData {
617 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
618 let fmt = match self.dtype {
619 DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()),
620 DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()),
621 DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()),
622 DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()),
623 DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()),
624 DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()),
625 DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()),
626 DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()),
627 DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()),
628 DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()),
629 DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()),
630 DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()),
631 DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()),
632 DType::QFloat(scheme) => match scheme {
633 QuantScheme {
634 level: QuantLevel::Tensor | QuantLevel::Block(_),
635 mode: QuantMode::Symmetric,
636 value:
637 QuantValue::Q8F
638 | QuantValue::Q8S
639 | QuantValue::Q4F
641 | QuantValue::Q4S
642 | QuantValue::Q2F
643 | QuantValue::Q2S,
644 ..
645 } => {
646 format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>())
647 },
648 QuantScheme {
649 level: QuantLevel::Tensor | QuantLevel::Block(_),
650 mode: QuantMode::Symmetric,
651 value:
652 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,
653 ..
654 } => {
655 unimplemented!("Can't format yet");
656 }
657 },
658 };
659 f.write_str(fmt.as_str())
660 }
661}
662
663#[derive(Debug)]
665pub enum DataError {
666 CastError(CheckedCastError),
668 TypeMismatch(String),
670}
671
672impl core::error::Error for DataError {}
673
674impl core::fmt::Display for DataError {
675 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
676 f.write_str(format!("{self:?}").as_str())
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use alloc::vec;
684 use rand::{SeedableRng, rngs::StdRng};
685
686 #[test]
687 fn should_have_rank() {
688 let shape = [3, 5, 6];
689 let data = TensorData::random::<f32, _, _>(
690 shape,
691 Distribution::Default,
692 &mut StdRng::from_os_rng(),
693 );
694
695 assert_eq!(data.rank(), 3);
696 }
697
698 #[test]
699 fn into_vec_should_yield_same_value_as_iter() {
700 let shape = [3, 5, 6];
701 let data = TensorData::random::<f32, _, _>(
702 shape,
703 Distribution::Default,
704 &mut StdRng::from_os_rng(),
705 );
706
707 let expected = data.iter::<f32>().collect::<Vec<f32>>();
708 let actual = data.into_vec::<f32>().unwrap();
709
710 assert_eq!(expected, actual);
711 }
712
713 #[test]
714 #[should_panic]
715 fn into_vec_should_assert_wrong_dtype() {
716 let shape = [3, 5, 6];
717 let data = TensorData::random::<f32, _, _>(
718 shape,
719 Distribution::Default,
720 &mut StdRng::from_os_rng(),
721 );
722
723 data.into_vec::<i32>().unwrap();
724 }
725
726 #[test]
727 fn should_have_right_num_elements() {
728 let shape = [3, 5, 6];
729 let num_elements: usize = shape.iter().product();
730 let data = TensorData::random::<f32, _, _>(
731 shape,
732 Distribution::Default,
733 &mut StdRng::from_os_rng(),
734 );
735
736 assert_eq!(num_elements, data.bytes.len() / 4); assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());
738 }
739
740 #[test]
741 fn should_have_right_shape() {
742 let data = TensorData::from([[3.0, 5.0, 6.0]]);
743 assert_eq!(data.shape, vec![1, 3]);
744
745 let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);
746 assert_eq!(data.shape, vec![2, 3]);
747
748 let data = TensorData::from([3.0, 5.0, 6.0]);
749 assert_eq!(data.shape, vec![3]);
750 }
751
752 #[test]
753 fn should_convert_bytes_correctly() {
754 let mut vector: Vec<f32> = Vec::with_capacity(5);
755 vector.push(2.0);
756 vector.push(3.0);
757 let data1 = TensorData::new(vector, vec![2]);
758
759 let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();
760 assert_eq!(data1.bytes.len(), 2 * factor);
761 assert_eq!(data1.bytes.capacity(), 5 * factor);
762 }
763
764 #[test]
765 fn should_convert_bytes_correctly_inplace() {
766 fn test_precision<E: Element>() {
767 let data = TensorData::new((0..32).collect(), [32]);
768 for (i, val) in data
769 .clone()
770 .convert::<E>()
771 .into_vec::<E>()
772 .unwrap()
773 .into_iter()
774 .enumerate()
775 {
776 assert_eq!(i as u32, val.elem::<u32>())
777 }
778 }
779 test_precision::<f32>();
780 test_precision::<f16>();
781 test_precision::<i64>();
782 test_precision::<i32>();
783 }
784
785 macro_rules! test_dtypes {
786 ($test_name:ident, $($dtype:ty),*) => {
787 $(
788 paste::paste! {
789 #[test]
790 fn [<$test_name _ $dtype:snake>]() {
791 let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());
792 let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());
793 assert_eq!(full_dtype, full);
794 }
795 }
796 )*
797 };
798}
799
800 test_dtypes!(
801 should_create_with_dtype,
802 bool,
803 i8,
804 i16,
805 i32,
806 i64,
807 u8,
808 u16,
809 u32,
810 u64,
811 f16,
812 bf16,
813 f32,
814 f64
815 );
816}