Skip to main content

14_concatenate_stack/
14_concatenate_stack.rs

1//! Shape composition: joining tensors with `concatenate` and `stack` (RFC-039).
2//!
3//! Run: cargo run --example 14_concatenate_stack
4//!
5//! `concatenate` joins tensors along an existing axis (same rank, matching sizes on
6//! every non-concatenation axis). `stack` joins equally shaped tensors along a new
7//! axis, so the output rank grows by one. Both take a borrowed slice `&[&Tensor]`
8//! and reject dynamic tensors; the `try_*` forms return a `Result`.
9
10use matten::Tensor;
11
12fn main() {
13    let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
14    let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]);
15
16    // concatenate joins an existing axis.
17    let cat0 = Tensor::concatenate(&[&a, &b], 0);
18    println!("concatenate axis 0 -> shape {:?}", cat0.shape());
19    assert_eq!(cat0.shape(), &[4, 2]);
20
21    let cat1 = Tensor::concatenate(&[&a, &b], 1);
22    println!("concatenate axis 1 -> shape {:?}", cat1.shape());
23    assert_eq!(cat1.shape(), &[2, 4]);
24
25    // stack adds a new axis (rank + 1). The new axis position selects the tensor.
26    let st0 = Tensor::stack(&[&a, &b], 0);
27    println!("stack axis 0      -> shape {:?}", st0.shape());
28    assert_eq!(st0.shape(), &[2, 2, 2]);
29
30    let st2 = Tensor::stack(&[&a, &b], 2);
31    println!("stack axis 2      -> shape {:?}", st2.shape());
32    assert_eq!(st2.shape(), &[2, 2, 2]);
33    assert_eq!(st2.as_slice(), &[1.0, 5.0, 2.0, 6.0, 3.0, 7.0, 4.0, 8.0]);
34
35    // try_* forms return Result instead of panicking; e.g. a rank mismatch is Err.
36    let v = Tensor::from_vec(vec![1.0, 2.0]);
37    assert!(Tensor::try_concatenate(&[&a, &v], 0).is_err());
38
39    println!("Concatenate/stack: OK");
40}