1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension along which the tensor will be narrowed.
/// * `start` - The starting point of the given range.
/// * `length` - The ending point of the given range.
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
pub fn narrow<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
dim: usize,
start: usize,
length: usize,
) -> K::Primitive<D> {
let shape = K::shape(&tensor);
let ranges: Vec<_> = (0..D)
.map(|i| {
if i == dim {
start..(start + length)
} else {
0..shape.dims[i]
}
})
.collect();
let ranges_array: [_; D] = ranges.try_into().unwrap();
K::slice(tensor, ranges_array)
}