1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
use tensor::Tensor;
use traits::TensorTrait;

impl<T: TensorTrait> Tensor<T> {
    pub fn concat(lhs: &Tensor<T>, rhs: &Tensor<T>, axis: usize) -> Tensor<T> {
        debug_assert!(axis < lhs.ndim());
        debug_assert!(lhs.ndim() == rhs.ndim());

        let t1 = lhs.canonize();
        let t2 = rhs.canonize();

        let mut shape = Vec::with_capacity(t1.ndim());
        for i in 0..t1.ndim() {
            if i != axis {
                if t1.shape[i] != t2.shape[i] {
                    panic!("When using concat, all axes must be the same except the joining one");
                }
                shape.push(t1.shape[i]);
            } else {
                shape.push(t1.shape[i] + t2.shape[i]);
            }
        }

        let mut t = Tensor::empty(&shape);
        for i in 0..t1.size() {
            let ii = t1.unravel_index(i);
            t[&ii] = t1.data[i];
        }

        let offset = t1.shape[axis];
        for i in 0..t2.size() {
            let mut ii = t2.unravel_index(i);
            ii[axis] += offset;
            t[&ii] = t2.data[i];
        }
        t
    }

}