burn_tensor/tensor/api/narrow.rs
1use crate::{backend::Backend, BasicOps, TensorKind};
2use alloc::vec::Vec;
3
4use super::TensorMetadata;
5
6/// Returns a new tensor with the given dimension narrowed to the given range.
7///
8/// # Arguments
9///
10/// * `tensor` - The tensor.
11/// * `dim` - The dimension along which the tensor will be narrowed.
12/// * `start` - The starting point of the given range.
13/// * `length` - The ending point of the given range.
14/// # Panics
15///
16/// - If the dimension is greater than the number of dimensions of the tensor.
17/// - If the given range exceeds the number of elements on the given dimension.
18///
19/// # Returns
20///
21/// A new tensor with the given dimension narrowed to the given range.
22pub fn narrow<B: Backend, K: TensorKind<B> + BasicOps<B>>(
23 tensor: K::Primitive,
24 dim: usize,
25 start: usize,
26 length: usize,
27) -> K::Primitive {
28 let shape = tensor.shape();
29
30 let ranges: Vec<_> = shape
31 .dims
32 .iter()
33 .enumerate()
34 .map(|(i, d)| {
35 if i == dim {
36 start..(start + length)
37 } else {
38 0..*d
39 }
40 })
41 .collect();
42
43 K::slice(tensor, &ranges)
44}