use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::Axis;
use super::core::{NumericExt, SplitArg};
pub fn hsplit<T: Clone>(
array: &Array<T>,
sections_or_indices: impl Into<SplitArg>,
) -> Result<Vec<Array<T>>> {
let shape = array.shape();
let ndim = shape.len();
if ndim < 2 {
return Err(NumRs2Error::InvalidOperation(
"hsplit requires at least 2D array".to_string(),
));
}
let axis = 1;
match sections_or_indices.into() {
SplitArg::Sections(sections) => {
let axis_len = shape[axis];
if !axis_len.is_multiple_of(sections) {
return Err(NumRs2Error::InvalidOperation(format!(
"array of shape {:?} cannot be split into {} equal sections along axis {}",
shape, sections, axis
)));
}
let section_size = axis_len / sections;
let mut indices = Vec::with_capacity(sections - 1);
for i in 1..sections {
indices.push(i * section_size);
}
split(array, &indices, axis)
}
SplitArg::Indices(indices) => split(array, &indices, axis),
}
}
pub fn vsplit<T: Clone>(
array: &Array<T>,
sections_or_indices: impl Into<SplitArg>,
) -> Result<Vec<Array<T>>> {
let shape = array.shape();
let ndim = shape.len();
if ndim < 2 {
return Err(NumRs2Error::InvalidOperation(
"vsplit requires at least 2D array".to_string(),
));
}
let axis = 0;
match sections_or_indices.into() {
SplitArg::Sections(sections) => {
let axis_len = shape[axis];
if !axis_len.is_multiple_of(sections) {
return Err(NumRs2Error::InvalidOperation(format!(
"array of shape {:?} cannot be split into {} equal sections along axis {}",
shape, sections, axis
)));
}
let section_size = axis_len / sections;
let mut indices = Vec::with_capacity(sections - 1);
for i in 1..sections {
indices.push(i * section_size);
}
split(array, &indices, axis)
}
SplitArg::Indices(indices) => split(array, &indices, axis),
}
}
pub fn dsplit<T: Clone>(
array: &Array<T>,
sections_or_indices: impl Into<SplitArg>,
) -> Result<Vec<Array<T>>> {
let shape = array.shape();
let ndim = shape.len();
if ndim < 3 {
return Err(NumRs2Error::InvalidOperation(
"dsplit requires at least 3D array".to_string(),
));
}
let axis = 2;
match sections_or_indices.into() {
SplitArg::Sections(sections) => {
let axis_len = shape[axis];
if !axis_len.is_multiple_of(sections) {
return Err(NumRs2Error::InvalidOperation(format!(
"array of shape {:?} cannot be split into {} equal sections along axis {}",
shape, sections, axis
)));
}
let section_size = axis_len / sections;
let mut indices = Vec::with_capacity(sections - 1);
for i in 1..sections {
indices.push(i * section_size);
}
split(array, &indices, axis)
}
SplitArg::Indices(indices) => split(array, &indices, axis),
}
}
pub fn split<T: Clone>(array: &Array<T>, indices: &[usize], axis: usize) -> Result<Vec<Array<T>>> {
let shape = array.shape();
if axis >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
shape.len()
)));
}
let axis_len = shape[axis];
let mut split_indices = Vec::new();
for &idx in indices {
if idx == 0 || idx >= axis_len {
return Err(NumRs2Error::InvalidOperation(format!(
"Split index {} out of bounds for axis {} with size {}",
idx, axis, axis_len
)));
}
split_indices.push(idx);
}
split_indices.sort();
let mut result = Vec::new();
let mut start_idx = 0;
for &end_idx in split_indices.iter() {
let view = array.array().slice_axis(
Axis(axis),
scirs2_core::ndarray::Slice::from(start_idx..end_idx),
);
result.push(Array::from_ndarray(view.into_owned().into_dyn()));
start_idx = end_idx;
}
if start_idx < axis_len {
let view = array.array().slice_axis(
Axis(axis),
scirs2_core::ndarray::Slice::from(start_idx..axis_len),
);
result.push(Array::from_ndarray(view.into_owned().into_dyn()));
}
Ok(result)
}