Skip to main content

limen_core/message/
tensor.rs

1//! Owned, fixed-capacity inline tensor (`no_std`, `no_alloc`).
2//!
3//! [`Tensor<T, N, R>`] stores up to `N` elements of scalar type `T` with
4//! a compile-time rank `R`. Data is held inline (stack / static storage),
5//! so no heap allocation is required. A live-element count (`len`) tracks
6//! how many of the `N` slots are in use; the shape array `[usize; R]` records
7//! the logical dimensions.
8//!
9//! Use `&Tensor<T, N, R>` for a borrowed, zero-copy view into an existing
10//! buffer (e.g. a frame from a DMA ring).
11
12use core::mem;
13
14use crate::memory::BufferDescriptor;
15use crate::message::payload::Payload;
16use crate::types::{DType, DataType};
17
18// ---------------------------------------------------------------------------
19// Owned, fixed-capacity tensor
20// ---------------------------------------------------------------------------
21
22/// Owned, fixed-capacity tensor stored inline.
23///
24/// - `T` — scalar element type (`Copy + Default + DType`).
25/// - `N` — maximum element capacity (compile-time).
26/// - `R` — tensor rank / number of dimensions (compile-time).
27///
28/// `len` tracks live elements (≤ `N`). Shape is `[usize; R]` with all
29/// dimensions active. Rank-0 scalars use `R = 0`.
30///
31/// Byte size is computed, not stored — one multiply is cheaper than 8 bytes
32/// of struct overhead on a constrained MCU.
33///
34/// # Examples
35/// ```text
36/// // TF Lite Micro person detection input.
37/// let img = Tensor::<u8, 9216, 4>::nhwc(1, 96, 96, 1, &pixels);
38///
39/// // Classifier output.
40/// let out = Tensor::<i8, 3, 2>::nc(1, 3, &scores);
41///
42/// // Flat feature vector.
43/// let v = Tensor::<f32, 128, 1>::from_slice(&features);
44///
45/// // Destructure shape for processing.
46/// let [batch, height, width, channels] = *img.shape();
47/// let pixel = img.at([0, 10, 20, 0]);
48/// ```
49#[derive(Clone, Copy)]
50pub struct Tensor<T: Copy + Default + DType, const N: usize, const R: usize> {
51    data: [T; N],
52    len: usize,
53    shape: [usize; R],
54}
55
56// ---------------------------------------------------------------------------
57// Core API (all ranks)
58// ---------------------------------------------------------------------------
59
60impl<T: Copy + Default + DType, const N: usize, const R: usize> Tensor<T, N, R> {
61    /// Create from a shape and data slice.
62    ///
63    /// # Panics
64    /// - If `data.len() > N`.
65    /// - In debug: if the product of `shape` != `data.len()`.
66    #[inline]
67    pub fn from_shape(shape: [usize; R], data: &[T]) -> Self {
68        assert!(
69            data.len() <= N,
70            "Tensor: data length {} exceeds capacity {}",
71            data.len(),
72            N
73        );
74        debug_assert_eq!(
75            checked_product(&shape),
76            Some(data.len()),
77            "Tensor: shape product != data length"
78        );
79
80        let mut buf = [T::default(); N];
81        buf[..data.len()].copy_from_slice(data);
82
83        Self {
84            data: buf,
85            len: data.len(),
86            shape,
87        }
88    }
89
90    /// Create filled with `value`.
91    ///
92    /// # Panics
93    /// If the shape product exceeds `N`.
94    #[inline]
95    pub fn filled(shape: [usize; R], value: T) -> Self {
96        let count = checked_product(&shape).expect("Tensor: shape overflow");
97        assert!(count <= N, "Tensor: count {} exceeds capacity {}", count, N);
98
99        let mut buf = [T::default(); N];
100        let mut i = 0;
101        while i < count {
102            buf[i] = value;
103            i += 1;
104        }
105
106        Self {
107            data: buf,
108            len: count,
109            shape,
110        }
111    }
112
113    /// Create zeroed with the given shape.
114    #[inline]
115    pub fn zeros(shape: [usize; R]) -> Self {
116        Self::filled(shape, T::default())
117    }
118
119    /// The `DataType` of the scalar elements.
120    #[inline]
121    pub fn data_type(&self) -> DataType {
122        T::DATA_TYPE
123    }
124
125    /// Number of live elements.
126    #[inline]
127    pub fn len(&self) -> usize {
128        self.len
129    }
130
131    /// Whether the tensor has zero live elements.
132    #[inline]
133    pub fn is_empty(&self) -> bool {
134        self.len == 0
135    }
136
137    /// Compile-time element capacity.
138    #[inline]
139    pub const fn capacity(&self) -> usize {
140        N
141    }
142
143    /// Rank (compile-time constant).
144    #[inline]
145    pub const fn rank(&self) -> usize {
146        R
147    }
148
149    /// Active shape dimensions.
150    #[inline]
151    pub fn shape(&self) -> &[usize; R] {
152        &self.shape
153    }
154
155    /// Borrow the live data.
156    #[inline]
157    pub fn as_slice(&self) -> &[T] {
158        &self.data[..self.len]
159    }
160
161    /// Mutably borrow the live data.
162    #[inline]
163    pub fn as_mut_slice(&mut self) -> &mut [T] {
164        &mut self.data[..self.len]
165    }
166
167    /// Total bytes of live data (computed, not stored).
168    #[inline]
169    pub fn byte_len(&self) -> usize {
170        self.len.saturating_mul(mem::size_of::<T>())
171    }
172
173    /// Reshape in place (metadata only).
174    ///
175    /// # Panics
176    /// In debug: if the shape product != `self.len`.
177    #[inline]
178    pub fn reshape(&mut self, new_shape: [usize; R]) {
179        debug_assert_eq!(
180            checked_product(&new_shape),
181            Some(self.len),
182            "Tensor::reshape: shape product != len"
183        );
184        self.shape = new_shape;
185    }
186
187    /// Validate that the shape product matches the element count.
188    #[inline]
189    pub fn is_compatible(&self) -> bool {
190        checked_product(&self.shape) == Some(self.len)
191    }
192
193    /// Read element by multi-dimensional index.
194    ///
195    /// # Panics
196    /// If the flat index is out of bounds.
197    #[inline]
198    pub fn at(&self, index: [usize; R]) -> T {
199        self.data[self.flat_index(index)]
200    }
201
202    /// Write element by multi-dimensional index.
203    ///
204    /// # Panics
205    /// If the flat index is out of bounds.
206    #[inline]
207    pub fn set(&mut self, index: [usize; R], value: T) {
208        let i = self.flat_index(index);
209        self.data[i] = value;
210    }
211
212    /// Convert multi-dimensional index to flat offset (row-major).
213    #[inline]
214    fn flat_index(&self, index: [usize; R]) -> usize {
215        let mut flat = 0usize;
216        let mut stride = 1usize;
217        let mut d = R;
218        while d > 0 {
219            d -= 1;
220            flat += index[d] * stride;
221            stride *= self.shape[d];
222        }
223        assert!(flat < self.len, "tensor index out of bounds");
224        flat
225    }
226}
227
228// ---------------------------------------------------------------------------
229// Rank-0: scalar
230// ---------------------------------------------------------------------------
231
232impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 0> {
233    /// Create a rank-0 scalar.
234    #[inline]
235    pub fn scalar(value: T) -> Self {
236        assert!(N >= 1, "Tensor::scalar: capacity must be >= 1");
237        let mut buf = [T::default(); N];
238        buf[0] = value;
239        Self {
240            data: buf,
241            len: 1,
242            shape: [],
243        }
244    }
245
246    /// Read the scalar value.
247    #[inline]
248    pub fn value(&self) -> T {
249        self.data[0]
250    }
251}
252
253// ---------------------------------------------------------------------------
254// Rank-1: flat vector
255// ---------------------------------------------------------------------------
256
257impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 1> {
258    /// Create from a data slice (shape inferred as `[data.len()]`).
259    #[inline]
260    pub fn from_slice(data: &[T]) -> Self {
261        Self::from_shape([data.len()], data)
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Rank-2: common constructors
267// ---------------------------------------------------------------------------
268
269impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 2> {
270    /// `[batch, classes]` — classifier output (TF Lite Micro softmax, etc.).
271    #[inline]
272    pub fn nc(batch: usize, classes: usize, data: &[T]) -> Self {
273        Self::from_shape([batch, classes], data)
274    }
275
276    /// `[rows, cols]` — 2-D matrix.
277    #[inline]
278    pub fn matrix(rows: usize, cols: usize, data: &[T]) -> Self {
279        Self::from_shape([rows, cols], data)
280    }
281}
282
283// ---------------------------------------------------------------------------
284// Rank-3: common constructors
285// ---------------------------------------------------------------------------
286
287impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 3> {
288    /// `[batch, time_steps, features]` — sequence / spectrogram
289    /// (keyword spotting, audio models).
290    #[inline]
291    pub fn sequence(batch: usize, time_steps: usize, features: usize, data: &[T]) -> Self {
292        Self::from_shape([batch, time_steps, features], data)
293    }
294
295    /// `[height, width, channels]` — single image, no batch dim
296    /// (common in resource-constrained pipelines).
297    #[inline]
298    pub fn hwc(height: usize, width: usize, channels: usize, data: &[T]) -> Self {
299        Self::from_shape([height, width, channels], data)
300    }
301}
302
303// ---------------------------------------------------------------------------
304// Rank-4: common constructors
305// ---------------------------------------------------------------------------
306
307impl<T: Copy + Default + DType, const N: usize> Tensor<T, N, 4> {
308    /// `[batch, height, width, channels]` — TF Lite / TF Lite Micro standard.
309    #[inline]
310    pub fn nhwc(batch: usize, height: usize, width: usize, channels: usize, data: &[T]) -> Self {
311        Self::from_shape([batch, height, width, channels], data)
312    }
313
314    /// `[batch, channels, height, width]` — PyTorch / tract default.
315    #[inline]
316    pub fn nchw(batch: usize, channels: usize, height: usize, width: usize, data: &[T]) -> Self {
317        Self::from_shape([batch, channels, height, width], data)
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Trait impls
323// ---------------------------------------------------------------------------
324
325impl<T: Copy + Default + DType, const N: usize, const R: usize> Default for Tensor<T, N, R> {
326    #[inline]
327    fn default() -> Self {
328        Self {
329            data: [T::default(); N],
330            len: 0,
331            shape: [0; R],
332        }
333    }
334}
335
336impl<T: Copy + Default + DType + core::fmt::Debug, const N: usize, const R: usize> core::fmt::Debug
337    for Tensor<T, N, R>
338{
339    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
340        // show metadata and the live data slice
341        f.debug_struct("Tensor")
342            .field("data_type", &self.data_type())
343            .field("len", &self.len)
344            .field("shape", &&self.shape[..])
345            .field("capacity", &N)
346            .field("data", &self.as_slice())
347            .finish()
348    }
349}
350
351impl<T: Copy + Default + DType + PartialEq, const N: usize, const R: usize> PartialEq
352    for Tensor<T, N, R>
353{
354    fn eq(&self, other: &Self) -> bool {
355        self.len == other.len
356            && self.shape == other.shape
357            && self.data[..self.len] == other.data[..other.len]
358    }
359}
360
361impl<T: Copy + Default + DType + Eq, const N: usize, const R: usize> Eq for Tensor<T, N, R> {}
362
363impl<T: Copy + Default + DType, const N: usize, const R: usize> Payload for Tensor<T, N, R> {
364    #[inline]
365    fn buffer_descriptor(&self) -> BufferDescriptor {
366        BufferDescriptor::new(self.byte_len())
367    }
368}
369
370/// Checked product of a shape array. Returns `None` on overflow.
371#[inline]
372fn checked_product(shape: &[usize]) -> Option<usize> {
373    let mut acc = 1usize;
374    for &d in shape {
375        acc = acc.checked_mul(d)?;
376    }
377    Some(acc)
378}
379
380/// Canonical shape used by shared test and benchmark tensor fixtures.
381///
382/// This shape is fixed at `3 × 3` so all common test payloads use the same
383/// rank-2 layout across memory, message, and benchmark code.
384#[cfg(any(test, feature = "bench"))]
385pub const TEST_TENSOR_SHAPE: [usize; 2] = [3, 3];
386
387/// Total live element count of [`TestTensor`].
388///
389/// This is the product of [`TEST_TENSOR_SHAPE`] and is fixed at `9`.
390#[cfg(any(test, feature = "bench"))]
391pub const TEST_TENSOR_ELEMENT_COUNT: usize = 9;
392
393/// Total live payload byte count of [`TestTensor`].
394///
395/// This reflects the tensor payload size reported through [`Payload`], not the
396/// full in-memory size of the `Tensor` struct itself.
397#[cfg(any(test, feature = "bench"))]
398pub const TEST_TENSOR_BYTE_COUNT: usize = TEST_TENSOR_ELEMENT_COUNT * mem::size_of::<u32>();
399
400/// Shared rank-2 tensor payload used by tests and benchmarks.
401///
402/// This alias standardizes on a `3 × 3` tensor of `u32` values with exactly
403/// `9` live elements.
404#[cfg(any(test, feature = "bench"))]
405pub type TestTensor = Tensor<u32, TEST_TENSOR_ELEMENT_COUNT, 2>;
406
407/// Create a shared test tensor with every element set to the same value.
408///
409/// The returned tensor always has shape [`TEST_TENSOR_SHAPE`].
410#[cfg(any(test, feature = "bench"))]
411#[inline]
412pub fn create_test_tensor_filled_with(value: u32) -> TestTensor {
413    Tensor::filled(TEST_TENSOR_SHAPE, value)
414}
415
416/// Create a shared test tensor from explicit `3 × 3` element values.
417///
418/// The input array is flattened in row-major order into a tensor with shape
419/// [`TEST_TENSOR_SHAPE`], so `values[0][0]` becomes the first element and
420/// `values[2][2]` becomes the last.
421#[cfg(any(test, feature = "bench"))]
422#[inline]
423pub fn create_test_tensor_from_array(values: [[u32; 3]; 3]) -> TestTensor {
424    Tensor::from_shape(
425        TEST_TENSOR_SHAPE,
426        &[
427            values[0][0],
428            values[0][1],
429            values[0][2],
430            values[1][0],
431            values[1][1],
432            values[1][2],
433            values[2][0],
434            values[2][1],
435            values[2][2],
436        ],
437    )
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use crate::types::DataType;
444
445    // -----------------------------------------------------------------
446    // Construction
447    // -----------------------------------------------------------------
448
449    #[test]
450    fn default_is_empty() {
451        let t = Tensor::<f32, 16, 2>::default();
452        assert!(t.is_empty());
453        assert_eq!(t.len(), 0);
454        assert_eq!(t.shape(), &[0, 0]);
455        assert_eq!(t.byte_len(), 0);
456        assert_eq!(t.rank(), 2);
457        assert_eq!(t.capacity(), 16);
458    }
459
460    #[test]
461    fn from_shape_basic() {
462        let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
463        let t = Tensor::<f32, 8, 2>::from_shape([2, 3], &data);
464        assert_eq!(t.len(), 6);
465        assert_eq!(t.shape(), &[2, 3]);
466        assert_eq!(t.as_slice(), &data);
467        assert_eq!(t.byte_len(), 6 * 4);
468        assert!(t.is_compatible());
469    }
470
471    #[test]
472    fn from_shape_partial_capacity() {
473        let data = [1u8, 2, 3];
474        let t = Tensor::<u8, 256, 1>::from_shape([3], &data);
475        assert_eq!(t.len(), 3);
476        assert_eq!(t.capacity(), 256);
477        assert_eq!(t.as_slice(), &[1, 2, 3]);
478    }
479
480    #[test]
481    #[should_panic(expected = "exceeds capacity")]
482    fn from_shape_panics_on_overflow() {
483        let data = [0u8; 32];
484        let _ = Tensor::<u8, 16, 1>::from_shape([32], &data);
485    }
486
487    #[test]
488    fn filled_creates_uniform_tensor() {
489        let t = Tensor::<i8, 12, 2>::filled([3, 4], 42i8);
490        assert_eq!(t.len(), 12);
491        assert_eq!(t.shape(), &[3, 4]);
492        assert!(t.as_slice().iter().all(|&v| v == 42));
493    }
494
495    #[test]
496    #[should_panic(expected = "exceeds capacity")]
497    fn filled_panics_on_overflow() {
498        let _ = Tensor::<f32, 4, 2>::filled([3, 3], 0.0);
499    }
500
501    #[test]
502    fn zeros_is_all_default() {
503        let t = Tensor::<u32, 64, 3>::zeros([2, 4, 8]);
504        assert_eq!(t.len(), 64);
505        assert!(t.as_slice().iter().all(|&v| v == 0));
506    }
507
508    // -----------------------------------------------------------------
509    // Rank-0: scalar
510    // -----------------------------------------------------------------
511
512    #[test]
513    fn scalar_round_trip() {
514        let t = Tensor::<f32, 1, 0>::scalar(3.14);
515        assert_eq!(t.value(), 3.14);
516        assert_eq!(t.len(), 1);
517        assert_eq!(t.rank(), 0);
518        assert_eq!(t.shape(), &[]);
519        assert_eq!(t.byte_len(), 4);
520        assert!(t.is_compatible());
521    }
522
523    #[test]
524    fn scalar_with_excess_capacity() {
525        let t = Tensor::<u8, 4, 0>::scalar(7);
526        assert_eq!(t.value(), 7);
527        assert_eq!(t.capacity(), 4);
528        assert_eq!(t.len(), 1);
529    }
530
531    // -----------------------------------------------------------------
532    // Rank-1: from_slice
533    // -----------------------------------------------------------------
534
535    #[test]
536    fn from_slice_rank1() {
537        let t = Tensor::<f32, 8, 1>::from_slice(&[1.0, 2.0, 3.0]);
538        assert_eq!(t.len(), 3);
539        assert_eq!(t.shape(), &[3]);
540        assert_eq!(t.as_slice(), &[1.0, 2.0, 3.0]);
541    }
542
543    #[test]
544    fn from_slice_empty() {
545        let t = Tensor::<u8, 8, 1>::from_slice(&[]);
546        assert!(t.is_empty());
547        assert_eq!(t.shape(), &[0]);
548    }
549
550    // -----------------------------------------------------------------
551    // Rank-2: named constructors
552    // -----------------------------------------------------------------
553
554    #[test]
555    fn nc_constructor() {
556        let data = [10i8, 20, 30];
557        let t = Tensor::<i8, 4, 2>::nc(1, 3, &data);
558        assert_eq!(t.shape(), &[1, 3]);
559        assert_eq!(t.len(), 3);
560        assert_eq!(t.as_slice(), &data);
561    }
562
563    #[test]
564    fn matrix_constructor() {
565        let data: [f32; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
566        let t = Tensor::<f32, 8, 2>::matrix(2, 3, &data);
567        assert_eq!(t.shape(), &[2, 3]);
568        assert_eq!(t.len(), 6);
569    }
570
571    // -----------------------------------------------------------------
572    // Rank-3: named constructors
573    // -----------------------------------------------------------------
574
575    #[test]
576    fn sequence_constructor() {
577        let data = [0i8; 1960];
578        let t = Tensor::<i8, 1960, 3>::sequence(1, 49, 40, &data);
579        assert_eq!(t.shape(), &[1, 49, 40]);
580        assert_eq!(t.len(), 1960);
581    }
582
583    #[test]
584    fn hwc_constructor() {
585        let data = [0u8; 48];
586        let t = Tensor::<u8, 48, 3>::hwc(4, 4, 3, &data);
587        assert_eq!(t.shape(), &[4, 4, 3]);
588        assert_eq!(t.len(), 48);
589    }
590
591    // -----------------------------------------------------------------
592    // Rank-4: named constructors
593    // -----------------------------------------------------------------
594
595    #[test]
596    fn nhwc_constructor() {
597        let data = [0u8; 9216];
598        let t = Tensor::<u8, 9216, 4>::nhwc(1, 96, 96, 1, &data);
599        assert_eq!(t.shape(), &[1, 96, 96, 1]);
600        assert_eq!(t.len(), 9216);
601        assert!(t.is_compatible());
602    }
603
604    #[test]
605    fn nchw_constructor() {
606        // 1×3×4×4 = 48 elements — small enough for the test thread stack.
607        let data = [0.0f32; 48];
608        let t = Tensor::<f32, 48, 4>::nchw(1, 3, 4, 4, &data);
609        assert_eq!(t.shape(), &[1, 3, 4, 4]);
610    }
611
612    // -----------------------------------------------------------------
613    // Indexing: at / set
614    // -----------------------------------------------------------------
615
616    #[test]
617    fn at_and_set_rank2() {
618        let data = [1u32, 2, 3, 4, 5, 6];
619        let mut t = Tensor::<u32, 8, 2>::matrix(2, 3, &data);
620
621        // Row-major: [0,0]=1, [0,1]=2, [0,2]=3, [1,0]=4, [1,1]=5, [1,2]=6
622        assert_eq!(t.at([0, 0]), 1);
623        assert_eq!(t.at([0, 2]), 3);
624        assert_eq!(t.at([1, 0]), 4);
625        assert_eq!(t.at([1, 2]), 6);
626
627        t.set([1, 1], 99);
628        assert_eq!(t.at([1, 1]), 99);
629    }
630
631    #[test]
632    fn at_rank4_nhwc() {
633        // 1×2×2×3 NHWC tensor with sequential values.
634        let data: [u8; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
635        let t = Tensor::<u8, 12, 4>::nhwc(1, 2, 2, 3, &data);
636
637        // [0, 0, 0, 0] = 0, [0, 0, 0, 1] = 1, [0, 0, 0, 2] = 2
638        // [0, 0, 1, 0] = 3, [0, 0, 1, 1] = 4, [0, 0, 1, 2] = 5
639        // [0, 1, 0, 0] = 6, ...
640        assert_eq!(t.at([0, 0, 0, 0]), 0);
641        assert_eq!(t.at([0, 0, 0, 2]), 2);
642        assert_eq!(t.at([0, 0, 1, 0]), 3);
643        assert_eq!(t.at([0, 1, 0, 0]), 6);
644        assert_eq!(t.at([0, 1, 1, 2]), 11);
645    }
646
647    #[test]
648    fn at_rank3_sequence() {
649        // 1×3×2 sequence tensor.
650        let data = [10i8, 20, 30, 40, 50, 60];
651        let t = Tensor::<i8, 8, 3>::sequence(1, 3, 2, &data);
652
653        assert_eq!(t.at([0, 0, 0]), 10);
654        assert_eq!(t.at([0, 0, 1]), 20);
655        assert_eq!(t.at([0, 1, 0]), 30);
656        assert_eq!(t.at([0, 2, 1]), 60);
657    }
658
659    #[test]
660    fn at_rank1() {
661        let t = Tensor::<f32, 4, 1>::from_slice(&[10.0, 20.0, 30.0]);
662        assert_eq!(t.at([0]), 10.0);
663        assert_eq!(t.at([2]), 30.0);
664    }
665
666    #[test]
667    #[should_panic]
668    fn at_out_of_bounds_panics() {
669        let t = Tensor::<u8, 4, 1>::from_slice(&[1, 2, 3]);
670        let _ = t.at([3]);
671    }
672
673    // -----------------------------------------------------------------
674    // Reshape
675    // -----------------------------------------------------------------
676
677    #[test]
678    fn reshape_preserves_data() {
679        let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
680        let mut t = Tensor::<f32, 8, 2>::matrix(2, 3, &data);
681
682        t.reshape([3, 2]);
683        assert_eq!(t.shape(), &[3, 2]);
684        assert_eq!(t.len(), 6);
685        assert_eq!(t.as_slice(), &data);
686        assert!(t.is_compatible());
687
688        // Indexing uses new shape.
689        assert_eq!(t.at([0, 0]), 1.0);
690        assert_eq!(t.at([0, 1]), 2.0);
691        assert_eq!(t.at([1, 0]), 3.0);
692        assert_eq!(t.at([2, 1]), 6.0);
693    }
694
695    // -----------------------------------------------------------------
696    // Mutation via as_mut_slice
697    // -----------------------------------------------------------------
698
699    #[test]
700    fn as_mut_slice_modification() {
701        let mut t = Tensor::<u8, 8, 1>::from_slice(&[0, 0, 0]);
702        t.as_mut_slice().copy_from_slice(&[10, 20, 30]);
703        assert_eq!(t.as_slice(), &[10, 20, 30]);
704    }
705
706    // -----------------------------------------------------------------
707    // DataType
708    // -----------------------------------------------------------------
709
710    #[test]
711    fn data_type_reflects_element() {
712        assert_eq!(
713            Tensor::<f32, 4, 1>::default().data_type(),
714            DataType::Float32
715        );
716        assert_eq!(
717            Tensor::<f64, 4, 1>::default().data_type(),
718            DataType::Float64
719        );
720        assert_eq!(
721            Tensor::<u8, 4, 1>::default().data_type(),
722            DataType::Unsigned8
723        );
724        assert_eq!(Tensor::<i8, 4, 1>::default().data_type(), DataType::Signed8);
725        assert_eq!(
726            Tensor::<u16, 4, 1>::default().data_type(),
727            DataType::Unsigned16
728        );
729        assert_eq!(
730            Tensor::<i16, 4, 1>::default().data_type(),
731            DataType::Signed16
732        );
733        assert_eq!(
734            Tensor::<u32, 4, 1>::default().data_type(),
735            DataType::Unsigned32
736        );
737        assert_eq!(
738            Tensor::<i32, 4, 1>::default().data_type(),
739            DataType::Signed32
740        );
741        assert_eq!(
742            Tensor::<u64, 4, 1>::default().data_type(),
743            DataType::Unsigned64
744        );
745        assert_eq!(
746            Tensor::<i64, 4, 1>::default().data_type(),
747            DataType::Signed64
748        );
749        assert_eq!(
750            Tensor::<bool, 4, 1>::default().data_type(),
751            DataType::Boolean
752        );
753    }
754
755    // -----------------------------------------------------------------
756    // Byte length
757    // -----------------------------------------------------------------
758
759    #[test]
760    fn byte_len_correct_for_types() {
761        let t_f32 = Tensor::<f32, 8, 1>::from_slice(&[0.0; 6]);
762        assert_eq!(t_f32.byte_len(), 24); // 6 × 4
763
764        let t_u8 = Tensor::<u8, 8, 1>::from_slice(&[0u8; 5]);
765        assert_eq!(t_u8.byte_len(), 5); // 5 × 1
766
767        let t_f64 = Tensor::<f64, 4, 1>::from_slice(&[0.0f64; 3]);
768        assert_eq!(t_f64.byte_len(), 24); // 3 × 8
769
770        let t_empty = Tensor::<f32, 4, 1>::default();
771        assert_eq!(t_empty.byte_len(), 0);
772    }
773
774    // -----------------------------------------------------------------
775    // Payload trait
776    // -----------------------------------------------------------------
777
778    #[test]
779    fn payload_buffer_descriptor() {
780        let t = Tensor::<f32, 8, 2>::matrix(2, 3, &[0.0; 6]);
781        let bd = t.buffer_descriptor();
782        assert_eq!(*bd.bytes(), 24);
783    }
784
785    // -----------------------------------------------------------------
786    // Equality
787    // -----------------------------------------------------------------
788
789    #[test]
790    fn eq_same_data_same_shape() {
791        let a = Tensor::<u8, 8, 2>::matrix(2, 3, &[1, 2, 3, 4, 5, 6]);
792        let b = Tensor::<u8, 8, 2>::matrix(2, 3, &[1, 2, 3, 4, 5, 6]);
793        assert_eq!(a, b);
794    }
795
796    #[test]
797    fn ne_different_data() {
798        let a = Tensor::<u8, 4, 1>::from_slice(&[1, 2, 3]);
799        let b = Tensor::<u8, 4, 1>::from_slice(&[1, 2, 4]);
800        assert_ne!(a, b);
801    }
802
803    #[test]
804    fn ne_different_shape_same_data() {
805        let data = [1u8, 2, 3, 4, 5, 6];
806        let a = Tensor::<u8, 8, 2>::matrix(2, 3, &data);
807        let b = Tensor::<u8, 8, 2>::matrix(3, 2, &data);
808        assert_ne!(a, b);
809    }
810
811    #[test]
812    fn ne_different_len() {
813        let a = Tensor::<u8, 8, 1>::from_slice(&[1, 2, 3]);
814        let b = Tensor::<u8, 8, 1>::from_slice(&[1, 2]);
815        assert_ne!(a, b);
816    }
817
818    #[test]
819    fn eq_ignores_padding_beyond_len() {
820        // Two tensors with same live data but different capacity padding.
821        let a = Tensor::<u8, 8, 1>::from_slice(&[1, 2, 3]);
822        let _b = Tensor::<u8, 16, 1>::from_slice(&[1, 2, 3]);
823        // Different N means different types — can't compare directly.
824        // But same-N tensors with same live data are equal regardless of padding.
825        let mut c = Tensor::<u8, 8, 1>::default();
826        c.data[0] = 1;
827        c.data[1] = 2;
828        c.data[2] = 3;
829        c.data[7] = 99; // padding differs
830        c.len = 3;
831        c.shape = [3];
832        assert_eq!(a, c);
833    }
834
835    // -----------------------------------------------------------------
836    // Copy semantics
837    // -----------------------------------------------------------------
838
839    #[test]
840    fn tensor_is_copy() {
841        let a = Tensor::<f32, 4, 1>::from_slice(&[1.0, 2.0]);
842        let b = a; // copy
843        let c = a; // still valid — Copy
844        assert_eq!(b, c);
845    }
846
847    // -----------------------------------------------------------------
848    // Clone
849    // -----------------------------------------------------------------
850
851    #[test]
852    fn tensor_clone_equals_original() {
853        let a = Tensor::<i32, 16, 3>::from_shape([2, 2, 2], &[1, 2, 3, 4, 5, 6, 7, 8]);
854        let b = a.clone();
855        assert_eq!(a, b);
856    }
857
858    // -----------------------------------------------------------------
859    // Debug
860    // -----------------------------------------------------------------
861
862    #[test]
863    fn debug_format_does_not_panic() {
864        use core::fmt::Write;
865        struct StackBuf([u8; 256], usize);
866        impl Write for StackBuf {
867            fn write_str(&mut self, s: &str) -> core::fmt::Result {
868                for &b in s.as_bytes() {
869                    if self.1 >= self.0.len() {
870                        return Ok(()); // silently truncate
871                    }
872                    self.0[self.1] = b;
873                    self.1 += 1;
874                }
875                Ok(())
876            }
877        }
878        let t = Tensor::<f32, 4, 2>::matrix(2, 2, &[1.0, 2.0, 3.0, 4.0]);
879        let mut buf = StackBuf([0u8; 256], 0);
880        write!(buf, "{:?}", t).unwrap();
881        let s = core::str::from_utf8(&buf.0[..buf.1]).unwrap();
882        assert!(s.contains("Tensor"));
883        assert!(s.contains("Float32"));
884        assert!(s.contains("len: 4"));
885    }
886
887    // -----------------------------------------------------------------
888    // is_compatible
889    // -----------------------------------------------------------------
890
891    #[test]
892    fn is_compatible_valid() {
893        let t = Tensor::<u8, 8, 2>::matrix(2, 3, &[0; 6]);
894        assert!(t.is_compatible());
895    }
896
897    #[test]
898    fn is_compatible_default_zero_shape() {
899        // Default has len=0, shape=[0, 0]. Product of [0, 0] = 0 = len. Compatible.
900        let t = Tensor::<u8, 8, 2>::default();
901        assert!(t.is_compatible());
902    }
903
904    // -----------------------------------------------------------------
905    // flat_index consistency
906    // -----------------------------------------------------------------
907
908    #[test]
909    fn flat_index_row_major_rank3() {
910        // 2×3×4 tensor, sequential data.
911        let data: [u32; 24] = [
912            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
913        ];
914        let t = Tensor::<u32, 24, 3>::from_shape([2, 3, 4], &data);
915
916        // Row-major: flat = i*12 + j*4 + k
917        for i in 0..2 {
918            for j in 0..3 {
919                for k in 0..4 {
920                    let expected = (i * 12 + j * 4 + k) as u32;
921                    assert_eq!(t.at([i, j, k]), expected);
922                }
923            }
924        }
925    }
926
927    #[test]
928    fn flat_index_rank4_exhaustive_small() {
929        // 1×2×2×2 = 8 elements.
930        let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
931        let t = Tensor::<u8, 8, 4>::nhwc(1, 2, 2, 2, &data);
932
933        let mut idx = 0u8;
934        for b in 0..1 {
935            for h in 0..2 {
936                for w in 0..2 {
937                    for c in 0..2 {
938                        assert_eq!(t.at([b, h, w, c]), idx);
939                        idx += 1;
940                    }
941                }
942            }
943        }
944    }
945
946    // -----------------------------------------------------------------
947    // Real-world shapes smoke tests
948    // -----------------------------------------------------------------
949
950    #[test]
951    fn tflm_person_detection_shapes() {
952        // Input: [1, 96, 96, 1] u8
953        let input = Tensor::<u8, 9216, 4>::nhwc(1, 96, 96, 1, &[128u8; 9216]);
954        assert_eq!(input.len(), 9216);
955        assert_eq!(input.byte_len(), 9216);
956        let [b, h, w, c] = *input.shape();
957        assert_eq!((b, h, w, c), (1, 96, 96, 1));
958
959        // Output: [1, 3] i8
960        let output = Tensor::<i8, 3, 2>::nc(1, 3, &[10, -5, 30]);
961        assert_eq!(output.len(), 3);
962        assert_eq!(output.at([0, 2]), 30);
963    }
964
965    #[test]
966    fn tflm_keyword_spotting_shapes() {
967        // Input: [1, 49, 40] i8
968        let input = Tensor::<i8, 1960, 3>::sequence(1, 49, 40, &[0i8; 1960]);
969        assert_eq!(input.len(), 1960);
970        let [b, t, f] = *input.shape();
971        assert_eq!((b, t, f), (1, 49, 40));
972
973        // Output: [1, 12] i8
974        let output = Tensor::<i8, 12, 2>::nc(1, 12, &[0i8; 12]);
975        assert_eq!(output.len(), 12);
976    }
977
978    #[test]
979    fn tract_mobilenet_shapes() {
980        // tract typically uses NCHW.
981        // MobileNet v2: [1, 3, 224, 224] f32
982        let input = Tensor::<f32, 150528, 4>::nchw(1, 3, 224, 224, &[0.0f32; 150528]);
983        assert_eq!(input.len(), 150528);
984        assert_eq!(input.byte_len(), 150528 * 4);
985        let [b, c, h, w] = *input.shape();
986        assert_eq!((b, c, h, w), (1, 3, 224, 224));
987    }
988}