1use std::sync::Arc;
2use crate::{AutogradMetaT, Dim, Dims, Error, Layout, Result, Shape, Storage, WithDType, D};
3use super::{Tensor, TensorId, TensorImpl, Slice};
4
5impl<T: WithDType> Tensor<T> {
6 pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
16 let dims = self.dims();
17 let dim = dim.to_index(self.shape(), "squeeze")?;
18 if dims[dim] == 1 {
19 let mut dims = dims.to_vec();
20 let mut strides = self.layout().stride().to_vec();
21 dims.remove(dim);
22 strides.remove(dim);
23 let tensor_ = TensorImpl {
24 id: TensorId::new(),
25 storage: self.0.storage.clone(),
26 layout: Layout::new(dims, strides, self.layout().start_offset()),
27 meta: T::AutogradMeta::on_reshape_op(self)
28 };
29 Ok(Self(Arc::new(tensor_)))
30 } else {
31 Err( Error::SqueezeDimNot1 { shape: self.shape().clone(), dim } )?
32 }
33 }
34
35 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
48 let mut dims = self.dims().to_vec();
49 let mut strides = self.layout().stride().to_vec();
50 let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
51 dims.insert(dim, 1);
52 let stride = if dim < strides.len() { strides[dim] } else { 1 };
53 strides.insert(dim, stride);
54 let tensor_ = TensorImpl {
55 id: TensorId::new(),
56 storage: self.0.storage.clone(),
57 layout: Layout::new(dims, strides, self.layout().start_offset()),
58 meta: T::AutogradMeta::on_reshape_op(self),
59 };
60 Ok(Self(Arc::new(tensor_)))
61 }
62
63 pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
80 let dims = self.dims();
81 let dim = dim.to_index(self.shape(), "narrow")?;
82 let err = |msg| {
83 Err::<(), _>(Error::NarrowInvalidArgs {
84 shape: self.shape().clone(),
85 dim,
86 start,
87 len,
88 msg,
89 })
90 };
91
92 if start > dims[dim] {
93 err("start > dim_len")?;
94 }
95 if start.saturating_add(len) > dims[dim] {
96 err("start + len > dim_len")?
97 }
98 if start == 0 && dims[dim] == len {
99 Ok(self.clone())
100 } else {
101 let meta = T::AutogradMeta::on_narrow_op(self, dim, start, len);
102 let layout = self.layout().narrow(dim, start, len)?;
103 let tensor_ = TensorImpl {
104 id: TensorId::new(),
105 storage: self.0.storage.clone(),
106 layout,
107 meta
108 };
109 Ok(Self(Arc::new(tensor_)))
110 }
111 }
112
113 pub fn slice<D: Dim>(&self, dim: D, slice: &Slice) -> Result<Self> {
127 let dims = self.dims();
128 let dim = dim.to_index(self.shape(), "narrow")?;
129 let err = |msg| {
130 Err::<(), _>(Error::SliceInvalidArgs {
131 shape: self.shape().clone(),
132 dim,
133 slice: slice.clone(),
134 msg,
135 })
136 };
137
138 let end = match slice.end {
139 Some(end) if end >= 0 => end as usize,
140 Some(end) => {
141 let dis = -end as usize;
142 if dis > dims[dim] {
143 0
144 } else {
145 dims[dim] - dis
146 }
147 }
148 None => dims[dim],
149 };
150 if slice.start > dims[dim] {
151 err("start > dim_len")?;
152 }
153 if end > dims[dim] {
154 err("end > dim_len")?
155 }
156 if slice.start == 0 && dims[dim] == end && slice.step == 1 {
157 Ok(self.clone())
158 } else {
159 let meta = T::AutogradMeta::on_slice_op(self, dim, slice.start, end, slice.step);
160 let layout = self.layout().slice(dim, slice.start, end, slice.step)?;
161 Ok(self.share_storage(layout, meta))
162 }
163 }
164
165 pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
182 let shape = shape.into();
183 if shape.element_count() != self.element_count() {
184 return Err(Error::ShapeMismatchBinaryOp {
185 lhs: self.shape().clone(),
186 rhs: shape,
187 op: "reshape",
188 })?;
189 }
190
191 let meta = T::AutogradMeta::on_reshape_op(self);
192 if self.is_contiguous() {
193 let layout = Layout::contiguous_with_offset(shape, self.layout().start_offset());
194 Ok(self.share_storage(layout, meta))
195 } else {
196 let storage = self.storage_read()?.copy(self.layout());
197 Ok(Self::from_storage(storage, shape, meta))
198 }
199 }
200
201 pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Self> {
203 let dim1 = dim1.to_index(self.shape(), "transpose")?;
204 let dim2 = dim2.to_index(self.shape(), "transpose")?;
205 if dim1 == dim2 {
206 return Ok(self.clone());
207 }
208
209 let meta = T::AutogradMeta::on_transpose_op(self, dim1, dim2);
210 let layout = self.layout().transpose(dim1, dim2)?;
211 Ok(self.share_storage(layout, meta))
212 }
213
214 pub fn transpose_last(&self) -> Result<Self> {
215 self.transpose(D::Minus1, D::Minus2)
216 }
217
218 pub fn permute<D: Dims>(&self, dims: D) -> Result<Self> {
229 let dims = dims.to_indexes(self.shape(), "permute")?;
230 let is_permutation =
232 dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
233 if !is_permutation {
234 crate::bail!(
235 "dimension mismatch in permute, tensor {:?}, dims: {:?}",
236 self.dims(),
237 dims
238 )
239 }
240 let layout = self.layout().permute(&dims)?;
242 let meta = T::AutogradMeta::on_permute_op(self, dims);
243 Ok(self.share_storage(layout, meta))
244 }
245
246 pub fn cat<A: AsRef<Tensor<T>>, D: Dim>(arrs: &[A], dim: D) -> Result<Self> {
263 if arrs.is_empty() {
265 Err(Error::OpRequiresAtLeastOneTensor { op: "cat" })?
266 }
267
268 let arr0 = &arrs[0];
270 let rank0 = arr0.as_ref().rank();
271
272 let cat_dim = dim.to_index(arr0.as_ref().shape(), "cat")?;
274 let mut target_dims = arr0.as_ref().dims().to_vec();
275 target_dims[cat_dim] = 0;
276 let mut dim_offsets = vec![];
277
278 for (_arr_index, arr) in arrs.iter().enumerate() {
279 let rank = arr.as_ref().rank();
281 if rank != rank0 {
282 Err(Error::UnexpectedNumberOfDims {
283 expected: rank,
284 got: arr.as_ref().rank(),
285 shape: arr.as_ref().shape().clone(),
286 })?
287 }
288
289 for (dim_index, (v1, v2)) in arr0.as_ref().dims().iter()
291 .zip(arr.as_ref().dims().iter())
292 .enumerate()
293 {
294 if dim_index == cat_dim {
296 dim_offsets.push(target_dims[cat_dim]);
297 target_dims[cat_dim] += v2;
298 }
299
300 if dim_index != cat_dim && v1 != v2 {
302 Err(Error::ShapeMismatchCat {
303 dim: dim_index,
304 first_shape: arr0.as_ref().shape().clone(),
305 n: dim_index + 1,
306 nth_shape: arr0.as_ref().shape().clone(),
307 })?
308 }
309 }
310 }
311
312 let target_shape: Shape = target_dims.into();
317
318 let mut dst: Vec<T> = Vec::with_capacity(target_shape.element_count());
320 unsafe { dst.set_len(target_shape.element_count()) };
321
322 let meta = T::AutogradMeta::on_cat_op(arrs, cat_dim);
323 let res_arr = Self::from_storage(Storage::new(dst), target_shape, meta);
324
325 for (arr_index, arr) in arrs.iter().enumerate() {
326 let sub_res_arr = res_arr.narrow(cat_dim, dim_offsets[arr_index], arr.as_ref().dims()[cat_dim])?;
328 assert_eq!(sub_res_arr.shape(), arr.as_ref().shape());
329 sub_res_arr.copy_from(arr.as_ref())?;
331 }
332
333 Ok(res_arr)
334 }
335
336 pub fn stack<A: AsRef<Tensor<T>>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
352 if args.is_empty() {
353 Err(Error::OpRequiresAtLeastOneTensor { op: "stack" })?
354 }
355 let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
356 let args = args
357 .iter()
358 .map(|t| t.as_ref().unsqueeze(dim))
359 .collect::<Result<Vec<_>>>()?;
360 Self::cat(&args, dim)
361 }
362
363 pub fn split<D: Dim>(&self, dim: D) -> Result<Vec<Self>> {
396 let split_index = dim.to_index(self.shape(), "split")?;
397 let split_dim_size = self.dims()[split_index];
398 let mut splited_shape = self.dims().to_vec();
399 splited_shape.remove(split_index);
400
401 let mut vec = vec![];
402 for i in 0..split_dim_size {
403 let sub_tensor = self.narrow(split_index, i, 1)?.squeeze(split_index)?;
404 vec.push(sub_tensor);
405 }
406
407 Ok(vec)
408 }
409
410 pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
413 let dim = dim.to_index(self.shape(), "chunk")?;
414 let size = self.dim(dim)?;
415 if size < chunks {
416 (0..size).map(|i| self.narrow(dim, i, 1)).collect()
417 } else {
418 let chunk_size = size / chunks;
419 let cnt_additional = size % chunks;
420 let mut tensors = vec![];
421 let mut sum_chunk_size = 0;
422 for i in 0..chunks {
423 let chunk_size = if i < cnt_additional {
424 chunk_size + 1
425 } else {
426 chunk_size
427 };
428 let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?;
429 tensors.push(tensor);
430 sum_chunk_size += chunk_size
431 }
432 Ok(tensors)
433 }
434 }
435
436 pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Self> {
439 self.flatten_(Some(start_dim), Some(end_dim))
440 }
441
442 pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Self> {
444 self.flatten_(None::<usize>, Some(end_dim))
445 }
446
447 pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Self> {
450 self.flatten_(Some(start_dim), None::<usize>)
451 }
452
453 pub fn flatten_all(&self) -> Result<Self> {
464 self.flatten_(None::<usize>, None::<usize>)
465 }
466
467 pub fn repeat<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
469 let repeats: Shape = shape.into();
470 let mut repeats = repeats.dims().to_vec();
471
472 if repeats.len() > self.rank() {
473 Err(Error::RepeatRankOutOfRange { repeats: repeats.clone().into(), shape: self.shape().into() })?;
474 } else if repeats.len() > self.rank() {
475 for _ in 0..(repeats.len() - self.rank()) {
476 repeats.push(1);
477 }
478 }
479
480 let mut arr = self.clone();
481
482 for (idx, &repeat) in repeats.iter().enumerate() {
483 if repeat > 1 {
484 arr = Tensor::cat(&vec![&arr; repeat], idx)?
485 }
486 }
487 Ok(arr)
488 }
489
490 pub fn repeat_dim<D: Dim>(&self, dim: D, times: usize) -> Result<Self> {
492 if times == 0 {
493 self.squeeze(dim)
494 } else if times == 1 {
495 Ok(self.clone())
496 } else {
497 Tensor::cat(&vec![self; times], dim)
498 }
499 }
500
501 fn flatten_<D1: Dim, D2: Dim>(
502 &self,
503 start_dim: Option<D1>,
504 end_dim: Option<D2>,
505 ) -> Result<Self> {
506 if self.rank() == 0 {
507 self.reshape(1)
508 } else {
509 let start_dim = match start_dim {
510 None => 0,
511 Some(dim) => dim.to_index(self.shape(), "flatten")?,
512 };
513 let end_dim = match end_dim {
514 None => self.rank() - 1,
515 Some(dim) => dim.to_index(self.shape(), "flatten")?,
516 };
517 if start_dim < end_dim {
518 let dims = self.dims();
519 let mut dst_dims = dims[..start_dim].to_vec();
520 dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
521 if end_dim + 1 < dims.len() {
522 dst_dims.extend(&dims[end_dim + 1..]);
523 }
524 self.reshape(dst_dims)
525 } else {
526 Ok(self.clone())
527 }
528 }
529 }
530}
531
532impl<T: WithDType> AsRef<Tensor<T>> for Tensor<T> {
533 fn as_ref(&self) -> &Tensor<T> {
534 self
535 }
536}
537
538#[cfg(test)]
539#[allow(unused)]
540mod test {
541 use super::*;
542
543 #[test]
544 fn test_unsqueeze_basic() -> Result<()> {
545 let t = Tensor::<i32>::zeros((2, 3))?;
546
547 let unsq0 = t.unsqueeze(0)?;
549 assert_eq!(unsq0.dims(), &[1, 2, 3]);
550 assert_eq!(unsq0.to_vec()?, t.to_vec()?);
551
552 let unsq1 = t.unsqueeze(1)?;
554 assert_eq!(unsq1.dims(), &[2, 1, 3]);
555 assert_eq!(unsq1.to_vec()?, t.to_vec()?);
556
557 let unsq2 = t.unsqueeze(2)?;
559 assert_eq!(unsq2.dims(), &[2, 3, 1]);
560 assert_eq!(unsq2.to_vec()?, t.to_vec()?);
561
562 Ok(())
563 }
564
565 #[test]
566 fn test_squeeze_basic() -> Result<()> {
567 let t = Tensor::<i32>::zeros((2, 1, 3))?;
568
569 let sq = t.squeeze(1)?;
570 assert_eq!(sq.dims(), &[2, 3]);
571 assert_eq!(sq.to_vec()?, t.to_vec()?);
572
573 let t2 = Tensor::<i32>::zeros((1, 5))?;
578 let sq2 = t2.squeeze(0)?;
579 assert_eq!(sq2.dims(), &[5]);
580
581 Ok(())
582 }
583
584 #[test]
585 fn test_squeeze_unsqueeze_consistency() -> Result<()> {
586 let t = Tensor::new(&[[1, 2, 3], [4, 5, 6]])?; let unsq = t.unsqueeze(0)?; let sq = unsq.squeeze(0)?; assert_eq!(t.dims(), sq.dims());
592 assert_eq!(t.to_vec()?, sq.to_vec()?);
593 Ok(())
594 }
595
596 #[test]
597 fn test_unsqueeze() -> Result<()> {
598 let t = Tensor::<i32>::zeros((2, 1, 3))?;
599 let sq = t.squeeze(1)?;
600 println!("{}", sq);
601 assert_eq!(sq.dims(), vec![2, 3]);
602
603 let unsq = sq.unsqueeze(0)?;
604 println!("{}", unsq);
605 assert_eq!(unsq.dims(), vec![1, 2, 3]);
606
607 Ok(())
608 }
609
610 #[test]
611 fn test_cat_3d() -> Result<()> {
612 let a = Tensor::full((2, 2, 2), 1)?;
613 let b = Tensor::full((2, 2, 2), 2)?;
614
615 let c = Tensor::cat(&[a, b], 0)?;
616 assert_eq!(c.dims(), [4, 2, 2]);
617
618 let c2 = Tensor::cat(&[c.clone(), c.clone()], 1)?;
619 assert_eq!(c2.dims(), [4, 4, 2]);
620
621 Ok(())
622 }
623
624 #[test]
625 fn test_cat_1d() -> Result<()> {
626 let a = Tensor::new(&[1, 2, 3])?;
627 let b = Tensor::new(&[4, 5, 6])?;
628 let c = Tensor::new(&[7])?;
629
630 let res = Tensor::cat(&[a, b, c], 0)?;
631 assert_eq!(res.dims(), &[7]);
632 assert_eq!(res.to_vec()?, &[1, 2, 3, 4, 5, 6, 7]);
633
634 Ok(())
635 }
636
637 #[test]
638 fn test_cat_2d_axis0() -> Result<()> {
639 let a = Tensor::new(&[[1, 2], [3, 4]])?; let b = Tensor::new(&[[5, 6]])?; let c = Tensor::cat(&[a, b], 0)?;
643 assert_eq!(c.dims(), &[3, 2]);
644 assert_eq!(c.to_vec()?, &[1, 2, 3, 4, 5, 6]);
645
646 Ok(())
647 }
648
649 #[test]
650 fn test_cat_2d_axis1() -> Result<()> {
651 let a = Tensor::new(&[[1, 2], [3, 4]])?; let b = Tensor::new(&[[5], [6]])?; let c = Tensor::cat(&[a, b], 1)?;
655 assert_eq!(c.dims(), &[2, 3]);
656 assert_eq!(c.to_vec()?, &[1, 2, 5, 3, 4, 6]);
658
659 Ok(())
660 }
661
662 #[test]
663 fn test_cat_shape_mismatch() {
664 let a = Tensor::new(&[[1, 2], [3, 4]]).unwrap(); let b = Tensor::new(&[[1, 2, 3]]).unwrap(); let res = Tensor::cat(&[a, b], 0);
670 assert!(res.is_err());
671 }
672
673 #[test]
674 fn test_cat_empty_list_error() {
675 let res = Tensor::<i32>::cat::<Tensor<i32>, usize>(&[], 0);
676 assert!(res.is_err(), "Concatenating an empty list should return an error");
677 }
678
679 #[test]
680 fn test_cat_bool() -> Result<()> {
681 let a = Tensor::new(&[[true, false]])?;
682 let b = Tensor::new(&[[false, true]])?;
683
684 let c = Tensor::cat(&[a, b], 0)?;
685 assert_eq!(c.dims(), [2, 2]);
686 assert_eq!(c.to_vec().unwrap(), [true, false, false, true]);
687
688 Ok(())
689 }
690
691 #[test]
692 fn test_stack_1d_axis0() -> Result<()> {
693 let a = Tensor::new(&[1, 2, 3])?;
694 let b = Tensor::new(&[4, 5, 6])?;
695
696 let c = Tensor::stack(&[a, b], 0)?;
697 assert_eq!(c.dims(), [2, 3]);
698 assert_eq!(c.to_vec().unwrap(), [1, 2, 3, 4, 5, 6]);
699
700 Ok(())
701 }
702
703 #[test]
704 fn test_stack_1d_axis1() -> Result<()> {
705 let a = Tensor::new(&[1, 2, 3])?;
706 let b = Tensor::new(&[4, 5, 6])?;
707
708 let c = Tensor::stack(&[a, b], 1)?;
709 assert_eq!(c.dims(), [3, 2]);
710 assert_eq!(c.to_vec().unwrap(), [1, 4, 2, 5, 3, 6]);
711
712 Ok(())
713 }
714
715 #[test]
716 fn test_stack_2d_axis0() -> Result<()> {
717 let a = Tensor::new(&[[1, 2], [3, 4]])?;
718 let b = Tensor::new(&[[5, 6], [7, 8]])?;
719
720 let c = Tensor::stack(&[a, b], 0)?;
721 assert_eq!(c.dims(), [2, 2, 2]);
722 assert_eq!(c.to_vec().unwrap(), [1, 2, 3, 4, 5, 6, 7, 8]);
723
724 Ok(())
725 }
726
727 #[test]
728 fn test_stack_2d_axis1() -> Result<()> {
729 let a = Tensor::new(&[[1, 2], [3, 4]])?;
730 let b = Tensor::new(&[[5, 6], [7, 8]])?;
731
732 let c = Tensor::stack(&[a, b], 1)?;
733 assert_eq!(c.dims(), [2, 2, 2]);
734 assert_eq!(c.to_vec().unwrap(), [1, 2, 5, 6, 3, 4, 7, 8]);
735
736 Ok(())
737 }
738
739 #[test]
740 fn test_stack_2d_axis2() -> Result<()> {
741 let a = Tensor::new(&[[1, 2], [3, 4]])?;
742 let b = Tensor::new(&[[5, 6], [7, 8]])?;
743
744 let c = Tensor::stack(&[a, b], 2)?;
745 assert_eq!(c.dims(), [2, 2, 2]);
746 assert_eq!(c.to_vec().unwrap(), [1, 5, 2, 6, 3, 7, 4, 8]);
747
748 Ok(())
749 }
750
751 #[test]
752 fn test_stack_shape_mismatch() {
753 let a = Tensor::new(&[1, 2, 3]).unwrap();
754 let b = Tensor::new(&[4, 5]).unwrap();
755
756 let res = Tensor::stack(&[a, b], 0);
757 assert!(res.is_err());
758 }
759
760 #[test]
761 fn test_split_1d() -> Result<()> {
762 let a = Tensor::new(&[10, 20, 30, 40])?;
763 let splits = a.split(0)?; assert_eq!(splits.len(), 4);
766 assert_eq!(splits[0].to_vec().unwrap(), [10]);
767 assert_eq!(splits[1].to_vec().unwrap(), [20]);
768 assert_eq!(splits[2].to_vec().unwrap(), [30]);
769 assert_eq!(splits[3].to_vec().unwrap(), [40]);
770
771 Ok(())
772 }
773
774 #[test]
775 fn test_split_2d_axis0() -> Result<()> {
776 let a = Tensor::new(&[[1, 2], [3, 4], [5, 6], [7, 8]])?;
777 let splits = a.split(0)?;
778
779 assert_eq!(splits.len(), 4);
780 assert_eq!(splits[0].to_vec().unwrap(), [1, 2]);
781 assert_eq!(splits[1].to_vec().unwrap(), [3, 4]);
782 assert_eq!(splits[2].to_vec().unwrap(), [5, 6]);
783 assert_eq!(splits[3].to_vec().unwrap(), [7, 8]);
784
785 Ok(())
786 }
787
788 #[test]
789 fn test_split_2d_axis1() -> Result<()> {
790 let a = Tensor::new(&[[1, 2, 3], [4, 5, 6]])?;
791 let splits = a.split(1)?;
792
793 assert_eq!(splits.len(), 3);
794 assert_eq!(splits[0].to_vec().unwrap(), [1, 4]);
795 assert_eq!(splits[1].to_vec().unwrap(), [2, 5]);
796 assert_eq!(splits[2].to_vec().unwrap(), [3, 6]);
797
798 Ok(())
799 }
800
801 #[test]
802 fn test_split_3d_axis2() -> Result<()> {
803 let a = Tensor::new(&[
804 [[1, 2], [3, 4]],
805 [[5, 6], [7, 8]]
806 ])?;
807 let splits = a.split(2)?;
808
809 assert_eq!(splits.len(), 2);
810 assert_eq!(splits[0].to_vec().unwrap(), [1, 3, 5, 7]);
811 assert_eq!(splits[1].to_vec().unwrap(), [2, 4, 6, 8]);
812
813 Ok(())
814 }
815
816 #[test]
817 fn test_split_single_element() -> Result<()> {
818 let a = Tensor::new(&[42])?;
819 let splits = a.split(0)?;
820
821 assert_eq!(splits.len(), 1);
822 assert_eq!(splits[0].to_vec().unwrap(), [42]);
823
824 Ok(())
825 }
826
827 #[test]
828 fn test_split_empty_array() -> Result<()> {
829 let a = Tensor::<i32>::zeros((0, 2))?;
830 let splits = a.split(0)?;
831
832 assert!(splits.is_empty());
833 Ok(())
834 }
835
836 #[test]
837 fn test_repeat_1d() -> Result<()> {
838 let a = Tensor::new(&[1, 2, 3])?;
839 let b = a.repeat(3)?; assert_eq!(b.dims(), [3 * 3]); assert_eq!(b.to_vec().unwrap(), [1, 2, 3, 1, 2, 3, 1, 2, 3]);
842 Ok(())
843 }
844
845 #[test]
846 fn test_repeat_2d() -> Result<()> {
847 let a = Tensor::new(&[[1, 2], [3, 4]])?;
848 let b = a.repeat((2, 3))?; assert_eq!(b.dims(), [4, 6]);
850 assert_eq!(
851 b.to_vec().unwrap(),
852 [
853 1, 2, 1, 2, 1, 2,
854 3, 4, 3, 4, 3, 4,
855 1, 2, 1, 2, 1, 2,
856 3, 4, 3, 4, 3, 4
857 ]
858 );
859 Ok(())
860 }
861
862 #[test]
863 fn test_repeat_dim() -> Result<()> {
864 let a = Tensor::new(&[1, 2, 3])?;
865 let b = a.repeat_dim(0, 2)?; assert_eq!(b.dims(), [6]);
867 assert_eq!(b.to_vec().unwrap(), [1, 2, 3, 1, 2, 3]);
868
869 let c = a.repeat_dim(0, 1)?; assert_eq!(c.dims(), [3]);
871 assert_eq!(c.to_vec().unwrap(), [1, 2, 3]);
872
873 Ok(())
874 }
875
876 #[test]
877 fn test_repeat_high_dim() -> Result<()> {
878 let a = Tensor::new(&[[1, 2], [3, 4]])?;
879 let b = a.repeat((2, 3))?; assert_eq!(b.dims(), [4, 6]);
881 Ok(())
882 }
883
884 #[test]
885 fn test_narrow_1d_basic() -> Result<()> {
886 let a = Tensor::new(&[0, 1, 2, 3, 4, 5])?;
887
888 let b = a.narrow(0, 2, 3)?;
889
890 assert_eq!(b.dims(), &[3]);
891 assert_eq!(b.to_vec().unwrap(), &[2, 3, 4]);
892
893 let b = Tensor::randn(0.0, 1.0, (5, 5))?;
894 println!("{:?}", b);
895 Ok(())
896 }
897
898 #[test]
899 fn test_narrow_2d_rows() -> Result<()> {
900 let a = Tensor::new(&[
905 [0, 1, 2],
906 [3, 4, 5],
907 [6, 7, 8]
908 ])?;
909
910 let b = a.narrow(0, 1, 1)?;
911
912 assert_eq!(b.dims(), &[1, 3]);
913 assert_eq!(b.to_vec().unwrap(), &[3, 4, 5]);
914 Ok(())
915 }
916}