burn_tensor/tensor/api/
chunk.rs1use super::{TensorMetadata, narrow::narrow};
2use crate::{BasicOps, TensorKind, backend::Backend};
3use alloc::vec::Vec;
4
5pub fn chunk<B: Backend, K: TensorKind<B> + BasicOps<B>>(
24 tensor: K::Primitive,
25 chunks: usize,
26 dim: usize,
27) -> Vec<K::Primitive> {
28 let size = tensor.shape().dims[dim];
29 if size < chunks {
30 return (0..size)
31 .map(|i| narrow::<B, K>(tensor.clone(), dim, i, 1))
32 .collect();
33 }
34
35 let mut tensors = Vec::with_capacity(chunks);
36 let mut sum_chunk_size = 0;
37 if size % chunks == 0 {
38 let chunk_size = size / chunks;
39 for _ in 0..chunks {
40 tensors.push(narrow::<B, K>(
41 tensor.clone(),
42 dim,
43 sum_chunk_size,
44 chunk_size,
45 ));
46 sum_chunk_size += chunk_size;
47 }
48 } else {
49 let chunk_size = (size / chunks) + 1; for _ in 0..chunks - 1 {
51 tensors.push(narrow::<B, K>(
52 tensor.clone(),
53 dim,
54 sum_chunk_size,
55 chunk_size,
56 ));
57 sum_chunk_size += chunk_size;
58 }
59 let remainder = size % chunk_size;
60 tensors.push(narrow::<B, K>(
61 tensor.clone(),
62 dim,
63 sum_chunk_size,
64 remainder,
65 ));
66 }
67
68 tensors
69}