use rten_base::iter::range_chunks;
use rten_shape_inference::ops as shape_ops;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use crate::buffer_pool::BufferPool;
use crate::infer_shapes::{InferShapes, impl_infer_shapes};
use crate::operator::{
OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList, OutputTypesContext,
};
use crate::ops::{map_value_view, resolve_axis};
use crate::value::ValueView;
#[derive(Clone, Debug)]
pub enum SplitSizes<'a> {
Size(i32),
Sizes(NdTensorView<'a, i32, 1>),
NumSplits(u32),
}
impl<'a> From<&'a [i32]> for SplitSizes<'a> {
fn from(val: &'a [i32]) -> Self {
Self::Sizes(val.into())
}
}
pub fn split<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
axis: isize,
split: SplitSizes,
) -> Result<Vec<Tensor<T>>, OpError> {
let axis = resolve_axis(input.ndim(), axis)?;
let axis_size = input.size(axis);
let split_with_chunk_size = |chunk_size| {
range_chunks(0..axis_size, chunk_size)
.map(|split_range| input.slice_axis(axis, split_range).to_tensor_in(pool))
.collect()
};
let outputs = match split {
SplitSizes::Size(size) => {
if size < 1 {
return Err(OpError::InvalidValue("Split size must be >= 1"));
}
split_with_chunk_size(size as usize)
}
SplitSizes::Sizes(split) => {
if split.iter().any(|size| *size < 0) {
return Err(OpError::InvalidValue("Split sizes must be >= 0"));
}
let split_sum = split.iter().sum::<i32>() as usize;
if split_sum != input.size(axis) {
return Err(OpError::InvalidValue(
"Split sizes do not sum to dimension size",
));
}
let mut split_start = 0;
split
.iter()
.map(|&split_size| {
let split_size = split_size as usize;
let split_range = split_start..split_start + split_size;
split_start += split_size;
input.slice_axis(axis, split_range).to_tensor_in(pool)
})
.collect()
}
SplitSizes::NumSplits(n_splits) => {
let n_splits = n_splits as usize;
if n_splits == 0 {
return Err(OpError::InvalidValue("num_outputs must be > 0"));
}
if n_splits > axis_size {
return Err(OpError::InvalidValue("num_outputs exceeds dim size"));
}
split_with_chunk_size(axis_size.div_ceil(n_splits))
}
};
Ok(outputs)
}
#[derive(Debug)]
pub struct Split {
pub axis: isize,
pub num_outputs: Option<u32>,
}
impl Operator for Split {
fn name(&self) -> &str {
"Split"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require(0)?;
let splits = ctx.inputs().get_as(1)?;
let num_outputs = self.num_outputs.or(ctx.num_outputs());
let split_sizes = if let Some(splits) = splits {
SplitSizes::Sizes(splits)
} else if let Some(num_outputs) = num_outputs {
SplitSizes::NumSplits(num_outputs)
} else {
return Err(OpError::InvalidValue(
"Either `num_outputs` or `splits` must be set",
));
};
map_value_view!(input, x, {
split(ctx.pool(), x, self.axis, split_sizes)
.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
})
}
fn output_types(&self, ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some(OutputTypeList::from_elem(
OutputType::CopyFromInput(0),
ctx.num_outputs,
))
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
Split,
op,
shape_ops::Split {
axis: op.axis as i32,
num_outputs: op.num_outputs
}
);
#[cfg(test)]
mod tests {
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};
use rten_testing::TestCases;
use crate::buffer_pool::BufferPool;
use crate::operator::{InputList, OpError, OpRunContext, Operator};
use super::{Split, SplitSizes, split};
#[test]
fn test_split() {
let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
#[derive(Debug)]
struct Case {
axis: isize,
splits: Option<NdTensor<i32, 1>>,
num_outputs: Option<u32>,
graph_outputs: Option<u32>,
expected: Vec<Tensor>,
}
let cases = [
Case {
axis: 1,
splits: Some([1, 1].into()),
num_outputs: None,
graph_outputs: None,
expected: [
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
]
.into(),
},
Case {
axis: -1,
splits: Some([1, 1].into()),
num_outputs: None,
graph_outputs: None,
expected: [
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
]
.into(),
},
Case {
axis: 0,
splits: None,
num_outputs: Some(3),
graph_outputs: None,
expected: [
Tensor::from([[0., 1.], [2., 3.]]),
Tensor::from([[4., 5.], [6., 7.]]),
Tensor::from([[8., 9.]]),
]
.into(),
},
Case {
axis: 1,
splits: None,
num_outputs: None,
graph_outputs: Some(2),
expected: [
Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
]
.into(),
},
];
cases.test_each(|case| {
let split_op = Split {
axis: case.axis,
num_outputs: case.num_outputs,
};
let inputs = InputList::from_iter([
Some(input.view().into()),
case.splits.as_ref().map(|s| s.view().into()),
]);
let pool = BufferPool::new();
let mut ctx = OpRunContext::new(&pool, &inputs);
if let Some(n_outputs) = case.graph_outputs {
ctx.set_num_outputs(n_outputs);
}
let results = split_op.run(&ctx).unwrap();
let results: Vec<Tensor> = results.into_iter().map(|o| o.try_into().unwrap()).collect();
let expected_splits = match (case.splits.as_ref(), case.num_outputs) {
(None, Some(n)) => n as usize,
(Some(sizes), None) => sizes.len(),
(None, None) => case.graph_outputs.unwrap() as usize,
(Some(_), Some(_)) => 0,
};
assert_eq!(results.len(), expected_splits);
assert_eq!(results, case.expected);
})
}
#[test]
fn test_split_invalid_inputs() {
let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
#[derive(Debug)]
struct Case<'a> {
axis: isize,
splits: SplitSizes<'a>,
expected: OpError,
}
let cases = [
Case {
axis: 2,
splits: [1, 1].as_slice().into(),
expected: OpError::InvalidValue("Axis is invalid"),
},
Case {
axis: 1,
splits: [1, 2].as_slice().into(),
expected: OpError::InvalidValue("Split sizes do not sum to dimension size"),
},
Case {
axis: 1,
splits: [1, -2].as_slice().into(),
expected: OpError::InvalidValue("Split sizes must be >= 0"),
},
Case {
axis: 1,
splits: SplitSizes::NumSplits(0),
expected: OpError::InvalidValue("num_outputs must be > 0"),
},
Case {
axis: 1,
splits: SplitSizes::NumSplits(3),
expected: OpError::InvalidValue("num_outputs exceeds dim size"),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let result = split(&pool, input.view(), case.axis, case.splits.clone());
assert_eq!(result.err().as_ref(), Some(&case.expected));
})
}
}