1#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
3pub enum TensorLayout {
4 NCHW,
6 NHWC,
8 #[default]
10 RowMajor,
11}
12
13pub fn nchw_to_nhwc(tensor: &Tensor) -> Result<Tensor, String> {
16 if tensor.shape.len() != 4 {
17 return Err(format!(
18 "nchw_to_nhwc: expected 4D tensor, got {}D",
19 tensor.shape.len()
20 ));
21 }
22 let (n, c, h, w) = (
23 tensor.shape[0],
24 tensor.shape[1],
25 tensor.shape[2],
26 tensor.shape[3],
27 );
28 let mut out = vec![0.0f32; tensor.data.len()];
29
30 for batch in 0..n {
31 for ch in 0..c {
32 for row in 0..h {
33 for col in 0..w {
34 let src_idx = batch * c * h * w + ch * h * w + row * w + col;
35 let dst_idx = batch * h * w * c + row * w * c + col * c + ch;
36 out[dst_idx] = tensor.data[src_idx];
37 }
38 }
39 }
40 }
41
42 Ok(Tensor::new(out, vec![n, h, w, c]))
43}
44
45pub fn nhwc_to_nchw(tensor: &Tensor) -> Result<Tensor, String> {
48 if tensor.shape.len() != 4 {
49 return Err(format!(
50 "nhwc_to_nchw: expected 4D tensor, got {}D",
51 tensor.shape.len()
52 ));
53 }
54 let (n, h, w, c) = (
55 tensor.shape[0],
56 tensor.shape[1],
57 tensor.shape[2],
58 tensor.shape[3],
59 );
60 let mut out = vec![0.0f32; tensor.data.len()];
61
62 for batch in 0..n {
63 for row in 0..h {
64 for col in 0..w {
65 for ch in 0..c {
66 let src_idx = batch * h * w * c + row * w * c + col * c + ch;
67 let dst_idx = batch * c * h * w + ch * h * w + row * w + col;
68 out[dst_idx] = tensor.data[src_idx];
69 }
70 }
71 }
72 }
73
74 Ok(Tensor::new(out, vec![n, c, h, w]))
75}
76
77pub fn convert_layout(
79 tensor: &Tensor,
80 from: TensorLayout,
81 to: TensorLayout,
82) -> Result<Tensor, String> {
83 match (from, to) {
84 (TensorLayout::NCHW, TensorLayout::NHWC) => nchw_to_nhwc(tensor),
85 (TensorLayout::NHWC, TensorLayout::NCHW) => nhwc_to_nchw(tensor),
86 (a, b) if a == b => Ok(tensor.clone()),
87 _ => Err(format!(
88 "Unsupported layout conversion: {:?} -> {:?}",
89 from, to
90 )),
91 }
92}
93
94#[derive(Debug, Clone)]
97pub struct Tensor {
98 pub data: Vec<f32>,
99 pub shape: Vec<usize>,
100}
101
102impl Tensor {
103 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
106 debug_assert_eq!(data.len(), shape.iter().product::<usize>());
107 Self { data, shape }
108 }
109
110 pub fn zeros(shape: &[usize]) -> Self {
112 let n: usize = shape.iter().product();
113 Self {
114 data: vec![0.0f32; n],
115 shape: shape.to_vec(),
116 }
117 }
118
119 pub fn scalar(val: f32) -> Self {
121 Self {
122 data: vec![val],
123 shape: vec![1],
124 }
125 }
126
127 pub fn numel(&self) -> usize {
129 self.data.len()
130 }
131
132 pub fn ndim(&self) -> usize {
134 self.shape.len()
135 }
136
137 pub fn reshape(&self, new_shape: &[usize]) -> Self {
140 assert_eq!(
141 new_shape.iter().product::<usize>(),
142 self.numel(),
143 "reshape: element count mismatch"
144 );
145 Self {
146 data: self.data.clone(),
147 shape: new_shape.to_vec(),
148 }
149 }
150
151 pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, String> {
154 let n = a.len().max(b.len());
155 let mut out = vec![0usize; n];
156 let a_pad = n - a.len();
157 let b_pad = n - b.len();
158 for i in 0..n {
159 let ai = if i < a_pad { 1 } else { a[i - a_pad] };
160 let bi = if i < b_pad { 1 } else { b[i - b_pad] };
161 if ai == bi {
162 out[i] = ai;
163 } else if ai == 1 {
164 out[i] = bi;
165 } else if bi == 1 {
166 out[i] = ai;
167 } else {
168 return Err(format!("Cannot broadcast {:?} with {:?}", a, b));
169 }
170 }
171 Ok(out)
172 }
173
174 #[inline(always)]
176 pub fn get(&self, idx: usize) -> f32 {
177 self.data[idx]
178 }
179}
180
181pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
187 let n = shape.len();
188 let mut strides = vec![1usize; n];
189 for i in (0..n.saturating_sub(1)).rev() {
190 strides[i] = strides[i + 1] * shape[i + 1];
191 }
192 strides
193}
194
195#[derive(Debug, Clone)]
200pub struct TensorView<'a> {
201 data: &'a [f32],
202 shape: Vec<usize>,
203 strides: Vec<usize>,
204 offset: usize,
205}
206
207impl<'a> TensorView<'a> {
208 pub fn new(data: &'a [f32], shape: Vec<usize>, strides: Vec<usize>, offset: usize) -> Self {
210 Self {
211 data,
212 shape,
213 strides,
214 offset,
215 }
216 }
217
218 pub fn shape(&self) -> &[usize] {
220 &self.shape
221 }
222
223 pub fn strides(&self) -> &[usize] {
225 &self.strides
226 }
227
228 pub fn ndim(&self) -> usize {
230 self.shape.len()
231 }
232
233 pub fn numel(&self) -> usize {
235 self.shape.iter().product()
236 }
237
238 pub fn is_contiguous(&self) -> bool {
242 let expected = compute_strides(&self.shape);
243 self.strides == expected && self.offset == 0
244 }
245
246 pub fn get(&self, indices: &[usize]) -> Option<f32> {
248 if indices.len() != self.shape.len() {
249 return None;
250 }
251 for (i, &idx) in indices.iter().enumerate() {
252 if idx >= self.shape[i] {
253 return None;
254 }
255 }
256 let flat_idx: usize = self.offset
257 + indices
258 .iter()
259 .zip(self.strides.iter())
260 .map(|(&i, &s)| i * s)
261 .sum::<usize>();
262 self.data.get(flat_idx).copied()
263 }
264
265 pub fn transpose(&self, perm: &[usize]) -> Self {
267 let new_shape: Vec<usize> = perm.iter().map(|&p| self.shape[p]).collect();
268 let new_strides: Vec<usize> = perm.iter().map(|&p| self.strides[p]).collect();
269 Self {
270 data: self.data,
271 shape: new_shape,
272 strides: new_strides,
273 offset: self.offset,
274 }
275 }
276
277 pub fn slice(&self, axis: usize, start: usize, end: usize) -> Self {
279 let mut new_shape = self.shape.clone();
280 new_shape[axis] = end - start;
281 Self {
282 data: self.data,
283 shape: new_shape,
284 strides: self.strides.clone(),
285 offset: self.offset + start * self.strides[axis],
286 }
287 }
288
289 pub fn select(&self, axis: usize, index: usize) -> Self {
291 let mut new_shape = self.shape.clone();
292 let mut new_strides = self.strides.clone();
293 new_shape.remove(axis);
294 new_strides.remove(axis);
295 Self {
296 data: self.data,
297 shape: new_shape,
298 strides: new_strides,
299 offset: self.offset + index * self.strides[axis],
300 }
301 }
302
303 pub fn squeeze(&self, axes: &[usize]) -> Self {
305 let mut new_shape = Vec::new();
306 let mut new_strides = Vec::new();
307 for (i, (&s, &st)) in self.shape.iter().zip(self.strides.iter()).enumerate() {
308 if axes.contains(&i) && s == 1 {
309 continue;
310 }
311 new_shape.push(s);
312 new_strides.push(st);
313 }
314 Self {
315 data: self.data,
316 shape: new_shape,
317 strides: new_strides,
318 offset: self.offset,
319 }
320 }
321
322 pub fn unsqueeze(&self, axes: &[usize]) -> Self {
324 let mut sorted_axes: Vec<usize> = axes.to_vec();
326 sorted_axes.sort_unstable();
327
328 let mut new_shape = self.shape.clone();
329 let mut new_strides = self.strides.clone();
330 for (offset, &ax) in sorted_axes.iter().enumerate() {
331 let pos = ax; let stride_val = if pos + 1 - offset < self.strides.len() {
335 self.strides[pos + 1 - offset].max(1)
336 } else {
337 1
338 };
339 new_shape.insert(pos, 1);
340 new_strides.insert(pos, stride_val);
341 }
342 Self {
343 data: self.data,
344 shape: new_shape,
345 strides: new_strides,
346 offset: self.offset,
347 }
348 }
349
350 pub fn to_tensor(&self) -> Tensor {
355 if self.is_contiguous() {
356 let n = self.numel();
357 let data = self.data[..n].to_vec();
358 return Tensor::new(data, self.shape.clone());
359 }
360 let data: Vec<f32> = self.iter().collect();
361 Tensor::new(data, self.shape.clone())
362 }
363
364 pub fn iter(&self) -> TensorViewIter<'_> {
366 let ndim = self.shape.len();
367 let exhausted = self.numel() == 0;
368 TensorViewIter {
369 data: self.data,
370 shape: self.shape.clone(),
371 strides: self.strides.clone(),
372 offset: self.offset,
373 indices: vec![0; ndim],
374 exhausted,
375 }
376 }
377}
378
379pub struct TensorViewIter<'a> {
381 data: &'a [f32],
382 shape: Vec<usize>,
383 strides: Vec<usize>,
384 offset: usize,
385 indices: Vec<usize>,
386 exhausted: bool,
387}
388
389impl TensorViewIter<'_> {
390 fn get_at(&self, indices: &[usize]) -> Option<f32> {
391 let flat_idx: usize = self.offset
392 + indices
393 .iter()
394 .zip(self.strides.iter())
395 .map(|(&i, &s)| i * s)
396 .sum::<usize>();
397 self.data.get(flat_idx).copied()
398 }
399}
400
401impl Iterator for TensorViewIter<'_> {
402 type Item = f32;
403
404 fn next(&mut self) -> Option<f32> {
405 if self.exhausted {
406 return None;
407 }
408 let val = self.get_at(&self.indices);
409
410 let ndim = self.shape.len();
412 let mut carry = true;
413 for i in (0..ndim).rev() {
414 if carry {
415 self.indices[i] += 1;
416 if self.indices[i] < self.shape[i] {
417 carry = false;
418 } else {
419 self.indices[i] = 0;
420 }
421 }
422 }
423 if carry {
424 self.exhausted = true;
425 }
426
427 val
428 }
429
430 fn size_hint(&self) -> (usize, Option<usize>) {
431 if self.exhausted {
432 return (0, Some(0));
433 }
434 let total: usize = self.shape.iter().product();
435 let mut consumed = 0usize;
436 let logical_strides = compute_strides(&self.shape);
437 for (i, &idx) in self.indices.iter().enumerate() {
438 consumed += idx * logical_strides[i];
439 }
440 let remaining = total.saturating_sub(consumed);
441 (remaining, Some(remaining))
442 }
443}
444
445impl ExactSizeIterator for TensorViewIter<'_> {}
446
447impl Tensor {
452 pub fn view(&self) -> TensorView<'_> {
454 let strides = compute_strides(&self.shape);
455 TensorView {
456 data: &self.data,
457 shape: self.shape.clone(),
458 strides,
459 offset: 0,
460 }
461 }
462
463 pub fn transpose_view(&self, perm: &[usize]) -> TensorView<'_> {
465 self.view().transpose(perm)
466 }
467
468 pub fn slice_view(&self, axis: usize, start: usize, end: usize) -> TensorView<'_> {
470 self.view().slice(axis, start, end)
471 }
472}
473
474pub fn from_f16_bytes(bytes: &[u8], shape: Vec<usize>) -> Tensor {
478 let data: Vec<f32> = bytes
479 .chunks_exact(2)
480 .map(|b| {
481 let bits = u16::from_le_bytes([b[0], b[1]]);
482 half::f16::from_bits(bits).to_f32()
483 })
484 .collect();
485 Tensor::new(data, shape)
486}
487
488pub fn from_f32_bytes(bytes: &[u8], shape: Vec<usize>) -> Tensor {
490 let data: Vec<f32> = bytes
491 .chunks_exact(4)
492 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
493 .collect();
494 Tensor::new(data, shape)
495}
496
497pub fn from_i64_bytes(bytes: &[u8], shape: Vec<usize>) -> Tensor {
499 let data: Vec<f32> = bytes
500 .chunks_exact(8)
501 .map(|b| i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32)
502 .collect();
503 Tensor::new(data, shape)
504}
505
506pub fn from_f32_vec(floats: Vec<f32>, shape: Vec<usize>) -> Tensor {
508 Tensor::new(floats, shape)
509}
510
511pub struct BroadcastIter<'a> {
518 a_data: &'a [f32],
519 b_data: &'a [f32],
520 a_strides: Vec<usize>,
521 b_strides: Vec<usize>,
522 output_shape: Vec<usize>,
523 output_strides: Vec<usize>,
524 total: usize,
525 idx: usize,
526}
527
528impl<'a> BroadcastIter<'a> {
529 pub fn new(a: &'a Tensor, b: &'a Tensor) -> Option<Self> {
532 let output_shape = Tensor::broadcast_shape(&a.shape, &b.shape).ok()?;
533
534 let a_strides = broadcast_strides(&a.shape, &output_shape);
535 let b_strides = broadcast_strides(&b.shape, &output_shape);
536 let output_strides = compute_strides(&output_shape);
537
538 let total: usize = output_shape.iter().product();
539
540 Some(Self {
541 a_data: &a.data,
542 b_data: &b.data,
543 a_strides,
544 b_strides,
545 output_shape,
546 output_strides,
547 total,
548 idx: 0,
549 })
550 }
551
552 pub fn output_shape(&self) -> &[usize] {
554 &self.output_shape
555 }
556
557 pub fn len(&self) -> usize {
559 self.total
560 }
561
562 pub fn is_empty(&self) -> bool {
564 self.total == 0
565 }
566}
567
568impl<'a> Iterator for BroadcastIter<'a> {
569 type Item = (f32, f32);
570
571 fn next(&mut self) -> Option<(f32, f32)> {
572 if self.idx >= self.total {
573 return None;
574 }
575
576 let mut a_flat = 0usize;
578 let mut b_flat = 0usize;
579 let mut remaining = self.idx;
580
581 for dim in 0..self.output_shape.len() {
582 let coord = remaining / self.output_strides[dim];
583 remaining %= self.output_strides[dim];
584 a_flat += coord * self.a_strides[dim];
585 b_flat += coord * self.b_strides[dim];
586 }
587
588 self.idx += 1;
589 Some((self.a_data[a_flat], self.b_data[b_flat]))
590 }
591
592 fn size_hint(&self) -> (usize, Option<usize>) {
593 let remaining = self.total - self.idx;
594 (remaining, Some(remaining))
595 }
596}
597
598impl ExactSizeIterator for BroadcastIter<'_> {}
599
600fn broadcast_strides(original_shape: &[usize], broadcast_shape: &[usize]) -> Vec<usize> {
602 let ndim = broadcast_shape.len();
603 let pad = ndim - original_shape.len();
604 let orig_strides = compute_strides(original_shape);
605
606 (0..ndim)
607 .map(|i| {
608 if i < pad {
609 0 } else {
611 let orig_idx = i - pad;
612 if original_shape[orig_idx] == 1 {
613 0 } else {
615 orig_strides[orig_idx]
616 }
617 }
618 })
619 .collect()
620}
621
622impl Tensor {
623 pub fn broadcast_iter<'a>(&'a self, other: &'a Tensor) -> Option<BroadcastIter<'a>> {
625 BroadcastIter::new(self, other)
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_broadcast_shape() {
635 assert_eq!(
636 Tensor::broadcast_shape(&[3, 1], &[1, 4]).expect("broadcast should succeed"),
637 vec![3, 4]
638 );
639 assert_eq!(
640 Tensor::broadcast_shape(&[1], &[4, 3]).expect("broadcast should succeed"),
641 vec![4, 3]
642 );
643 assert!(Tensor::broadcast_shape(&[2], &[3]).is_err());
644 }
645
646 #[test]
647 fn test_reshape() {
648 let t = Tensor::zeros(&[2, 3]);
649 let r = t.reshape(&[6]);
650 assert_eq!(r.shape, vec![6]);
651 }
652
653 fn make_seq_tensor(shape: &[usize]) -> Tensor {
658 let n: usize = shape.iter().product();
659 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
660 Tensor::new(data, shape.to_vec())
661 }
662
663 #[test]
664 fn test_view_basic() {
665 let t = make_seq_tensor(&[2, 3]);
666 let v = t.view();
667 assert_eq!(v.shape(), &[2, 3]);
668 assert_eq!(v.strides(), &[3, 1]);
669 assert_eq!(v.ndim(), 2);
670 assert_eq!(v.numel(), 6);
671 }
672
673 #[test]
674 fn test_view_get() {
675 let t = make_seq_tensor(&[2, 3]);
676 let v = t.view();
677 assert_eq!(v.get(&[0, 0]), Some(0.0));
679 assert_eq!(v.get(&[0, 2]), Some(2.0));
680 assert_eq!(v.get(&[1, 0]), Some(3.0));
681 assert_eq!(v.get(&[1, 2]), Some(5.0));
682 assert_eq!(v.get(&[2, 0]), None);
684 assert_eq!(v.get(&[0]), None);
685 }
686
687 #[test]
688 fn test_view_is_contiguous() {
689 let t = make_seq_tensor(&[2, 3]);
690 let v = t.view();
691 assert!(v.is_contiguous());
692
693 let tv = v.transpose(&[1, 0]);
694 assert!(!tv.is_contiguous());
695 }
696
697 #[test]
698 fn test_view_transpose() {
699 let t = make_seq_tensor(&[2, 3]); let v = t.view().transpose(&[1, 0]);
702 assert_eq!(v.shape(), &[3, 2]);
703 assert_eq!(v.get(&[0, 0]), Some(0.0));
705 assert_eq!(v.get(&[0, 1]), Some(3.0));
706 assert_eq!(v.get(&[1, 0]), Some(1.0));
707 assert_eq!(v.get(&[1, 1]), Some(4.0));
708 assert_eq!(v.get(&[2, 0]), Some(2.0));
709 assert_eq!(v.get(&[2, 1]), Some(5.0));
710 }
711
712 #[test]
713 fn test_view_transpose_3d() {
714 let t = make_seq_tensor(&[2, 3, 4]);
716 let v = t.view().transpose(&[2, 0, 1]);
717 assert_eq!(v.shape(), &[4, 2, 3]);
718 assert_eq!(v.get(&[0, 0, 0]), Some(0.0));
721 assert_eq!(v.get(&[2, 0, 1]), Some(6.0));
722 assert_eq!(v.get(&[3, 1, 2]), Some(23.0));
723 }
724
725 #[test]
726 fn test_view_slice() {
727 let t = make_seq_tensor(&[4, 3]); let v = t.view().slice(0, 1, 3);
730 assert_eq!(v.shape(), &[2, 3]);
731 assert_eq!(v.get(&[0, 0]), Some(3.0));
733 assert_eq!(v.get(&[0, 2]), Some(5.0));
734 assert_eq!(v.get(&[1, 0]), Some(6.0));
735 assert_eq!(v.get(&[1, 2]), Some(8.0));
736 }
737
738 #[test]
739 fn test_view_select() {
740 let t = make_seq_tensor(&[3, 4]); let v = t.view().select(0, 1);
743 assert_eq!(v.shape(), &[4]);
744 assert_eq!(v.get(&[0]), Some(4.0));
745 assert_eq!(v.get(&[1]), Some(5.0));
746 assert_eq!(v.get(&[2]), Some(6.0));
747 assert_eq!(v.get(&[3]), Some(7.0));
748 }
749
750 #[test]
751 fn test_view_squeeze() {
752 let t = make_seq_tensor(&[1, 3, 1, 4]);
754 let v = t.view().squeeze(&[0, 2]);
755 assert_eq!(v.shape(), &[3, 4]);
756 assert_eq!(v.numel(), 12);
757 assert_eq!(v.get(&[0, 0]), Some(0.0));
758 assert_eq!(v.get(&[2, 3]), Some(11.0));
759 }
760
761 #[test]
762 fn test_view_unsqueeze() {
763 let t = make_seq_tensor(&[3, 4]);
765 let v = t.view().unsqueeze(&[0]);
766 assert_eq!(v.shape(), &[1, 3, 4]);
767 assert_eq!(v.numel(), 12);
768 assert_eq!(v.get(&[0, 0, 0]), Some(0.0));
769 assert_eq!(v.get(&[0, 2, 3]), Some(11.0));
770 }
771
772 #[test]
773 fn test_view_to_tensor() {
774 let t = make_seq_tensor(&[2, 3]); let v = t.view().transpose(&[1, 0]); let mat = v.to_tensor();
778 assert_eq!(mat.shape, vec![3, 2]);
779 assert_eq!(mat.data, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
781 }
782
783 #[test]
784 fn test_view_iter() {
785 let t = make_seq_tensor(&[2, 3]);
786 let v = t.view().transpose(&[1, 0]); let elems: Vec<f32> = v.iter().collect();
788 assert_eq!(elems, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
789 }
790
791 #[test]
792 fn test_view_chained_ops() {
793 let t = make_seq_tensor(&[4, 6]);
795 let v = t.view().transpose(&[1, 0]).slice(0, 1, 4);
796 assert_eq!(v.shape(), &[3, 4]);
797 let mat = v.to_tensor();
798 assert_eq!(mat.shape, vec![3, 4]);
799 assert_eq!(
805 mat.data,
806 vec![1.0, 7.0, 13.0, 19.0, 2.0, 8.0, 14.0, 20.0, 3.0, 9.0, 15.0, 21.0,]
807 );
808 }
809
810 #[test]
815 fn test_broadcast_iter_same_shape() {
816 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
818 let b = Tensor::new(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![2, 3]);
819 let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
820 assert_eq!(iter.output_shape(), &[2, 3]);
821 assert_eq!(iter.len(), 6);
822 assert!(!iter.is_empty());
823 let pairs: Vec<(f32, f32)> = iter.collect();
824 assert_eq!(
825 pairs,
826 vec![
827 (1.0, 10.0),
828 (2.0, 20.0),
829 (3.0, 30.0),
830 (4.0, 40.0),
831 (5.0, 50.0),
832 (6.0, 60.0),
833 ]
834 );
835 }
836
837 #[test]
838 fn test_broadcast_iter_scalar() {
839 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
841 let b = Tensor::new(vec![100.0], vec![1]);
842 let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
843 assert_eq!(iter.output_shape(), &[2, 3]);
844 let pairs: Vec<(f32, f32)> = iter.collect();
845 for (i, (av, bv)) in pairs.iter().enumerate() {
846 assert!((*av - (i as f32 + 1.0)).abs() < 1e-6);
847 assert!((*bv - 100.0).abs() < 1e-6);
848 }
849 }
850
851 #[test]
852 fn test_broadcast_iter_row_col() {
853 let a = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]);
855 let b = Tensor::new(vec![10.0, 20.0, 30.0, 40.0], vec![1, 4]);
856 let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
857 assert_eq!(iter.output_shape(), &[3, 4]);
858 assert_eq!(iter.len(), 12);
859 let pairs: Vec<(f32, f32)> = iter.collect();
860 let expected = vec![
864 (1.0, 10.0),
865 (1.0, 20.0),
866 (1.0, 30.0),
867 (1.0, 40.0),
868 (2.0, 10.0),
869 (2.0, 20.0),
870 (2.0, 30.0),
871 (2.0, 40.0),
872 (3.0, 10.0),
873 (3.0, 20.0),
874 (3.0, 30.0),
875 (3.0, 40.0),
876 ];
877 assert_eq!(pairs, expected);
878 }
879
880 #[test]
881 fn test_broadcast_iter_3d() {
882 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 1, 4]);
884 let b = Tensor::new(
885 vec![
886 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
887 ],
888 vec![1, 3, 4],
889 );
890 let iter = BroadcastIter::new(&a, &b).expect("should be compatible");
891 assert_eq!(iter.output_shape(), &[2, 3, 4]);
892 assert_eq!(iter.len(), 24);
893
894 let pairs: Vec<(f32, f32)> = iter.collect();
895 assert_eq!(pairs[0], (1.0, 10.0));
897 assert_eq!(pairs[4], (1.0, 50.0));
899 assert_eq!(pairs[12], (5.0, 10.0));
901 assert_eq!(pairs[23], (8.0, 120.0));
903 }
904
905 #[test]
906 fn test_broadcast_iter_incompatible() {
907 let a = Tensor::new(vec![1.0; 6], vec![2, 3]);
909 let b = Tensor::new(vec![1.0; 12], vec![4, 3]);
910 assert!(BroadcastIter::new(&a, &b).is_none());
911 }
912
913 #[test]
918 fn test_nchw_to_nhwc() {
919 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
921 let t = Tensor::new(data, vec![1, 2, 3, 4]);
922 let nhwc = nchw_to_nhwc(&t).expect("conversion should succeed");
923 assert_eq!(nhwc.shape, vec![1, 3, 4, 2]);
924 assert!((nhwc.data[0] - 0.0).abs() < 1e-6);
926 assert!((nhwc.data[1] - 12.0).abs() < 1e-6);
928 assert!((nhwc.data[2] - 1.0).abs() < 1e-6);
930 }
931
932 #[test]
933 fn test_nhwc_to_nchw() {
934 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
936 let t = Tensor::new(data, vec![1, 3, 4, 2]);
937 let nchw = nhwc_to_nchw(&t).expect("conversion should succeed");
938 assert_eq!(nchw.shape, vec![1, 2, 3, 4]);
939 }
940
941 #[test]
942 fn test_layout_roundtrip() {
943 let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
944 let original = Tensor::new(data.clone(), vec![2, 3, 2, 4]);
945 let nhwc = nchw_to_nhwc(&original).expect("nchw_to_nhwc");
946 let back = nhwc_to_nchw(&nhwc).expect("nhwc_to_nchw");
947 assert_eq!(back.shape, original.shape);
948 assert_eq!(back.data, original.data);
949 }
950
951 #[test]
952 fn test_convert_layout_same() {
953 let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 1, 2, 2]);
954 let result =
955 convert_layout(&t, TensorLayout::NCHW, TensorLayout::NCHW).expect("same layout");
956 assert_eq!(result.data, t.data);
957 assert_eq!(result.shape, t.shape);
958 }
959
960 #[test]
961 fn test_non_4d_error() {
962 let t = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]);
963 assert!(nchw_to_nhwc(&t).is_err());
964 assert!(nhwc_to_nchw(&t).is_err());
965
966 let t3d = Tensor::new(vec![1.0; 12], vec![2, 3, 2]);
967 assert!(nchw_to_nhwc(&t3d).is_err());
968 }
969}