1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
extern crate ndarray; use ops; use tensor::Tensor; pub struct ExpandDims { pub axes: Vec<isize>, } impl ops::Op for ExpandDims { fn name(&self) -> &str { "ExpandDims" } fn compute(&mut self, xs: &[&::NdArray], _: bool) -> ::NdArray { let ret = xs[0].clone(); let mut output_shape = ret.shape().to_vec(); for &i in self.axes.iter() { let axis = if i < 0 { (ret.ndim() as isize + i) as usize } else { i as usize }; output_shape.insert(axis, 1); } ret.into_shape(output_shape).unwrap() } fn grad(&self, gy: &Tensor, _: &[&Tensor], _: &Tensor) -> Vec<Option<Tensor>> { vec![Some(ops::squeeze(gy, self.axes.as_slice()))] } }