burn_tensor/tensor/api/
split.rs

1use super::{narrow::narrow, TensorMetadata};
2use crate::{backend::Backend, BasicOps, TensorKind};
3use alloc::vec::Vec;
4
5/// Splits the tensor along the given dimension into equally sized chunks (if possible)
6/// with size `split_size`. Last chunk will be smaller if the tensor size along the given
7/// dimension `dim` is not divisible by `split_size`.
8///
9/// # Arguments
10///
11/// * `tensor` - The tensor.
12/// * `split_size` - The size of a single chunk.
13/// * `dim` - The dimension along which to split the tensor.
14///
15/// # Returns
16///
17/// A vector of tensors.
18///
19/// # Remarks
20///
21/// This (and the following) are fallback solutions that is used only when the backend doesn't have the corresponding implementation.
22/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
23/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
24/// or use this function directly.
25pub fn split<B: Backend, K: TensorKind<B> + BasicOps<B>>(
26    tensor: K::Primitive,
27    split_size: usize,
28    dim: usize,
29) -> Vec<K::Primitive> {
30    let size = tensor.shape().dims[dim];
31    let mut tensors = Vec::new();
32
33    let mut start = 0;
34    while start < size {
35        let length = usize::min(split_size, size - start);
36        tensors.push(narrow::<B, K>(tensor.clone(), dim, start, length));
37        start += length;
38    }
39
40    tensors
41}
42
43/// Splits the tensor along the given dimension into chunks with sizes in
44/// `dim` according to `split_sizes`.
45///
46/// # Arguments
47///
48/// * `tensor` - The tensor.
49/// * `split_sizes` - Vector of sizes for each chunk.
50/// * `dim` - The dimension along which to split the tensor.
51///
52/// # Returns
53///
54/// A vector of tensors.
55///
56/// # Remarks
57///
58/// Fallback solution for backends with no equivalent functionality.
59pub fn split_with_sizes<B: Backend, K: TensorKind<B> + BasicOps<B>>(
60    tensor: K::Primitive,
61    split_sizes: Vec<usize>,
62    dim: usize,
63) -> Vec<K::Primitive> {
64    let mut tensors = Vec::new();
65
66    let mut start = 0;
67    for length in split_sizes {
68        if length == 0 {
69            continue;
70        }
71        tensors.push(narrow::<B, K>(tensor.clone(), dim, start, length));
72        start += length;
73    }
74
75    tensors
76}