14_concatenate_stack/
14_concatenate_stack.rs1use matten::Tensor;
11
12fn main() {
13 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
14 let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]);
15
16 let cat0 = Tensor::concatenate(&[&a, &b], 0);
18 println!("concatenate axis 0 -> shape {:?}", cat0.shape());
19 assert_eq!(cat0.shape(), &[4, 2]);
20
21 let cat1 = Tensor::concatenate(&[&a, &b], 1);
22 println!("concatenate axis 1 -> shape {:?}", cat1.shape());
23 assert_eq!(cat1.shape(), &[2, 4]);
24
25 let st0 = Tensor::stack(&[&a, &b], 0);
27 println!("stack axis 0 -> shape {:?}", st0.shape());
28 assert_eq!(st0.shape(), &[2, 2, 2]);
29
30 let st2 = Tensor::stack(&[&a, &b], 2);
31 println!("stack axis 2 -> shape {:?}", st2.shape());
32 assert_eq!(st2.shape(), &[2, 2, 2]);
33 assert_eq!(st2.as_slice(), &[1.0, 5.0, 2.0, 6.0, 3.0, 7.0, 4.0, 8.0]);
34
35 let v = Tensor::from_vec(vec![1.0, 2.0]);
37 assert!(Tensor::try_concatenate(&[&a, &v], 0).is_err());
38
39 println!("Concatenate/stack: OK");
40}