shape/
shape.rs

1use numrst::prelude::*;
2
3fn main() -> Result<(), Box<dyn std::error::Error>> {
4    // --- Reshape ---
5    let a = NdArray::<i32>::arange(0, 12)?;
6    let a2d = a.reshape((3, 4))?;
7    let a3d = a.reshape((2, 2, 3))?;
8
9    println!("Original a (1D):\n{}", a);
10    println!("Reshaped a (3x4):\n{}", a2d);
11    println!("Reshaped a (2x2x3):\n{}", a3d);
12
13    // --- Transpose ---
14    let t = a2d.transpose(0, 1)?;
15    println!("Transpose a2d (swap dim 0 <-> 1):\n{}", t);
16
17    // --- Unsqueeze & Squeeze ---
18    let b = NdArray::<f32>::arange(0.0, 6.0)?.reshape((2, 3))?;
19    let b_unsq = b.unsqueeze(0)?; // add axis at front -> (1,2,3)
20    let b_sq = b_unsq.squeeze(0)?; // remove axis -> (2,3)
21
22    println!("b shape: {}", b.shape());
23    println!("b_unsq shape: {}", b_unsq.shape());
24    println!("b_sq shape: {}", b_sq.shape());
25
26    // --- Narrow & Narrow Range ---
27    let narrowed = b.narrow(1, 0, 2)?; // take first 2 cols along axis 1
28    let narrowed_range = b.narrow_range(1, &rng!(0:2))?;
29
30    println!("narrowed (axis=1, first 2 cols):\n{}", narrowed);
31    println!("narrowed_range (axis=1, 0..2):\n{}", narrowed_range);
32
33    // --- Concatenate ---
34    let c1 = NdArray::<i32>::full((2, 2), 1)?;
35    let c2 = NdArray::<i32>::full((2, 2), 2)?;
36    let cat = NdArray::cat(&[&c1, &c2], 0)?; // concat along axis=0
37
38    println!("Concatenate along axis=0:\n{}", cat);
39
40    // --- Stack ---
41    let s1 = NdArray::<i32>::full((2, 2), 3)?;
42    let s2 = NdArray::<i32>::full((2, 2), 4)?;
43    let stacked = NdArray::stack(&[&s1, &s2], 0)?; // stack new axis
44
45    println!("Stacked along new axis=0:\n{}", stacked);
46
47    Ok(())
48}