Function autograd::ops::gather_common [] [src]

pub fn gather_common<T: ArrayLike>(
    param: &Tensor,
    indices: &T,
    axis: isize
) -> Tensor

Gathers subviews from the input tensor.

Same spec as https://www.tensorflow.org/api_docs/python/tf/gather. For example, this can be used for embedding vectors lookup etc.

Unlike ag::gather, indices can contain negative elements.

Returns

Tensor with shape param.shape[..axis] + indices.shape + param.shape[axis+1..]

extern crate ndarray;
extern crate autograd as ag;

let ref param = ag::constant(ag::ndarray_ext::zeros(&[5, 4, 8, 2]));
let ref indices = ag::constant(ndarray::arr2(&[[5., -1., 3.], [2., 1., -2.]]));
let ref y = ag::gather_common(param, indices, 2);

assert_eq!(y.eval(&[]).unwrap().shape(), &[5, 4, 2, 3, 2])