burn_tensor/tensor/api/split.rs
1use super::{narrow::narrow, TensorMetadata};
2use crate::{backend::Backend, BasicOps, TensorKind};
3use alloc::vec::Vec;
4
5/// Splits the tensor along the given dimension into equally sized chunks (if possible)
6/// with size `split_size`. Last chunk will be smaller if the tensor size along the given
7/// dimension `dim` is not divisible by `split_size`.
8///
9/// # Arguments
10///
11/// * `tensor` - The tensor.
12/// * `split_size` - The size of a single chunk.
13/// * `dim` - The dimension along which to split the tensor.
14///
15/// # Returns
16///
17/// A vector of tensors.
18///
19/// # Remarks
20///
21/// This (and the following) are fallback solutions that is used only when the backend doesn't have the corresponding implementation.
22/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
23/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
24/// or use this function directly.
25pub fn split<B: Backend, K: TensorKind<B> + BasicOps<B>>(
26 tensor: K::Primitive,
27 split_size: usize,
28 dim: usize,
29) -> Vec<K::Primitive> {
30 let size = tensor.shape().dims[dim];
31 let mut tensors = Vec::new();
32
33 let mut start = 0;
34 while start < size {
35 let length = usize::min(split_size, size - start);
36 tensors.push(narrow::<B, K>(tensor.clone(), dim, start, length));
37 start += length;
38 }
39
40 tensors
41}
42
43/// Splits the tensor along the given dimension into chunks with sizes in
44/// `dim` according to `split_sizes`.
45///
46/// # Arguments
47///
48/// * `tensor` - The tensor.
49/// * `split_sizes` - Vector of sizes for each chunk.
50/// * `dim` - The dimension along which to split the tensor.
51///
52/// # Returns
53///
54/// A vector of tensors.
55///
56/// # Remarks
57///
58/// Fallback solution for backends with no equivalent functionality.
59pub fn split_with_sizes<B: Backend, K: TensorKind<B> + BasicOps<B>>(
60 tensor: K::Primitive,
61 split_sizes: Vec<usize>,
62 dim: usize,
63) -> Vec<K::Primitive> {
64 let mut tensors = Vec::new();
65
66 let mut start = 0;
67 for length in split_sizes {
68 if length == 0 {
69 continue;
70 }
71 tensors.push(narrow::<B, K>(tensor.clone(), dim, start, length));
72 start += length;
73 }
74
75 tensors
76}