use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;
pub fn split(
tensor: &Tensor,
split_size_or_sections: SplitArg,
dim: isize,
) -> TorshResult<Vec<Tensor>> {
let shape = tensor.shape();
let ndim = shape.ndim() as isize;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= shape.ndim() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
shape.ndim()
),
"split",
));
}
match split_size_or_sections {
SplitArg::Size(size) => {
let dim_size = shape.dims()[dim];
let num_splits = dim_size.div_ceil(size);
let mut splits = Vec::new();
for i in 0..num_splits {
let start = i * size;
let end = ((i + 1) * size).min(dim_size);
let split = tensor.slice(dim as usize, start, end)?.to_tensor()?;
splits.push(split);
}
Ok(splits)
}
SplitArg::Sections(sections) => {
let dim_size = shape.dims()[dim];
let base_size = dim_size / sections;
let remainder = dim_size % sections;
let mut splits = Vec::new();
let mut offset = 0;
for i in 0..sections {
let size = if i < remainder {
base_size + 1
} else {
base_size
};
let split = tensor
.slice(dim as usize, offset, offset + size)?
.to_tensor()?;
splits.push(split);
offset += size;
}
Ok(splits)
}
SplitArg::Indices(indices) => {
let mut splits = Vec::new();
let mut start = 0;
for &index in &indices {
let split = tensor.slice(dim as usize, start, index)?.to_tensor()?;
splits.push(split);
start = index;
}
let dim_size = shape.dims()[dim];
if start < dim_size {
let split = tensor.slice(dim as usize, start, dim_size)?.to_tensor()?;
splits.push(split);
}
Ok(splits)
}
}
}
#[derive(Debug, Clone)]
pub enum SplitArg {
Size(usize),
Sections(usize),
Indices(Vec<usize>),
}
pub fn chunk(tensor: &Tensor, chunks: usize, dim: isize) -> TorshResult<Vec<Tensor>> {
let shape = tensor.shape();
let ndim = shape.ndim() as isize;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= shape.ndim() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
shape.ndim()
),
"chunk",
));
}
split(tensor, SplitArg::Sections(chunks), dim as isize)
}
pub fn tensor_split(
tensor: &Tensor,
indices_or_sections: TensorSplitArg,
dim: isize,
) -> TorshResult<Vec<Tensor>> {
let shape = tensor.shape();
let ndim = shape.ndim() as isize;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= shape.ndim() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
shape.ndim()
),
"tensor_split",
));
}
match indices_or_sections {
TensorSplitArg::Sections(sections) => {
split(tensor, SplitArg::Sections(sections), dim as isize)
}
TensorSplitArg::Indices(indices) => {
let dim_size = shape.dims()[dim];
let mut splits = Vec::new();
let mut prev_idx = 0;
for &idx in &indices {
if idx > dim_size {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!(
"Split index {} out of range for dimension size {}",
idx, dim_size
),
"tensor_split",
));
}
if idx > prev_idx {
let split = tensor.slice(dim as usize, prev_idx, idx)?.to_tensor()?;
splits.push(split);
}
prev_idx = idx;
}
if prev_idx < dim_size {
let split = tensor
.slice(dim as usize, prev_idx, dim_size)?
.to_tensor()?;
splits.push(split);
}
Ok(splits)
}
}
}
#[derive(Debug, Clone)]
pub enum TensorSplitArg {
Sections(usize),
Indices(Vec<usize>),
}
pub fn hsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
let shape = tensor.shape();
if shape.ndim() < 2 {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Input tensor must have at least 2 dimensions for hsplit",
"hsplit",
));
}
tensor_split(tensor, indices_or_sections, 1)
}
pub fn vsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
let shape = tensor.shape();
if shape.ndim() < 2 {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Input tensor must have at least 2 dimensions for vsplit",
"vsplit",
));
}
tensor_split(tensor, indices_or_sections, 0)
}
pub fn dsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
let shape = tensor.shape();
if shape.ndim() < 3 {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Input tensor must have at least 3 dimensions for dsplit",
"dsplit",
));
}
tensor_split(tensor, indices_or_sections, 2)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
#[test]
fn test_split() -> TorshResult<()> {
let tensor = randn(&[6, 4], None, None, None)?;
let result = split(&tensor, SplitArg::Sections(3), 0)?;
assert_eq!(result.len(), 3);
for chunk in &result {
assert_eq!(chunk.shape().dims(), &[2, 4]);
}
let result = split(&tensor, SplitArg::Indices(vec![2, 4]), 0)?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].shape().dims(), &[2, 4]);
assert_eq!(result[1].shape().dims(), &[2, 4]);
assert_eq!(result[2].shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_chunk() -> TorshResult<()> {
let tensor = randn(&[8, 3], None, None, None)?;
let result = chunk(&tensor, 3, 0)?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].shape().dims(), &[3, 3]);
assert_eq!(result[1].shape().dims(), &[3, 3]);
assert_eq!(result[2].shape().dims(), &[2, 3]);
let result = chunk(&tensor, 2, 1)?;
assert_eq!(result.len(), 2);
assert_eq!(result[0].shape().dims(), &[8, 2]); assert_eq!(result[1].shape().dims(), &[8, 1]);
Ok(())
}
#[test]
fn test_tensor_split() -> TorshResult<()> {
let tensor = randn(&[6, 4], None, None, None)?;
let result = tensor_split(&tensor, TensorSplitArg::Sections(3), 0)?;
assert_eq!(result.len(), 3);
for chunk in &result {
assert_eq!(chunk.shape().dims(), &[2, 4]);
}
let result = tensor_split(&tensor, TensorSplitArg::Indices(vec![2, 4]), 0)?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].shape().dims(), &[2, 4]);
assert_eq!(result[1].shape().dims(), &[2, 4]);
assert_eq!(result[2].shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_hsplit() -> TorshResult<()> {
let tensor = randn(&[4, 6], None, None, None)?;
let result = hsplit(&tensor, TensorSplitArg::Sections(3))?;
assert_eq!(result.len(), 3);
for chunk in &result {
assert_eq!(chunk.shape().dims(), &[4, 2]);
}
let result = hsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].shape().dims(), &[4, 2]);
assert_eq!(result[1].shape().dims(), &[4, 2]);
assert_eq!(result[2].shape().dims(), &[4, 2]);
Ok(())
}
#[test]
fn test_vsplit() -> TorshResult<()> {
let tensor = randn(&[6, 4], None, None, None)?;
let result = vsplit(&tensor, TensorSplitArg::Sections(3))?;
assert_eq!(result.len(), 3);
for chunk in &result {
assert_eq!(chunk.shape().dims(), &[2, 4]);
}
let result = vsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].shape().dims(), &[2, 4]);
assert_eq!(result[1].shape().dims(), &[2, 4]);
assert_eq!(result[2].shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_dsplit() -> TorshResult<()> {
let tensor = randn(&[2, 3, 6], None, None, None)?;
let result = dsplit(&tensor, TensorSplitArg::Sections(3))?;
assert_eq!(result.len(), 3);
for chunk in &result {
assert_eq!(chunk.shape().dims(), &[2, 3, 2]);
}
let result = dsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].shape().dims(), &[2, 3, 2]);
assert_eq!(result[1].shape().dims(), &[2, 3, 2]);
assert_eq!(result[2].shape().dims(), &[2, 3, 2]);
Ok(())
}
#[test]
#[should_panic(expected = "Input tensor must have at least 2 dimensions for hsplit")]
fn test_hsplit_invalid_dimensions() {
let tensor = randn(&[5], None, None, None).expect("randn should succeed"); hsplit(&tensor, TensorSplitArg::Sections(2)).expect("operation should succeed");
}
#[test]
#[should_panic(expected = "Input tensor must have at least 2 dimensions for vsplit")]
fn test_vsplit_invalid_dimensions() {
let tensor = randn(&[5], None, None, None).expect("randn should succeed"); vsplit(&tensor, TensorSplitArg::Sections(2)).expect("operation should succeed");
}
#[test]
#[should_panic(expected = "Input tensor must have at least 3 dimensions for dsplit")]
fn test_dsplit_invalid_dimensions() {
let tensor = randn(&[3, 4], None, None, None).expect("randn should succeed"); dsplit(&tensor, TensorSplitArg::Sections(2)).expect("operation should succeed");
}
#[test]
fn test_split_size_mode() -> TorshResult<()> {
let tensor = randn(&[10, 4], None, None, None)?;
let result = split(&tensor, SplitArg::Size(3), 0)?;
assert_eq!(result.len(), 4);
assert_eq!(result[0].shape().dims(), &[3, 4]);
assert_eq!(result[1].shape().dims(), &[3, 4]);
assert_eq!(result[2].shape().dims(), &[3, 4]);
assert_eq!(result[3].shape().dims(), &[1, 4]);
Ok(())
}
#[test]
fn test_negative_dimension_indexing() -> TorshResult<()> {
let tensor = randn(&[4, 6, 8], None, None, None)?;
let result1 = split(&tensor, SplitArg::Sections(2), -1)?; let result2 = split(&tensor, SplitArg::Sections(2), 2)?;
assert_eq!(result1.len(), result2.len());
assert_eq!(result1[0].shape().dims(), result2[0].shape().dims());
Ok(())
}
}