burn_tensor/tensor/api/
narrow.rs

1use crate::{backend::Backend, BasicOps, TensorKind};
2use alloc::vec::Vec;
3
4use super::TensorMetadata;
5
6/// Returns a new tensor with the given dimension narrowed to the given range.
7///
8/// # Arguments
9///
10/// * `tensor` - The tensor.
11/// * `dim` - The dimension along which the tensor will be narrowed.
12/// * `start` - The starting point of the given range.
13/// * `length` - The ending point of the given range.
14/// # Panics
15///
16/// - If the dimension is greater than the number of dimensions of the tensor.
17/// - If the given range exceeds the number of elements on the given dimension.
18///
19/// # Returns
20///
21/// A new tensor with the given dimension narrowed to the given range.
22pub fn narrow<B: Backend, K: TensorKind<B> + BasicOps<B>>(
23    tensor: K::Primitive,
24    dim: usize,
25    start: usize,
26    length: usize,
27) -> K::Primitive {
28    let shape = tensor.shape();
29
30    let ranges: Vec<_> = shape
31        .dims
32        .iter()
33        .enumerate()
34        .map(|(i, d)| {
35            if i == dim {
36                start..(start + length)
37            } else {
38                0..*d
39            }
40        })
41        .collect();
42
43    K::slice(tensor, &ranges)
44}