Function dfdx::tensor_ops::slice

source ·
pub fn slice<S: SliceShape<Slice>, E: Unit, D: SliceKernel<E>, T: Tape<E, D>, Slice: 'static>(
    tensor: Tensor<S, E, D, T>,
    slice: Slice
) -> Tensor<S::Sliced, E, D, T>
Expand description

Slices all dimensions of a tensor, with the starting and ending indices of each dimension determined by a tuple of ranges.

Slices are specified as tuples of ranges defined with the .. and ..= operators. All sliced dimensions are changed to be of type usize except those sliced with .. (std::ops::RangeFull), whose types are not modified.

Example:

let a = dev.tensor([
    [1., 2.],
    [3., 4.],
]);

// Slice the first row to get a 1x2 tensor
let b: Tensor<Rank2<1, 2>, _, _> = a.clone().slice((0..1, 0..2)).realize();
assert_eq!(b.array(), [[1., 2.]]);

// Slice the last column to get a 2x1 tensor
let c: Tensor<Rank2<2, 1>, _, _> = a.clone().slice((0..2, 1..)).realize();
assert_eq!(c.array(), [[2.], [4.]]);