Skip to main content

oxionnx_core/
tensor.rs

1/// Tensor memory layout for image/convolution data.
2#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
3pub enum TensorLayout {
4    /// Batch, Channels, Height, Width (default for ONNX/PyTorch).
5    NCHW,
6    /// Batch, Height, Width, Channels (used by TensorFlow, often faster on CPU).
7    NHWC,
8    /// Generic row-major layout (non-image tensors).
9    #[default]
10    RowMajor,
11}
12
13/// Convert a tensor from NCHW to NHWC layout.
14/// Input shape: [N, C, H, W] -> Output shape: [N, H, W, C]
15pub fn nchw_to_nhwc(tensor: &Tensor) -> Result<Tensor, String> {
16    if tensor.shape.len() != 4 {
17        return Err(format!(
18            "nchw_to_nhwc: expected 4D tensor, got {}D",
19            tensor.shape.len()
20        ));
21    }
22    let (n, c, h, w) = (
23        tensor.shape[0],
24        tensor.shape[1],
25        tensor.shape[2],
26        tensor.shape[3],
27    );
28    let mut out = vec![0.0f32; tensor.data.len()];
29
30    for batch in 0..n {
31        for ch in 0..c {
32            for row in 0..h {
33                for col in 0..w {
34                    let src_idx = batch * c * h * w + ch * h * w + row * w + col;
35                    let dst_idx = batch * h * w * c + row * w * c + col * c + ch;
36                    out[dst_idx] = tensor.data[src_idx];
37                }
38            }
39        }
40    }
41
42    Ok(Tensor::new(out, vec![n, h, w, c]))
43}
44
45/// Convert a tensor from NHWC to NCHW layout.
46/// Input shape: [N, H, W, C] -> Output shape: [N, C, H, W]
47pub fn nhwc_to_nchw(tensor: &Tensor) -> Result<Tensor, String> {
48    if tensor.shape.len() != 4 {
49        return Err(format!(
50            "nhwc_to_nchw: expected 4D tensor, got {}D",
51            tensor.shape.len()
52        ));
53    }
54    let (n, h, w, c) = (
55        tensor.shape[0],
56        tensor.shape[1],
57        tensor.shape[2],
58        tensor.shape[3],
59    );
60    let mut out = vec![0.0f32; tensor.data.len()];
61
62    for batch in 0..n {
63        for row in 0..h {
64            for col in 0..w {
65                for ch in 0..c {
66                    let src_idx = batch * h * w * c + row * w * c + col * c + ch;
67                    let dst_idx = batch * c * h * w + ch * h * w + row * w + col;
68                    out[dst_idx] = tensor.data[src_idx];
69                }
70            }
71        }
72    }
73
74    Ok(Tensor::new(out, vec![n, c, h, w]))
75}
76
77/// Convert between tensor layouts.
78pub fn convert_layout(
79    tensor: &Tensor,
80    from: TensorLayout,
81    to: TensorLayout,
82) -> Result<Tensor, String> {
83    match (from, to) {
84        (TensorLayout::NCHW, TensorLayout::NHWC) => nchw_to_nhwc(tensor),
85        (TensorLayout::NHWC, TensorLayout::NCHW) => nhwc_to_nchw(tensor),
86        (a, b) if a == b => Ok(tensor.clone()),
87        _ => Err(format!(
88            "Unsupported layout conversion: {:?} -> {:?}",
89            from, to
90        )),
91    }
92}
93
94/// N-dimensional tensor with f32 data and a shape vector.
95/// Layout: row-major (C order), last dimension varies fastest.
96#[derive(Debug, Clone)]
97pub struct Tensor {
98    pub data: Vec<f32>,
99    pub shape: Vec<usize>,
100}
101
102impl Tensor {
103    /// Create a tensor from owned data and a shape vector.
104    /// Panics (debug-only) if `data.len() != shape.product()`.
105    pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
106        debug_assert_eq!(data.len(), shape.iter().product::<usize>());
107        Self { data, shape }
108    }
109
110    /// Create a zero-filled tensor with the given shape.
111    pub fn zeros(shape: &[usize]) -> Self {
112        let n: usize = shape.iter().product();
113        Self {
114            data: vec![0.0f32; n],
115            shape: shape.to_vec(),
116        }
117    }
118
119    /// Create a scalar tensor (shape `[1]`) containing a single value.
120    pub fn scalar(val: f32) -> Self {
121        Self {
122            data: vec![val],
123            shape: vec![1],
124        }
125    }
126
127    /// Total number of elements in this tensor.
128    pub fn numel(&self) -> usize {
129        self.data.len()
130    }
131
132    /// Number of dimensions (rank) of this tensor.
133    pub fn ndim(&self) -> usize {
134        self.shape.len()
135    }
136
137    /// Return a new tensor with the same data but a different shape.
138    /// Panics if the element count changes.
139    pub fn reshape(&self, new_shape: &[usize]) -> Self {
140        assert_eq!(
141            new_shape.iter().product::<usize>(),
142            self.numel(),
143            "reshape: element count mismatch"
144        );
145        Self {
146            data: self.data.clone(),
147            shape: new_shape.to_vec(),
148        }
149    }
150
151    /// Compute the broadcast shape of two tensors (NumPy rules).
152    /// Returns Err if shapes are incompatible.
153    pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, String> {
154        let n = a.len().max(b.len());
155        let mut out = vec![0usize; n];
156        let a_pad = n - a.len();
157        let b_pad = n - b.len();
158        for i in 0..n {
159            let ai = if i < a_pad { 1 } else { a[i - a_pad] };
160            let bi = if i < b_pad { 1 } else { b[i - b_pad] };
161            if ai == bi {
162                out[i] = ai;
163            } else if ai == 1 {
164                out[i] = bi;
165            } else if bi == 1 {
166                out[i] = ai;
167            } else {
168                return Err(format!("Cannot broadcast {:?} with {:?}", a, b));
169            }
170        }
171        Ok(out)
172    }
173
174    /// Retrieve a single element by flat index (bounds checked in debug).
175    #[inline(always)]
176    pub fn get(&self, idx: usize) -> f32 {
177        self.data[idx]
178    }
179}
180
181// ---------------------------------------------------------------------------
182// TensorView — zero-copy strided view
183// ---------------------------------------------------------------------------
184
185/// Compute C-order (row-major) strides from shape.
186pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
187    let n = shape.len();
188    let mut strides = vec![1usize; n];
189    for i in (0..n.saturating_sub(1)).rev() {
190        strides[i] = strides[i + 1] * shape[i + 1];
191    }
192    strides
193}
194
195/// A read-only view into a tensor's data with stride-based indexing.
196///
197/// Enables zero-copy transpose, slice, squeeze, and unsqueeze operations
198/// by manipulating shape, strides, and offset without copying data.
199#[derive(Debug, Clone)]
200pub struct TensorView<'a> {
201    data: &'a [f32],
202    shape: Vec<usize>,
203    strides: Vec<usize>,
204    offset: usize,
205}
206
207impl<'a> TensorView<'a> {
208    /// Create a view from a data slice, shape, and strides.
209    pub fn new(data: &'a [f32], shape: Vec<usize>, strides: Vec<usize>, offset: usize) -> Self {
210        Self {
211            data,
212            shape,
213            strides,
214            offset,
215        }
216    }
217
218    /// Shape of this view.
219    pub fn shape(&self) -> &[usize] {
220        &self.shape
221    }
222
223    /// Strides of this view.
224    pub fn strides(&self) -> &[usize] {
225        &self.strides
226    }
227
228    /// Number of dimensions.
229    pub fn ndim(&self) -> usize {
230        self.shape.len()
231    }
232
233    /// Total number of elements.
234    pub fn numel(&self) -> usize {
235        self.shape.iter().product()
236    }
237
238    /// Check if the view is contiguous (C-order, row-major).
239    ///
240    /// Contiguous when `strides[i] == product(shape[i+1..])` for all `i`.
241    pub fn is_contiguous(&self) -> bool {
242        let expected = compute_strides(&self.shape);
243        self.strides == expected && self.offset == 0
244    }
245
246    /// Access a single element by multi-dimensional indices.
247    pub fn get(&self, indices: &[usize]) -> Option<f32> {
248        if indices.len() != self.shape.len() {
249            return None;
250        }
251        for (i, &idx) in indices.iter().enumerate() {
252            if idx >= self.shape[i] {
253                return None;
254            }
255        }
256        let flat_idx: usize = self.offset
257            + indices
258                .iter()
259                .zip(self.strides.iter())
260                .map(|(&i, &s)| i * s)
261                .sum::<usize>();
262        self.data.get(flat_idx).copied()
263    }
264
265    /// Transpose: permute dimensions and their strides.
266    pub fn transpose(&self, perm: &[usize]) -> Self {
267        let new_shape: Vec<usize> = perm.iter().map(|&p| self.shape[p]).collect();
268        let new_strides: Vec<usize> = perm.iter().map(|&p| self.strides[p]).collect();
269        Self {
270            data: self.data,
271            shape: new_shape,
272            strides: new_strides,
273            offset: self.offset,
274        }
275    }
276
277    /// Slice along one axis: select a range `[start, end)` along the given axis.
278    pub fn slice(&self, axis: usize, start: usize, end: usize) -> Self {
279        let mut new_shape = self.shape.clone();
280        new_shape[axis] = end - start;
281        Self {
282            data: self.data,
283            shape: new_shape,
284            strides: self.strides.clone(),
285            offset: self.offset + start * self.strides[axis],
286        }
287    }
288
289    /// Select a single index along an axis, reducing rank by 1.
290    pub fn select(&self, axis: usize, index: usize) -> Self {
291        let mut new_shape = self.shape.clone();
292        let mut new_strides = self.strides.clone();
293        new_shape.remove(axis);
294        new_strides.remove(axis);
295        Self {
296            data: self.data,
297            shape: new_shape,
298            strides: new_strides,
299            offset: self.offset + index * self.strides[axis],
300        }
301    }
302
303    /// Squeeze: remove dimensions of size 1.
304    pub fn squeeze(&self, axes: &[usize]) -> Self {
305        let mut new_shape = Vec::new();
306        let mut new_strides = Vec::new();
307        for (i, (&s, &st)) in self.shape.iter().zip(self.strides.iter()).enumerate() {
308            if axes.contains(&i) && s == 1 {
309                continue;
310            }
311            new_shape.push(s);
312            new_strides.push(st);
313        }
314        Self {
315            data: self.data,
316            shape: new_shape,
317            strides: new_strides,
318            offset: self.offset,
319        }
320    }
321
322    /// Unsqueeze: insert dimensions of size 1.
323    pub fn unsqueeze(&self, axes: &[usize]) -> Self {
324        // Sort axes so we can insert from left to right with offset tracking.
325        let mut sorted_axes: Vec<usize> = axes.to_vec();
326        sorted_axes.sort_unstable();
327
328        let mut new_shape = self.shape.clone();
329        let mut new_strides = self.strides.clone();
330        for (offset, &ax) in sorted_axes.iter().enumerate() {
331            let pos = ax; // axes refer to positions in the *output* shape
332                          // For stride of a size-1 dim, any value works; use the stride of
333                          // the next dim (or 1 if at end).
334            let stride_val = if pos + 1 - offset < self.strides.len() {
335                self.strides[pos + 1 - offset].max(1)
336            } else {
337                1
338            };
339            new_shape.insert(pos, 1);
340            new_strides.insert(pos, stride_val);
341        }
342        Self {
343            data: self.data,
344            shape: new_shape,
345            strides: new_strides,
346            offset: self.offset,
347        }
348    }
349
350    /// Materialize to a contiguous Tensor.
351    ///
352    /// If already contiguous, copies data directly. Otherwise, iterates
353    /// through all elements using strided indexing.
354    pub fn to_tensor(&self) -> Tensor {
355        if self.is_contiguous() {
356            let n = self.numel();
357            let data = self.data[..n].to_vec();
358            return Tensor::new(data, self.shape.clone());
359        }
360        let data: Vec<f32> = self.iter().collect();
361        Tensor::new(data, self.shape.clone())
362    }
363
364    /// Iterate over all elements in logical (row-major) order.
365    pub fn iter(&self) -> TensorViewIter<'_> {
366        let ndim = self.shape.len();
367        let exhausted = self.numel() == 0;
368        TensorViewIter {
369            data: self.data,
370            shape: self.shape.clone(),
371            strides: self.strides.clone(),
372            offset: self.offset,
373            indices: vec![0; ndim],
374            exhausted,
375        }
376    }
377}
378
379/// Iterator over a `TensorView` in logical (row-major) order.
380pub struct TensorViewIter<'a> {
381    data: &'a [f32],
382    shape: Vec<usize>,
383    strides: Vec<usize>,
384    offset: usize,
385    indices: Vec<usize>,
386    exhausted: bool,
387}
388
389impl TensorViewIter<'_> {
390    fn get_at(&self, indices: &[usize]) -> Option<f32> {
391        let flat_idx: usize = self.offset
392            + indices
393                .iter()
394                .zip(self.strides.iter())
395                .map(|(&i, &s)| i * s)
396                .sum::<usize>();
397        self.data.get(flat_idx).copied()
398    }
399}
400
401impl Iterator for TensorViewIter<'_> {
402    type Item = f32;
403
404    fn next(&mut self) -> Option<f32> {
405        if self.exhausted {
406            return None;
407        }
408        let val = self.get_at(&self.indices);
409
410        // Increment indices (rightmost first, carry over)
411        let ndim = self.shape.len();
412        let mut carry = true;
413        for i in (0..ndim).rev() {
414            if carry {
415                self.indices[i] += 1;
416                if self.indices[i] < self.shape[i] {
417                    carry = false;
418                } else {
419                    self.indices[i] = 0;
420                }
421            }
422        }
423        if carry {
424            self.exhausted = true;
425        }
426
427        val
428    }
429
430    fn size_hint(&self) -> (usize, Option<usize>) {
431        if self.exhausted {
432            return (0, Some(0));
433        }
434        let total: usize = self.shape.iter().product();
435        let mut consumed = 0usize;
436        let logical_strides = compute_strides(&self.shape);
437        for (i, &idx) in self.indices.iter().enumerate() {
438            consumed += idx * logical_strides[i];
439        }
440        let remaining = total.saturating_sub(consumed);
441        (remaining, Some(remaining))
442    }
443}
444
445impl ExactSizeIterator for TensorViewIter<'_> {}
446
447// ---------------------------------------------------------------------------
448// Tensor — view methods
449// ---------------------------------------------------------------------------
450
451impl Tensor {
452    /// Create a contiguous view of this tensor.
453    pub fn view(&self) -> TensorView<'_> {
454        let strides = compute_strides(&self.shape);
455        TensorView {
456            data: &self.data,
457            shape: self.shape.clone(),
458            strides,
459            offset: 0,
460        }
461    }
462
463    /// Create a transposed view without copying data.
464    pub fn transpose_view(&self, perm: &[usize]) -> TensorView<'_> {
465        self.view().transpose(perm)
466    }
467
468    /// Create a sliced view without copying data.
469    pub fn slice_view(&self, axis: usize, start: usize, end: usize) -> TensorView<'_> {
470        self.view().slice(axis, start, end)
471    }
472}
473
474// ===========================================================================
475
476/// Build a Tensor from raw f16 little-endian bytes (ONNX `raw_data` with float16 dtype).
477pub fn from_f16_bytes(bytes: &[u8], shape: Vec<usize>) -> Tensor {
478    let data: Vec<f32> = bytes
479        .chunks_exact(2)
480        .map(|b| {
481            let bits = u16::from_le_bytes([b[0], b[1]]);
482            half::f16::from_bits(bits).to_f32()
483        })
484        .collect();
485    Tensor::new(data, shape)
486}
487
488/// Build a Tensor from raw f32 little-endian bytes.
489pub fn from_f32_bytes(bytes: &[u8], shape: Vec<usize>) -> Tensor {
490    let data: Vec<f32> = bytes
491        .chunks_exact(4)
492        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
493        .collect();
494    Tensor::new(data, shape)
495}
496
497/// Build a Tensor from raw i64 little-endian bytes (index tensors).
498pub fn from_i64_bytes(bytes: &[u8], shape: Vec<usize>) -> Tensor {
499    let data: Vec<f32> = bytes
500        .chunks_exact(8)
501        .map(|b| i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
502        .collect();
503    Tensor::new(data, shape)
504}
505
506/// Build a Tensor from repeated float_data values.
507pub fn from_f32_vec(floats: Vec<f32>, shape: Vec<usize>) -> Tensor {
508    Tensor::new(floats, shape)
509}
510
511// ---------------------------------------------------------------------------
512// BroadcastIter — zero-allocation broadcasting iterator
513// ---------------------------------------------------------------------------
514
515/// Iterator that broadcasts two tensors together, yielding `(a_val, b_val)` pairs.
516/// Does NOT allocate an expanded tensor — computes indices on the fly using strides.
517pub struct BroadcastIter<'a> {
518    a_data: &'a [f32],
519    b_data: &'a [f32],
520    a_strides: Vec<usize>,
521    b_strides: Vec<usize>,
522    output_shape: Vec<usize>,
523    output_strides: Vec<usize>,
524    total: usize,
525    idx: usize,
526}
527
528impl<'a> BroadcastIter<'a> {
529    /// Create a broadcast iterator for two tensors.
530    /// Returns `None` if shapes are not broadcast-compatible.
531    pub fn new(a: &'a Tensor, b: &'a Tensor) -> Option<Self> {
532        let output_shape = Tensor::broadcast_shape(&a.shape, &b.shape).ok()?;
533
534        let a_strides = broadcast_strides(&a.shape, &output_shape);
535        let b_strides = broadcast_strides(&b.shape, &output_shape);
536        let output_strides = compute_strides(&output_shape);
537
538        let total: usize = output_shape.iter().product();
539
540        Some(Self {
541            a_data: &a.data,
542            b_data: &b.data,
543            a_strides,
544            b_strides,
545            output_shape,
546            output_strides,
547            total,
548            idx: 0,
549        })
550    }
551
552    /// The output shape of the broadcast.
553    pub fn output_shape(&self) -> &[usize] {
554        &self.output_shape
555    }
556
557    /// Total number of elements.
558    pub fn len(&self) -> usize {
559        self.total
560    }
561
562    /// Whether the iterator is empty.
563    pub fn is_empty(&self) -> bool {
564        self.total == 0
565    }
566}
567
568impl<'a> Iterator for BroadcastIter<'a> {
569    type Item = (f32, f32);
570
571    fn next(&mut self) -> Option<(f32, f32)> {
572        if self.idx >= self.total {
573            return None;
574        }
575
576        // Convert flat index to multi-dimensional indices, then to source indices
577        let mut a_flat = 0usize;
578        let mut b_flat = 0usize;
579        let mut remaining = self.idx;
580
581        for dim in 0..self.output_shape.len() {
582            let coord = remaining / self.output_strides[dim];
583            remaining %= self.output_strides[dim];
584            a_flat += coord * self.a_strides[dim];
585            b_flat += coord * self.b_strides[dim];
586        }
587
588        self.idx += 1;
589        Some((self.a_data[a_flat], self.b_data[b_flat]))
590    }
591
592    fn size_hint(&self) -> (usize, Option<usize>) {
593        let remaining = self.total - self.idx;
594        (remaining, Some(remaining))
595    }
596}
597
598impl ExactSizeIterator for BroadcastIter<'_> {}
599
600/// Compute broadcast strides: if the original dim is 1 (broadcasted), stride is 0.
601fn broadcast_strides(original_shape: &[usize], broadcast_shape: &[usize]) -> Vec<usize> {
602    let ndim = broadcast_shape.len();
603    let pad = ndim - original_shape.len();
604    let orig_strides = compute_strides(original_shape);
605
606    (0..ndim)
607        .map(|i| {
608            if i < pad {
609                0 // prepended dimension, broadcast
610            } else {
611                let orig_idx = i - pad;
612                if original_shape[orig_idx] == 1 {
613                    0 // broadcast this dim
614                } else {
615                    orig_strides[orig_idx]
616                }
617            }
618        })
619        .collect()
620}
621
622impl Tensor {
623    /// Create a broadcast iterator pairing this tensor with another.
624    pub fn broadcast_iter<'a>(&'a self, other: &'a Tensor) -> Option<BroadcastIter<'a>> {
625        BroadcastIter::new(self, other)
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn test_broadcast_shape() {
635        assert_eq!(
636            Tensor::broadcast_shape(&[3, 1], &[1, 4]).expect("broadcast should succeed"),
637            vec![3, 4]
638        );
639        assert_eq!(
640            Tensor::broadcast_shape(&[1], &[4, 3]).expect("broadcast should succeed"),
641            vec![4, 3]
642        );
643        assert!(Tensor::broadcast_shape(&[2], &[3]).is_err());
644    }
645
646    #[test]
647    fn test_reshape() {
648        let t = Tensor::zeros(&[2, 3]);
649        let r = t.reshape(&[6]);
650        assert_eq!(r.shape, vec![6]);
651    }
652
653    // -----------------------------------------------------------------------
654    // TensorView tests
655    // -----------------------------------------------------------------------
656
657    fn make_seq_tensor(shape: &[usize]) -> Tensor {
658        let n: usize = shape.iter().product();
659        let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
660        Tensor::new(data, shape.to_vec())
661    }
662
663    #[test]
664    fn test_view_basic() {
665        let t = make_seq_tensor(&[2, 3]);
666        let v = t.view();
667        assert_eq!(v.shape(), &[2, 3]);
668        assert_eq!(v.strides(), &[3, 1]);
669        assert_eq!(v.ndim(), 2);
670        assert_eq!(v.numel(), 6);
671    }
672
673    #[test]
674    fn test_view_get() {
675        let t = make_seq_tensor(&[2, 3]);
676        let v = t.view();
677        // [0,1,2; 3,4,5]
678        assert_eq!(v.get(&[0, 0]), Some(0.0));
679        assert_eq!(v.get(&[0, 2]), Some(2.0));
680        assert_eq!(v.get(&[1, 0]), Some(3.0));
681        assert_eq!(v.get(&[1, 2]), Some(5.0));
682        // out of bounds
683        assert_eq!(v.get(&[2, 0]), None);
684        assert_eq!(v.get(&[0]), None);
685    }
686
687    #[test]
688    fn test_view_is_contiguous() {
689        let t = make_seq_tensor(&[2, 3]);
690        let v = t.view();
691        assert!(v.is_contiguous());
692
693        let tv = v.transpose(&[1, 0]);
694        assert!(!tv.is_contiguous());
695    }
696
697    #[test]
698    fn test_view_transpose() {
699        // Shape [2,3] -> transpose to [3,2]
700        let t = make_seq_tensor(&[2, 3]); // [0,1,2,3,4,5]
701        let v = t.view().transpose(&[1, 0]);
702        assert_eq!(v.shape(), &[3, 2]);
703        // Transposed: [[0,3],[1,4],[2,5]]
704        assert_eq!(v.get(&[0, 0]), Some(0.0));
705        assert_eq!(v.get(&[0, 1]), Some(3.0));
706        assert_eq!(v.get(&[1, 0]), Some(1.0));
707        assert_eq!(v.get(&[1, 1]), Some(4.0));
708        assert_eq!(v.get(&[2, 0]), Some(2.0));
709        assert_eq!(v.get(&[2, 1]), Some(5.0));
710    }
711
712    #[test]
713    fn test_view_transpose_3d() {
714        // Shape [2,3,4], perm [2,0,1] -> [4,2,3]
715        let t = make_seq_tensor(&[2, 3, 4]);
716        let v = t.view().transpose(&[2, 0, 1]);
717        assert_eq!(v.shape(), &[4, 2, 3]);
718        // Element at original [i,j,k] is at transposed [k,i,j]
719        // Original [0,0,0]=0, [0,1,2]=6, [1,2,3]=23
720        assert_eq!(v.get(&[0, 0, 0]), Some(0.0));
721        assert_eq!(v.get(&[2, 0, 1]), Some(6.0));
722        assert_eq!(v.get(&[3, 1, 2]), Some(23.0));
723    }
724
725    #[test]
726    fn test_view_slice() {
727        // Shape [4,3], slice axis=0, [1,3) -> shape [2,3]
728        let t = make_seq_tensor(&[4, 3]); // 0..12
729        let v = t.view().slice(0, 1, 3);
730        assert_eq!(v.shape(), &[2, 3]);
731        // Row 1: [3,4,5], Row 2: [6,7,8]
732        assert_eq!(v.get(&[0, 0]), Some(3.0));
733        assert_eq!(v.get(&[0, 2]), Some(5.0));
734        assert_eq!(v.get(&[1, 0]), Some(6.0));
735        assert_eq!(v.get(&[1, 2]), Some(8.0));
736    }
737
738    #[test]
739    fn test_view_select() {
740        // Shape [3,4], select axis=0, index=1 -> shape [4]
741        let t = make_seq_tensor(&[3, 4]); // 0..12
742        let v = t.view().select(0, 1);
743        assert_eq!(v.shape(), &[4]);
744        assert_eq!(v.get(&[0]), Some(4.0));
745        assert_eq!(v.get(&[1]), Some(5.0));
746        assert_eq!(v.get(&[2]), Some(6.0));
747        assert_eq!(v.get(&[3]), Some(7.0));
748    }
749
750    #[test]
751    fn test_view_squeeze() {
752        // Shape [1,3,1,4] -> squeeze axes [0,2] -> [3,4]
753        let t = make_seq_tensor(&[1, 3, 1, 4]);
754        let v = t.view().squeeze(&[0, 2]);
755        assert_eq!(v.shape(), &[3, 4]);
756        assert_eq!(v.numel(), 12);
757        assert_eq!(v.get(&[0, 0]), Some(0.0));
758        assert_eq!(v.get(&[2, 3]), Some(11.0));
759    }
760
761    #[test]
762    fn test_view_unsqueeze() {
763        // Shape [3,4] -> unsqueeze axis 0 -> [1,3,4]
764        let t = make_seq_tensor(&[3, 4]);
765        let v = t.view().unsqueeze(&[0]);
766        assert_eq!(v.shape(), &[1, 3, 4]);
767        assert_eq!(v.numel(), 12);
768        assert_eq!(v.get(&[0, 0, 0]), Some(0.0));
769        assert_eq!(v.get(&[0, 2, 3]), Some(11.0));
770    }
771
772    #[test]
773    fn test_view_to_tensor() {
774        // Transpose then materialize
775        let t = make_seq_tensor(&[2, 3]); // [0,1,2,3,4,5]
776        let v = t.view().transpose(&[1, 0]); // [3,2]
777        let mat = v.to_tensor();
778        assert_eq!(mat.shape, vec![3, 2]);
779        // Transposed row-major: [0,3,1,4,2,5]
780        assert_eq!(mat.data, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
781    }
782
783    #[test]
784    fn test_view_iter() {
785        let t = make_seq_tensor(&[2, 3]);
786        let v = t.view().transpose(&[1, 0]); // [3,2]
787        let elems: Vec<f32> = v.iter().collect();
788        assert_eq!(elems, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
789    }
790
791    #[test]
792    fn test_view_chained_ops() {
793        // Shape [4,6]: transpose to [6,4], then slice axis=0 [1,4) -> [3,4]
794        let t = make_seq_tensor(&[4, 6]);
795        let v = t.view().transpose(&[1, 0]).slice(0, 1, 4);
796        assert_eq!(v.shape(), &[3, 4]);
797        let mat = v.to_tensor();
798        assert_eq!(mat.shape, vec![3, 4]);
799        // Original layout row-major [4,6]:
800        //   row0: 0..6, row1: 6..12, row2: 12..18, row3: 18..24
801        // Transposed [6,4]: row i of transposed = column i of original
802        //   trow0: [0,6,12,18], trow1: [1,7,13,19], ...
803        // Slice axis=0 [1,4): trow1,trow2,trow3
804        assert_eq!(
805            mat.data,
806            vec![1.0, 7.0, 13.0, 19.0, 2.0, 8.0, 14.0, 20.0, 3.0, 9.0, 15.0, 21.0,]
807        );
808    }
809
810    // -----------------------------------------------------------------------
811    // BroadcastIter tests
812    // -----------------------------------------------------------------------
813
814    #[test]
815    fn test_broadcast_iter_same_shape() {
816        // [2,3] x [2,3] — no actual broadcasting, just element-wise pairs
817        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
818        let b = Tensor::new(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![2, 3]);
819        let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
820        assert_eq!(iter.output_shape(), &[2, 3]);
821        assert_eq!(iter.len(), 6);
822        assert!(!iter.is_empty());
823        let pairs: Vec<(f32, f32)> = iter.collect();
824        assert_eq!(
825            pairs,
826            vec![
827                (1.0, 10.0),
828                (2.0, 20.0),
829                (3.0, 30.0),
830                (4.0, 40.0),
831                (5.0, 50.0),
832                (6.0, 60.0),
833            ]
834        );
835    }
836
837    #[test]
838    fn test_broadcast_iter_scalar() {
839        // [2,3] x [1] — scalar broadcast
840        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
841        let b = Tensor::new(vec![100.0], vec![1]);
842        let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
843        assert_eq!(iter.output_shape(), &[2, 3]);
844        let pairs: Vec<(f32, f32)> = iter.collect();
845        for (i, (av, bv)) in pairs.iter().enumerate() {
846            assert!((*av - (i as f32 + 1.0)).abs() < 1e-6);
847            assert!((*bv - 100.0).abs() < 1e-6);
848        }
849    }
850
851    #[test]
852    fn test_broadcast_iter_row_col() {
853        // [3,1] x [1,4] -> [3,4]
854        let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]);
855        let b = Tensor::new(vec![10.0, 20.0, 30.0, 40.0], vec![1, 4]);
856        let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
857        assert_eq!(iter.output_shape(), &[3, 4]);
858        assert_eq!(iter.len(), 12);
859        let pairs: Vec<(f32, f32)> = iter.collect();
860        // Row 0: a=1, b cycles [10,20,30,40]
861        // Row 1: a=2, b cycles [10,20,30,40]
862        // Row 2: a=3, b cycles [10,20,30,40]
863        let expected = vec![
864            (1.0, 10.0),
865            (1.0, 20.0),
866            (1.0, 30.0),
867            (1.0, 40.0),
868            (2.0, 10.0),
869            (2.0, 20.0),
870            (2.0, 30.0),
871            (2.0, 40.0),
872            (3.0, 10.0),
873            (3.0, 20.0),
874            (3.0, 30.0),
875            (3.0, 40.0),
876        ];
877        assert_eq!(pairs, expected);
878    }
879
880    #[test]
881    fn test_broadcast_iter_3d() {
882        // [2,1,4] x [1,3,4] -> [2,3,4]
883        let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 1, 4]);
884        let b = Tensor::new(
885            vec![
886                10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
887            ],
888            vec![1, 3, 4],
889        );
890        let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
891        assert_eq!(iter.output_shape(), &[2, 3, 4]);
892        assert_eq!(iter.len(), 24);
893
894        let pairs: Vec<(f32, f32)> = iter.collect();
895        // At [0,0,0]: a[0,0,0]=1, b[0,0,0]=10
896        assert_eq!(pairs[0], (1.0, 10.0));
897        // At [0,1,0]: a[0,0,0]=1 (dim 1 broadcast), b[0,1,0]=50
898        assert_eq!(pairs[4], (1.0, 50.0));
899        // At [1,0,0]: a[1,0,0]=5, b[0,0,0]=10 (dim 0 broadcast)
900        assert_eq!(pairs[12], (5.0, 10.0));
901        // At [1,2,3]: a[1,0,3]=8, b[0,2,3]=120
902        assert_eq!(pairs[23], (8.0, 120.0));
903    }
904
905    #[test]
906    fn test_broadcast_iter_incompatible() {
907        // [2,3] x [4,3] — incompatible
908        let a = Tensor::new(vec![1.0; 6], vec![2, 3]);
909        let b = Tensor::new(vec![1.0; 12], vec![4, 3]);
910        assert!(BroadcastIter::new(&a, &b).is_none());
911    }
912
913    // -----------------------------------------------------------------------
914    // Layout conversion tests
915    // -----------------------------------------------------------------------
916
917    #[test]
918    fn test_nchw_to_nhwc() {
919        // [1,2,3,4] tensor: N=1, C=2, H=3, W=4
920        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
921        let t = Tensor::new(data, vec![1, 2, 3, 4]);
922        let nhwc = nchw_to_nhwc(&t).expect("conversion should succeed");
923        assert_eq!(nhwc.shape, vec![1, 3, 4, 2]);
924        // NCHW [0,0,0,0]=0 -> NHWC [0,0,0,0]=0
925        assert!((nhwc.data[0] - 0.0).abs() < 1e-6);
926        // NCHW [0,1,0,0]=12 -> NHWC [0,0,0,1]=12
927        assert!((nhwc.data[1] - 12.0).abs() < 1e-6);
928        // NCHW [0,0,0,1]=1 -> NHWC [0,0,1,0]=1
929        assert!((nhwc.data[2] - 1.0).abs() < 1e-6);
930    }
931
932    #[test]
933    fn test_nhwc_to_nchw() {
934        // Build NHWC [1,3,4,2] and convert back
935        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
936        let t = Tensor::new(data, vec![1, 3, 4, 2]);
937        let nchw = nhwc_to_nchw(&t).expect("conversion should succeed");
938        assert_eq!(nchw.shape, vec![1, 2, 3, 4]);
939    }
940
941    #[test]
942    fn test_layout_roundtrip() {
943        let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
944        let original = Tensor::new(data.clone(), vec![2, 3, 2, 4]);
945        let nhwc = nchw_to_nhwc(&original).expect("nchw_to_nhwc");
946        let back = nhwc_to_nchw(&nhwc).expect("nhwc_to_nchw");
947        assert_eq!(back.shape, original.shape);
948        assert_eq!(back.data, original.data);
949    }
950
951    #[test]
952    fn test_convert_layout_same() {
953        let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
954        let result =
955            convert_layout(&t, TensorLayout::NCHW, TensorLayout::NCHW).expect("same layout");
956        assert_eq!(result.data, t.data);
957        assert_eq!(result.shape, t.shape);
958    }
959
960    #[test]
961    fn test_non_4d_error() {
962        let t = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]);
963        assert!(nchw_to_nhwc(&t).is_err());
964        assert!(nhwc_to_nchw(&t).is_err());
965
966        let t3d = Tensor::new(vec![1.0; 12], vec![2, 3, 2]);
967        assert!(nchw_to_nhwc(&t3d).is_err());
968    }
969}