1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use super::narrow::narrow;
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;

/// Split the tensor along the given dimension into chunks.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `chunks` - The number of chunks to be produced
/// * `times` - The dimension along which the tensor will be split.
///
/// # Returns
///
/// A vectors of tensors
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn chunk<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
    tensor: K::Primitive<D>,
    chunks: usize,
    dim: usize,
) -> Vec<K::Primitive<D>> {
    let size = K::shape(&tensor).dims[dim];
    if size < chunks {
        return (0..size)
            .map(|i| narrow::<B, D, K>(tensor.clone(), dim, i, 1))
            .collect();
    }

    let mut tensors = Vec::with_capacity(chunks);
    let mut sum_chunk_size = 0;
    if size % chunks == 0 {
        let chunk_size = size / chunks;
        for _ in 0..chunks {
            tensors.push(narrow::<B, D, K>(
                tensor.clone(),
                dim,
                sum_chunk_size,
                chunk_size,
            ));
            sum_chunk_size += chunk_size;
        }
    } else {
        let chunk_size = (size / chunks) + 1; // assumes not divisible
        for _ in 0..chunks - 1 {
            tensors.push(narrow::<B, D, K>(
                tensor.clone(),
                dim,
                sum_chunk_size,
                chunk_size,
            ));
            sum_chunk_size += chunk_size;
        }
        let remainder = size % chunk_size;
        tensors.push(narrow::<B, D, K>(
            tensor.clone(),
            dim,
            sum_chunk_size,
            remainder,
        ));
    }

    tensors
}