burn_tensor/tensor/api/
chunk.rs

1use super::{TensorMetadata, narrow::narrow};
2use crate::{BasicOps, TensorKind, backend::Backend};
3use alloc::vec::Vec;
4
5/// Split the tensor along the given dimension into chunks.
6///
7/// # Arguments
8///
9/// * `tensor` - The tensor.
10/// * `chunks` - The number of chunks to be produced.
11/// * `dim` - The dimension along which the tensor will be split.
12///
13/// # Returns
14///
15/// A vectors of tensors.
16///
17/// # Remarks
18///
19/// This is a fallback solution that is used only when the backend doesn't have the corresponding implementation.
20/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
21/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
22/// or use this function directly.
23pub fn chunk<B: Backend, K: TensorKind<B> + BasicOps<B>>(
24    tensor: K::Primitive,
25    chunks: usize,
26    dim: usize,
27) -> Vec<K::Primitive> {
28    let size = tensor.shape().dims[dim];
29    if size < chunks {
30        return (0..size)
31            .map(|i| narrow::<B, K>(tensor.clone(), dim, i, 1))
32            .collect();
33    }
34
35    let mut tensors = Vec::with_capacity(chunks);
36    let mut sum_chunk_size = 0;
37    if size % chunks == 0 {
38        let chunk_size = size / chunks;
39        for _ in 0..chunks {
40            tensors.push(narrow::<B, K>(
41                tensor.clone(),
42                dim,
43                sum_chunk_size,
44                chunk_size,
45            ));
46            sum_chunk_size += chunk_size;
47        }
48    } else {
49        let chunk_size = (size / chunks) + 1; // assumes not divisible
50        for _ in 0..chunks - 1 {
51            tensors.push(narrow::<B, K>(
52                tensor.clone(),
53                dim,
54                sum_chunk_size,
55                chunk_size,
56            ));
57            sum_chunk_size += chunk_size;
58        }
59        let remainder = size % chunk_size;
60        tensors.push(narrow::<B, K>(
61            tensor.clone(),
62            dim,
63            sum_chunk_size,
64            remainder,
65        ));
66    }
67
68    tensors
69}