use super::{narrow::narrow, TensorMetadata};
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;
pub fn chunk<B: Backend, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive,
chunks: usize,
dim: usize,
) -> Vec<K::Primitive> {
let size = tensor.shape().dims[dim];
if size < chunks {
return (0..size)
.map(|i| narrow::<B, K>(tensor.clone(), dim, i, 1))
.collect();
}
let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(narrow::<B, K>(
tensor.clone(),
dim,
sum_chunk_size,
chunk_size,
));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; for _ in 0..chunks - 1 {
tensors.push(narrow::<B, K>(
tensor.clone(),
dim,
sum_chunk_size,
chunk_size,
));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(narrow::<B, K>(
tensor.clone(),
dim,
sum_chunk_size,
remainder,
));
}
tensors
}