1use core::mem;
13
14use crate::memory::BufferDescriptor;
15use crate::message::payload::Payload;
16use crate::types::{DType, DataType};
17
18#[derive(Clone, Copy)]
50pub struct Tensor<T: Copy + Default + DType, const N: usize, const R: usize> {
51 data: [T; N],
52 len: usize,
53 shape: [usize; R],
54}
55
56impl<T: Copy + Default + DType, const N: usize, const R: usize> Tensor<T, N, R> {
61 #[inline]
67 pub fn from_shape(shape: [usize; R], data: &[T]) -> Self {
68 assert!(
69 data.len() <= N,
70 "Tensor: data length {} exceeds capacity {}",
71 data.len(),
72 N
73 );
74 debug_assert_eq!(
75 checked_product(&shape),
76 Some(data.len()),
77 "Tensor: shape product != data length"
78 );
79
80 let mut buf = [T::default(); N];
81 buf[..data.len()].copy_from_slice(data);
82
83 Self {
84 data: buf,
85 len: data.len(),
86 shape,
87 }
88 }
89
90 #[inline]
95 pub fn filled(shape: [usize; R], value: T) -> Self {
96 let count = checked_product(&shape).expect("Tensor: shape overflow");
97 assert!(count <= N, "Tensor: count {} exceeds capacity {}", count, N);
98
99 let mut buf = [T::default(); N];
100 let mut i = 0;
101 while i < count {
102 buf[i] = value;
103 i += 1;
104 }
105
106 Self {
107 data: buf,
108 len: count,
109 shape,
110 }
111 }
112
113 #[inline]
115 pub fn zeros(shape: [usize; R]) -> Self {
116 Self::filled(shape, T::default())
117 }
118
119 #[inline]
121 pub fn data_type(&self) -> DataType {
122 T::DATA_TYPE
123 }
124
125 #[inline]
127 pub fn len(&self) -> usize {
128 self.len
129 }
130
131 #[inline]
133 pub fn is_empty(&self) -> bool {
134 self.len == 0
135 }
136
137 #[inline]
139 pub const fn capacity(&self) -> usize {
140 N
141 }
142
143 #[inline]
145 pub const fn rank(&self) -> usize {
146 R
147 }
148
149 #[inline]
151 pub fn shape(&self) -> &[usize; R] {
152 &self.shape
153 }
154
155 #[inline]
157 pub fn as_slice(&self) -> &[T] {
158 &self.data[..self.len]
159 }
160
161 #[inline]
163 pub fn as_mut_slice(&mut self) -> &mut [T] {
164 &mut self.data[..self.len]
165 }
166
167 #[inline]
169 pub fn byte_len(&self) -> usize {
170 self.len.saturating_mul(mem::size_of::<T>())
171 }
172
173 #[inline]
178 pub fn reshape(&mut self, new_shape: [usize; R]) {
179 debug_assert_eq!(
180 checked_product(&new_shape),
181 Some(self.len),
182 "Tensor::reshape: shape product != len"
183 );
184 self.shape = new_shape;
185 }
186
187 #[inline]
189 pub fn is_compatible(&self) -> bool {
190 checked_product(&self.shape) == Some(self.len)
191 }
192
193 #[inline]
198 pub fn at(&self, index: [usize; R]) -> T {
199 self.data[self.flat_index(index)]
200 }
201
202 #[inline]
207 pub fn set(&mut self, index: [usize; R], value: T) {
208 let i = self.flat_index(index);
209 self.data[i] = value;
210 }
211
212 #[inline]
214 fn flat_index(&self, index: [usize; R]) -> usize {
215 let mut flat = 0usize;
216 let mut stride = 1usize;
217 let mut d = R;
218 while d > 0 {
219 d -= 1;
220 flat += index[d] * stride;
221 stride *= self.shape[d];
222 }
223 assert!(flat < self.len, "tensor index out of bounds");
224 flat
225 }
226}
227
228impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 0> {
233 #[inline]
235 pub fn scalar(value: T) -> Self {
236 assert!(N >= 1, "Tensor::scalar: capacity must be >= 1");
237 let mut buf = [T::default(); N];
238 buf[0] = value;
239 Self {
240 data: buf,
241 len: 1,
242 shape: [],
243 }
244 }
245
246 #[inline]
248 pub fn value(&self) -> T {
249 self.data[0]
250 }
251}
252
253impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 1> {
258 #[inline]
260 pub fn from_slice(data: &[T]) -> Self {
261 Self::from_shape([data.len()], data)
262 }
263}
264
265impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 2> {
270 #[inline]
272 pub fn nc(batch: usize, classes: usize, data: &[T]) -> Self {
273 Self::from_shape([batch, classes], data)
274 }
275
276 #[inline]
278 pub fn matrix(rows: usize, cols: usize, data: &[T]) -> Self {
279 Self::from_shape([rows, cols], data)
280 }
281}
282
283impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 3> {
288 #[inline]
291 pub fn sequence(batch: usize, time_steps: usize, features: usize, data: &[T]) -> Self {
292 Self::from_shape([batch, time_steps, features], data)
293 }
294
295 #[inline]
298 pub fn hwc(height: usize, width: usize, channels: usize, data: &[T]) -> Self {
299 Self::from_shape([height, width, channels], data)
300 }
301}
302
303impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 4> {
308 #[inline]
310 pub fn nhwc(batch: usize, height: usize, width: usize, channels: usize, data: &[T]) -> Self {
311 Self::from_shape([batch, height, width, channels], data)
312 }
313
314 #[inline]
316 pub fn nchw(batch: usize, channels: usize, height: usize, width: usize, data: &[T]) -> Self {
317 Self::from_shape([batch, channels, height, width], data)
318 }
319}
320
321impl<T: Copy + Default + DType, const N: usize, const R: usize> Default for Tensor<T, N, R> {
326 #[inline]
327 fn default() -> Self {
328 Self {
329 data: [T::default(); N],
330 len: 0,
331 shape: [0; R],
332 }
333 }
334}
335
336impl<T: Copy + Default + DType + core::fmt::Debug, const N: usize, const R: usize> core::fmt::Debug
337 for Tensor<T, N, R>
338{
339 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340 f.debug_struct("Tensor")
342 .field("data_type", &self.data_type())
343 .field("len", &self.len)
344 .field("shape", &&self.shape[..])
345 .field("capacity", &N)
346 .field("data", &self.as_slice())
347 .finish()
348 }
349}
350
351impl<T: Copy + Default + DType + PartialEq, const N: usize, const R: usize> PartialEq
352 for Tensor<T, N, R>
353{
354 fn eq(&self, other: &Self) -> bool {
355 self.len == other.len
356 && self.shape == other.shape
357 && self.data[..self.len] == other.data[..other.len]
358 }
359}
360
361impl<T: Copy + Default + DType + Eq, const N: usize, const R: usize> Eq for Tensor<T, N, R> {}
362
363impl<T: Copy + Default + DType, const N: usize, const R: usize> Payload for Tensor<T, N, R> {
364 #[inline]
365 fn buffer_descriptor(&self) -> BufferDescriptor {
366 BufferDescriptor::new(self.byte_len())
367 }
368}
369
370#[inline]
372fn checked_product(shape: &[usize]) -> Option<usize> {
373 let mut acc = 1usize;
374 for &d in shape {
375 acc = acc.checked_mul(d)?;
376 }
377 Some(acc)
378}
379
380#[cfg(any(test, feature = "bench"))]
385pub const TEST_TENSOR_SHAPE: [usize; 2] = [3, 3];
386
387#[cfg(any(test, feature = "bench"))]
391pub const TEST_TENSOR_ELEMENT_COUNT: usize = 9;
392
393#[cfg(any(test, feature = "bench"))]
398pub const TEST_TENSOR_BYTE_COUNT: usize = TEST_TENSOR_ELEMENT_COUNT * mem::size_of::<u32>();
399
400#[cfg(any(test, feature = "bench"))]
405pub type TestTensor = Tensor<u32, TEST_TENSOR_ELEMENT_COUNT, 2>;
406
407#[cfg(any(test, feature = "bench"))]
411#[inline]
412pub fn create_test_tensor_filled_with(value: u32) -> TestTensor {
413 Tensor::filled(TEST_TENSOR_SHAPE, value)
414}
415
416#[cfg(any(test, feature = "bench"))]
422#[inline]
423pub fn create_test_tensor_from_array(values: [[u32; 3]; 3]) -> TestTensor {
424 Tensor::from_shape(
425 TEST_TENSOR_SHAPE,
426 &[
427 values[0][0],
428 values[0][1],
429 values[0][2],
430 values[1][0],
431 values[1][1],
432 values[1][2],
433 values[2][0],
434 values[2][1],
435 values[2][2],
436 ],
437 )
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use crate::types::DataType;
444
445 #[test]
450 fn default_is_empty() {
451 let t = Tensor::<f32, 16, 2>::default();
452 assert!(t.is_empty());
453 assert_eq!(t.len(), 0);
454 assert_eq!(t.shape(), &[0, 0]);
455 assert_eq!(t.byte_len(), 0);
456 assert_eq!(t.rank(), 2);
457 assert_eq!(t.capacity(), 16);
458 }
459
460 #[test]
461 fn from_shape_basic() {
462 let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
463 let t = Tensor::<f32, 8, 2>::from_shape([2, 3], &data);
464 assert_eq!(t.len(), 6);
465 assert_eq!(t.shape(), &[2, 3]);
466 assert_eq!(t.as_slice(), &data);
467 assert_eq!(t.byte_len(), 6 * 4);
468 assert!(t.is_compatible());
469 }
470
471 #[test]
472 fn from_shape_partial_capacity() {
473 let data = [1u8, 2, 3];
474 let t = Tensor::<u8, 256, 1>::from_shape([3], &data);
475 assert_eq!(t.len(), 3);
476 assert_eq!(t.capacity(), 256);
477 assert_eq!(t.as_slice(), &[1, 2, 3]);
478 }
479
480 #[test]
481 #[should_panic(expected = "exceeds capacity")]
482 fn from_shape_panics_on_overflow() {
483 let data = [0u8; 32];
484 let _ = Tensor::<u8, 16, 1>::from_shape([32], &data);
485 }
486
487 #[test]
488 fn filled_creates_uniform_tensor() {
489 let t = Tensor::<i8, 12, 2>::filled([3, 4], 42i8);
490 assert_eq!(t.len(), 12);
491 assert_eq!(t.shape(), &[3, 4]);
492 assert!(t.as_slice().iter().all(|&v| v == 42));
493 }
494
495 #[test]
496 #[should_panic(expected = "exceeds capacity")]
497 fn filled_panics_on_overflow() {
498 let _ = Tensor::<f32, 4, 2>::filled([3, 3], 0.0);
499 }
500
501 #[test]
502 fn zeros_is_all_default() {
503 let t = Tensor::<u32, 64, 3>::zeros([2, 4, 8]);
504 assert_eq!(t.len(), 64);
505 assert!(t.as_slice().iter().all(|&v| v == 0));
506 }
507
508 #[test]
513 fn scalar_round_trip() {
514 let t = Tensor::<f32, 1, 0>::scalar(3.14);
515 assert_eq!(t.value(), 3.14);
516 assert_eq!(t.len(), 1);
517 assert_eq!(t.rank(), 0);
518 assert_eq!(t.shape(), &[]);
519 assert_eq!(t.byte_len(), 4);
520 assert!(t.is_compatible());
521 }
522
523 #[test]
524 fn scalar_with_excess_capacity() {
525 let t = Tensor::<u8, 4, 0>::scalar(7);
526 assert_eq!(t.value(), 7);
527 assert_eq!(t.capacity(), 4);
528 assert_eq!(t.len(), 1);
529 }
530
531 #[test]
536 fn from_slice_rank1() {
537 let t = Tensor::<f32, 8, 1>::from_slice(&[1.0, 2.0, 3.0]);
538 assert_eq!(t.len(), 3);
539 assert_eq!(t.shape(), &[3]);
540 assert_eq!(t.as_slice(), &[1.0, 2.0, 3.0]);
541 }
542
543 #[test]
544 fn from_slice_empty() {
545 let t = Tensor::<u8, 8, 1>::from_slice(&[]);
546 assert!(t.is_empty());
547 assert_eq!(t.shape(), &[0]);
548 }
549
550 #[test]
555 fn nc_constructor() {
556 let data = [10i8, 20, 30];
557 let t = Tensor::<i8, 4, 2>::nc(1, 3, &data);
558 assert_eq!(t.shape(), &[1, 3]);
559 assert_eq!(t.len(), 3);
560 assert_eq!(t.as_slice(), &data);
561 }
562
563 #[test]
564 fn matrix_constructor() {
565 let data: [f32; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
566 let t = Tensor::<f32, 8, 2>::matrix(2, 3, &data);
567 assert_eq!(t.shape(), &[2, 3]);
568 assert_eq!(t.len(), 6);
569 }
570
571 #[test]
576 fn sequence_constructor() {
577 let data = [0i8; 1960];
578 let t = Tensor::<i8, 1960, 3>::sequence(1, 49, 40, &data);
579 assert_eq!(t.shape(), &[1, 49, 40]);
580 assert_eq!(t.len(), 1960);
581 }
582
583 #[test]
584 fn hwc_constructor() {
585 let data = [0u8; 48];
586 let t = Tensor::<u8, 48, 3>::hwc(4, 4, 3, &data);
587 assert_eq!(t.shape(), &[4, 4, 3]);
588 assert_eq!(t.len(), 48);
589 }
590
591 #[test]
596 fn nhwc_constructor() {
597 let data = [0u8; 9216];
598 let t = Tensor::<u8, 9216, 4>::nhwc(1, 96, 96, 1, &data);
599 assert_eq!(t.shape(), &[1, 96, 96, 1]);
600 assert_eq!(t.len(), 9216);
601 assert!(t.is_compatible());
602 }
603
604 #[test]
605 fn nchw_constructor() {
606 let data = [0.0f32; 48];
608 let t = Tensor::<f32, 48, 4>::nchw(1, 3, 4, 4, &data);
609 assert_eq!(t.shape(), &[1, 3, 4, 4]);
610 }
611
612 #[test]
617 fn at_and_set_rank2() {
618 let data = [1u32, 2, 3, 4, 5, 6];
619 let mut t = Tensor::<u32, 8, 2>::matrix(2, 3, &data);
620
621 assert_eq!(t.at([0, 0]), 1);
623 assert_eq!(t.at([0, 2]), 3);
624 assert_eq!(t.at([1, 0]), 4);
625 assert_eq!(t.at([1, 2]), 6);
626
627 t.set([1, 1], 99);
628 assert_eq!(t.at([1, 1]), 99);
629 }
630
631 #[test]
632 fn at_rank4_nhwc() {
633 let data: [u8; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
635 let t = Tensor::<u8, 12, 4>::nhwc(1, 2, 2, 3, &data);
636
637 assert_eq!(t.at([0, 0, 0, 0]), 0);
641 assert_eq!(t.at([0, 0, 0, 2]), 2);
642 assert_eq!(t.at([0, 0, 1, 0]), 3);
643 assert_eq!(t.at([0, 1, 0, 0]), 6);
644 assert_eq!(t.at([0, 1, 1, 2]), 11);
645 }
646
647 #[test]
648 fn at_rank3_sequence() {
649 let data = [10i8, 20, 30, 40, 50, 60];
651 let t = Tensor::<i8, 8, 3>::sequence(1, 3, 2, &data);
652
653 assert_eq!(t.at([0, 0, 0]), 10);
654 assert_eq!(t.at([0, 0, 1]), 20);
655 assert_eq!(t.at([0, 1, 0]), 30);
656 assert_eq!(t.at([0, 2, 1]), 60);
657 }
658
659 #[test]
660 fn at_rank1() {
661 let t = Tensor::<f32, 4, 1>::from_slice(&[10.0, 20.0, 30.0]);
662 assert_eq!(t.at([0]), 10.0);
663 assert_eq!(t.at([2]), 30.0);
664 }
665
666 #[test]
667 #[should_panic]
668 fn at_out_of_bounds_panics() {
669 let t = Tensor::<u8, 4, 1>::from_slice(&[1, 2, 3]);
670 let _ = t.at([3]);
671 }
672
673 #[test]
678 fn reshape_preserves_data() {
679 let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
680 let mut t = Tensor::<f32, 8, 2>::matrix(2, 3, &data);
681
682 t.reshape([3, 2]);
683 assert_eq!(t.shape(), &[3, 2]);
684 assert_eq!(t.len(), 6);
685 assert_eq!(t.as_slice(), &data);
686 assert!(t.is_compatible());
687
688 assert_eq!(t.at([0, 0]), 1.0);
690 assert_eq!(t.at([0, 1]), 2.0);
691 assert_eq!(t.at([1, 0]), 3.0);
692 assert_eq!(t.at([2, 1]), 6.0);
693 }
694
695 #[test]
700 fn as_mut_slice_modification() {
701 let mut t = Tensor::<u8, 8, 1>::from_slice(&[0, 0, 0]);
702 t.as_mut_slice().copy_from_slice(&[10, 20, 30]);
703 assert_eq!(t.as_slice(), &[10, 20, 30]);
704 }
705
706 #[test]
711 fn data_type_reflects_element() {
712 assert_eq!(
713 Tensor::<f32, 4, 1>::default().data_type(),
714 DataType::Float32
715 );
716 assert_eq!(
717 Tensor::<f64, 4, 1>::default().data_type(),
718 DataType::Float64
719 );
720 assert_eq!(
721 Tensor::<u8, 4, 1>::default().data_type(),
722 DataType::Unsigned8
723 );
724 assert_eq!(Tensor::<i8, 4, 1>::default().data_type(), DataType::Signed8);
725 assert_eq!(
726 Tensor::<u16, 4, 1>::default().data_type(),
727 DataType::Unsigned16
728 );
729 assert_eq!(
730 Tensor::<i16, 4, 1>::default().data_type(),
731 DataType::Signed16
732 );
733 assert_eq!(
734 Tensor::<u32, 4, 1>::default().data_type(),
735 DataType::Unsigned32
736 );
737 assert_eq!(
738 Tensor::<i32, 4, 1>::default().data_type(),
739 DataType::Signed32
740 );
741 assert_eq!(
742 Tensor::<u64, 4, 1>::default().data_type(),
743 DataType::Unsigned64
744 );
745 assert_eq!(
746 Tensor::<i64, 4, 1>::default().data_type(),
747 DataType::Signed64
748 );
749 assert_eq!(
750 Tensor::<bool, 4, 1>::default().data_type(),
751 DataType::Boolean
752 );
753 }
754
755 #[test]
760 fn byte_len_correct_for_types() {
761 let t_f32 = Tensor::<f32, 8, 1>::from_slice(&[0.0; 6]);
762 assert_eq!(t_f32.byte_len(), 24); let t_u8 = Tensor::<u8, 8, 1>::from_slice(&[0u8; 5]);
765 assert_eq!(t_u8.byte_len(), 5); let t_f64 = Tensor::<f64, 4, 1>::from_slice(&[0.0f64; 3]);
768 assert_eq!(t_f64.byte_len(), 24); let t_empty = Tensor::<f32, 4, 1>::default();
771 assert_eq!(t_empty.byte_len(), 0);
772 }
773
774 #[test]
779 fn payload_buffer_descriptor() {
780 let t = Tensor::<f32, 8, 2>::matrix(2, 3, &[0.0; 6]);
781 let bd = t.buffer_descriptor();
782 assert_eq!(*bd.bytes(), 24);
783 }
784
785 #[test]
790 fn eq_same_data_same_shape() {
791 let a = Tensor::<u8, 8, 2>::matrix(2, 3, &[1, 2, 3, 4, 5, 6]);
792 let b = Tensor::<u8, 8, 2>::matrix(2, 3, &[1, 2, 3, 4, 5, 6]);
793 assert_eq!(a, b);
794 }
795
796 #[test]
797 fn ne_different_data() {
798 let a = Tensor::<u8, 4, 1>::from_slice(&[1, 2, 3]);
799 let b = Tensor::<u8, 4, 1>::from_slice(&[1, 2, 4]);
800 assert_ne!(a, b);
801 }
802
803 #[test]
804 fn ne_different_shape_same_data() {
805 let data = [1u8, 2, 3, 4, 5, 6];
806 let a = Tensor::<u8, 8, 2>::matrix(2, 3, &data);
807 let b = Tensor::<u8, 8, 2>::matrix(3, 2, &data);
808 assert_ne!(a, b);
809 }
810
811 #[test]
812 fn ne_different_len() {
813 let a = Tensor::<u8, 8, 1>::from_slice(&[1, 2, 3]);
814 let b = Tensor::<u8, 8, 1>::from_slice(&[1, 2]);
815 assert_ne!(a, b);
816 }
817
818 #[test]
819 fn eq_ignores_padding_beyond_len() {
820 let a = Tensor::<u8, 8, 1>::from_slice(&[1, 2, 3]);
822 let _b = Tensor::<u8, 16, 1>::from_slice(&[1, 2, 3]);
823 let mut c = Tensor::<u8, 8, 1>::default();
826 c.data[0] = 1;
827 c.data[1] = 2;
828 c.data[2] = 3;
829 c.data[7] = 99; c.len = 3;
831 c.shape = [3];
832 assert_eq!(a, c);
833 }
834
835 #[test]
840 fn tensor_is_copy() {
841 let a = Tensor::<f32, 4, 1>::from_slice(&[1.0, 2.0]);
842 let b = a; let c = a; assert_eq!(b, c);
845 }
846
847 #[test]
852 fn tensor_clone_equals_original() {
853 let a = Tensor::<i32, 16, 3>::from_shape([2, 2, 2], &[1, 2, 3, 4, 5, 6, 7, 8]);
854 let b = a.clone();
855 assert_eq!(a, b);
856 }
857
858 #[test]
863 fn debug_format_does_not_panic() {
864 use core::fmt::Write;
865 struct StackBuf([u8; 256], usize);
866 impl Write for StackBuf {
867 fn write_str(&mut self, s: &str) -> core::fmt::Result {
868 for &b in s.as_bytes() {
869 if self.1 >= self.0.len() {
870 return Ok(()); }
872 self.0[self.1] = b;
873 self.1 += 1;
874 }
875 Ok(())
876 }
877 }
878 let t = Tensor::<f32, 4, 2>::matrix(2, 2, &[1.0, 2.0, 3.0, 4.0]);
879 let mut buf = StackBuf([0u8; 256], 0);
880 write!(buf, "{:?}", t).unwrap();
881 let s = core::str::from_utf8(&buf.0[..buf.1]).unwrap();
882 assert!(s.contains("Tensor"));
883 assert!(s.contains("Float32"));
884 assert!(s.contains("len: 4"));
885 }
886
887 #[test]
892 fn is_compatible_valid() {
893 let t = Tensor::<u8, 8, 2>::matrix(2, 3, &[0; 6]);
894 assert!(t.is_compatible());
895 }
896
897 #[test]
898 fn is_compatible_default_zero_shape() {
899 let t = Tensor::<u8, 8, 2>::default();
901 assert!(t.is_compatible());
902 }
903
904 #[test]
909 fn flat_index_row_major_rank3() {
910 let data: [u32; 24] = [
912 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
913 ];
914 let t = Tensor::<u32, 24, 3>::from_shape([2, 3, 4], &data);
915
916 for i in 0..2 {
918 for j in 0..3 {
919 for k in 0..4 {
920 let expected = (i * 12 + j * 4 + k) as u32;
921 assert_eq!(t.at([i, j, k]), expected);
922 }
923 }
924 }
925 }
926
927 #[test]
928 fn flat_index_rank4_exhaustive_small() {
929 let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
931 let t = Tensor::<u8, 8, 4>::nhwc(1, 2, 2, 2, &data);
932
933 let mut idx = 0u8;
934 for b in 0..1 {
935 for h in 0..2 {
936 for w in 0..2 {
937 for c in 0..2 {
938 assert_eq!(t.at([b, h, w, c]), idx);
939 idx += 1;
940 }
941 }
942 }
943 }
944 }
945
946 #[test]
951 fn tflm_person_detection_shapes() {
952 let input = Tensor::<u8, 9216, 4>::nhwc(1, 96, 96, 1, &[128u8; 9216]);
954 assert_eq!(input.len(), 9216);
955 assert_eq!(input.byte_len(), 9216);
956 let [b, h, w, c] = *input.shape();
957 assert_eq!((b, h, w, c), (1, 96, 96, 1));
958
959 let output = Tensor::<i8, 3, 2>::nc(1, 3, &[10, -5, 30]);
961 assert_eq!(output.len(), 3);
962 assert_eq!(output.at([0, 2]), 30);
963 }
964
965 #[test]
966 fn tflm_keyword_spotting_shapes() {
967 let input = Tensor::<i8, 1960, 3>::sequence(1, 49, 40, &[0i8; 1960]);
969 assert_eq!(input.len(), 1960);
970 let [b, t, f] = *input.shape();
971 assert_eq!((b, t, f), (1, 49, 40));
972
973 let output = Tensor::<i8, 12, 2>::nc(1, 12, &[0i8; 12]);
975 assert_eq!(output.len(), 12);
976 }
977
978 #[test]
979 fn tract_mobilenet_shapes() {
980 let input = Tensor::<f32, 150528, 4>::nchw(1, 3, 224, 224, &[0.0f32; 150528]);
983 assert_eq!(input.len(), 150528);
984 assert_eq!(input.byte_len(), 150528 * 4);
985 let [b, c, h, w] = *input.shape();
986 assert_eq!((b, c, h, w), (1, 3, 224, 224));
987 }
988}