Function autograd::ops::gather [] [src]

pub fn gather<T: ArrayLike>(param: &Tensor, indices: &T, axis: isize) -> Tensor

Gathers subviews from the input tensor.

Same spec as https://www.tensorflow.org/api_docs/python/tf/gather. For example, this can be used for embedding vectors lookup etc.

Returns

Tensor with shape param.shape[..axis] + indices.shape + param.shape[axis+1..]

extern crate ndarray;
extern crate autograd as ag;

let mut ctx = ag::Context::new();
let ref param = ag::constant(ag::ndarray_ext::zeros(&[5, 4, 8, 2]), &mut ctx);
let ref indices = ag::constant(ndarray::arr2(&[[5., 4., 3.], [2., 1., 0.]]), &mut ctx);
let ref y = ag::gather(param, indices, 2);

assert_eq!(y.eval(&mut ctx).shape(), &[5, 4, 2, 3, 2])