1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
use tensor::Tensor; use traits::TensorTrait; impl<T: TensorTrait> Tensor<T> { pub fn concat(lhs: &Tensor<T>, rhs: &Tensor<T>, axis: usize) -> Tensor<T> { debug_assert!(axis < lhs.ndim()); debug_assert!(lhs.ndim() == rhs.ndim()); let t1 = lhs.canonize(); let t2 = rhs.canonize(); let mut shape = Vec::with_capacity(t1.ndim()); for i in 0..t1.ndim() { if i != axis { if t1.shape[i] != t2.shape[i] { panic!("When using concat, all axes must be the same except the joining one"); } shape.push(t1.shape[i]); } else { shape.push(t1.shape[i] + t2.shape[i]); } } let mut t = Tensor::empty(&shape); for i in 0..t1.size() { let ii = t1.unravel_index(i); t[&ii] = t1.data[i]; } let offset = t1.shape[axis]; for i in 0..t2.size() { let mut ii = t2.unravel_index(i); ii[axis] += offset; t[&ii] = t2.data[i]; } t } }