Function dfdx::tensor_ops::slice
source · pub fn slice<S: SliceShape<Slice>, E: Unit, D: SliceKernel<E>, T: Tape<E, D>, Slice: 'static>(
tensor: Tensor<S, E, D, T>,
slice: Slice
) -> Tensor<S::Sliced, E, D, T>
Expand description
Slices all dimensions of a tensor, with the starting and ending indices of each dimension determined by a tuple of ranges.
Slices are specified as tuples of ranges defined with the ..
and ..=
operators. All
sliced dimensions are changed to be of type usize except those sliced with ..
(std::ops::RangeFull), whose types are not modified.
Example:
let a = dev.tensor([
[1., 2.],
[3., 4.],
]);
// Slice the first row to get a 1x2 tensor
let b: Tensor<Rank2<1, 2>, _, _> = a.clone().slice((0..1, 0..2)).realize();
assert_eq!(b.array(), [[1., 2.]]);
// Slice the last column to get a 2x1 tensor
let c: Tensor<Rank2<2, 1>, _, _> = a.clone().slice((0..2, 1..)).realize();
assert_eq!(c.array(), [[2.], [4.]]);