1use rten_base::iter::range_chunks;
2use rten_shape_inference::ops as shape_ops;
3use rten_tensor::prelude::*;
4use rten_tensor::{NdTensorView, Tensor, TensorView};
5
6use crate::buffer_pool::BufferPool;
7use crate::infer_shapes::{InferShapes, impl_infer_shapes};
8use crate::operator::{
9 OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList, OutputTypesContext,
10};
11use crate::ops::{map_value_view, resolve_axis};
12use crate::value::ValueView;
13
14#[derive(Clone, Debug)]
15pub enum SplitSizes<'a> {
16 Size(i32),
19 Sizes(NdTensorView<'a, i32, 1>),
22 NumSplits(u32),
25}
26
27impl<'a> From<&'a [i32]> for SplitSizes<'a> {
28 fn from(val: &'a [i32]) -> Self {
29 Self::Sizes(val.into())
30 }
31}
32
33pub fn split<T: Copy>(
34 pool: &BufferPool,
35 input: TensorView<T>,
36 axis: isize,
37 split: SplitSizes,
38) -> Result<Vec<Tensor<T>>, OpError> {
39 let axis = resolve_axis(input.ndim(), axis)?;
40 let axis_size = input.size(axis);
41
42 let split_with_chunk_size = |chunk_size| {
43 range_chunks(0..axis_size, chunk_size)
44 .map(|split_range| input.slice_axis(axis, split_range).to_tensor_in(pool))
45 .collect()
46 };
47
48 let outputs = match split {
49 SplitSizes::Size(size) => {
50 if size < 1 {
51 return Err(OpError::InvalidValue("Split size must be >= 1"));
52 }
53 split_with_chunk_size(size as usize)
54 }
55 SplitSizes::Sizes(split) => {
56 if split.iter().any(|size| *size < 0) {
57 return Err(OpError::InvalidValue("Split sizes must be >= 0"));
58 }
59 let split_sum = split.iter().sum::<i32>() as usize;
60 if split_sum != input.size(axis) {
61 return Err(OpError::InvalidValue(
62 "Split sizes do not sum to dimension size",
63 ));
64 }
65
66 let mut split_start = 0;
67 split
68 .iter()
69 .map(|&split_size| {
70 let split_size = split_size as usize;
71 let split_range = split_start..split_start + split_size;
72 split_start += split_size;
73 input.slice_axis(axis, split_range).to_tensor_in(pool)
74 })
75 .collect()
76 }
77 SplitSizes::NumSplits(n_splits) => {
78 let n_splits = n_splits as usize;
79 if n_splits == 0 {
80 return Err(OpError::InvalidValue("num_outputs must be > 0"));
81 }
82 if n_splits > axis_size {
83 return Err(OpError::InvalidValue("num_outputs exceeds dim size"));
84 }
85 split_with_chunk_size(axis_size.div_ceil(n_splits))
86 }
87 };
88
89 Ok(outputs)
90}
91
92#[derive(Debug)]
93pub struct Split {
94 pub axis: isize,
95 pub num_outputs: Option<u32>,
96}
97
98impl Operator for Split {
99 fn name(&self) -> &str {
100 "Split"
101 }
102
103 fn max_inputs(&self) -> Option<usize> {
104 Some(2)
105 }
106
107 fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
108 let input = ctx.inputs().require(0)?;
109
110 let splits = ctx.inputs().get_as(1)?;
118 let num_outputs = self.num_outputs.or(ctx.num_outputs());
119
120 let split_sizes = if let Some(splits) = splits {
121 SplitSizes::Sizes(splits)
122 } else if let Some(num_outputs) = num_outputs {
123 SplitSizes::NumSplits(num_outputs)
124 } else {
125 return Err(OpError::InvalidValue(
126 "Either `num_outputs` or `splits` must be set",
127 ));
128 };
129
130 map_value_view!(input, x, {
131 split(ctx.pool(), x, self.axis, split_sizes)
132 .map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
133 })
134 }
135
136 fn output_types(&self, ctx: &OutputTypesContext) -> Option<OutputTypeList> {
137 Some(OutputTypeList::from_elem(
138 OutputType::CopyFromInput(0),
139 ctx.num_outputs,
140 ))
141 }
142
143 fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
144 Some(self)
145 }
146}
147
148impl_infer_shapes!(
149 Split,
150 op,
151 shape_ops::Split {
152 axis: op.axis as i32,
153 num_outputs: op.num_outputs
154 }
155);
156
157#[cfg(test)]
158mod tests {
159 use rten_tensor::prelude::*;
160 use rten_tensor::{NdTensor, Tensor};
161 use rten_testing::TestCases;
162
163 use crate::buffer_pool::BufferPool;
164 use crate::operator::{InputList, OpError, OpRunContext, Operator};
165
166 use super::{Split, SplitSizes, split};
167
168 #[test]
169 fn test_split() {
170 let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
171
172 #[derive(Debug)]
173 struct Case {
174 axis: isize,
175 splits: Option<NdTensor<i32, 1>>,
176 num_outputs: Option<u32>,
177
178 graph_outputs: Option<u32>,
180
181 expected: Vec<Tensor>,
182 }
183
184 let cases = [
185 Case {
187 axis: 1,
188 splits: Some([1, 1].into()),
189 num_outputs: None,
190 graph_outputs: None,
191 expected: [
192 Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
193 Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
194 ]
195 .into(),
196 },
197 Case {
199 axis: -1,
200 splits: Some([1, 1].into()),
201 num_outputs: None,
202 graph_outputs: None,
203 expected: [
204 Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
205 Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
206 ]
207 .into(),
208 },
209 Case {
211 axis: 0,
212 splits: None,
213 num_outputs: Some(3),
214 graph_outputs: None,
215 expected: [
216 Tensor::from([[0., 1.], [2., 3.]]),
217 Tensor::from([[4., 5.], [6., 7.]]),
218 Tensor::from([[8., 9.]]),
219 ]
220 .into(),
221 },
222 Case {
224 axis: 1,
225 splits: None,
226 num_outputs: None,
227 graph_outputs: Some(2),
228 expected: [
229 Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
230 Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
231 ]
232 .into(),
233 },
234 ];
235
236 cases.test_each(|case| {
237 let split_op = Split {
238 axis: case.axis,
239 num_outputs: case.num_outputs,
240 };
241
242 let inputs = InputList::from_iter([
243 Some(input.view().into()),
244 case.splits.as_ref().map(|s| s.view().into()),
245 ]);
246 let pool = BufferPool::new();
247 let mut ctx = OpRunContext::new(&pool, &inputs);
248 if let Some(n_outputs) = case.graph_outputs {
249 ctx.set_num_outputs(n_outputs);
250 }
251 let results = split_op.run(&ctx).unwrap();
252 let results: Vec<Tensor> = results.into_iter().map(|o| o.try_into().unwrap()).collect();
253
254 let expected_splits = match (case.splits.as_ref(), case.num_outputs) {
255 (None, Some(n)) => n as usize,
256 (Some(sizes), None) => sizes.len(),
257 (None, None) => case.graph_outputs.unwrap() as usize,
258 (Some(_), Some(_)) => 0,
259 };
260 assert_eq!(results.len(), expected_splits);
261 assert_eq!(results, case.expected);
262 })
263 }
264
265 #[test]
266 fn test_split_invalid_inputs() {
267 let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
268
269 #[derive(Debug)]
270 struct Case<'a> {
271 axis: isize,
272 splits: SplitSizes<'a>,
273 expected: OpError,
274 }
275
276 let cases = [
277 Case {
278 axis: 2,
279 splits: [1, 1].as_slice().into(),
280 expected: OpError::InvalidValue("Axis is invalid"),
281 },
282 Case {
283 axis: 1,
284 splits: [1, 2].as_slice().into(),
285 expected: OpError::InvalidValue("Split sizes do not sum to dimension size"),
286 },
287 Case {
288 axis: 1,
289 splits: [1, -2].as_slice().into(),
290 expected: OpError::InvalidValue("Split sizes must be >= 0"),
291 },
292 Case {
293 axis: 1,
294 splits: SplitSizes::NumSplits(0),
295 expected: OpError::InvalidValue("num_outputs must be > 0"),
296 },
297 Case {
298 axis: 1,
299 splits: SplitSizes::NumSplits(3),
300 expected: OpError::InvalidValue("num_outputs exceeds dim size"),
301 },
302 ];
303
304 cases.test_each(|case| {
305 let pool = BufferPool::new();
306 let result = split(&pool, input.view(), case.axis, case.splits.clone());
307 assert_eq!(result.err().as_ref(), Some(&case.expected));
308 })
309 }
310}