burn_tensor/tensor/api/
narrow.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;

/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension along which the tensor will be narrowed.
/// * `start` - The starting point of the given range.
/// * `length` - The ending point of the given range.
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
pub fn narrow<B: Backend, K: TensorKind<B> + BasicOps<B>>(
    tensor: K::Primitive,
    dim: usize,
    start: usize,
    length: usize,
) -> K::Primitive {
    let shape = K::shape(&tensor);

    let ranges: Vec<_> = shape
        .dims
        .iter()
        .enumerate()
        .map(|(i, d)| {
            if i == dim {
                start..(start + length)
            } else {
                0..*d
            }
        })
        .collect();

    K::slice(tensor, &ranges)
}