Function autograd::ops::gather_common
[−]
[src]
pub fn gather_common<T: ArrayLike>(
param: &Tensor,
indices: &T,
axis: isize
) -> Tensor
Gathers subviews from the input tensor.
Same spec as https://www.tensorflow.org/api_docs/python/tf/gather. For example, this can be used for embedding vectors lookup etc.
Unlike ag::gather
, indices
can contain negative elements.
Returns
Tensor with shape param.shape[..axis] + indices.shape + param.shape[axis+1..]
extern crate ndarray; extern crate autograd as ag; let ref param = ag::constant(ag::ndarray_ext::zeros(&[5, 4, 8, 2])); let ref indices = ag::constant(ndarray::arr2(&[[5., -1., 3.], [2., 1., -2.]])); let ref y = ag::gather_common(param, indices, 2); assert_eq!(y.eval(&[]).unwrap().shape(), &[5, 4, 2, 3, 2])