rten/ops/
split.rs

1use rten_base::iter::range_chunks;
2use rten_shape_inference::ops as shape_ops;
3use rten_tensor::prelude::*;
4use rten_tensor::{NdTensorView, Tensor, TensorView};
5
6use crate::buffer_pool::BufferPool;
7use crate::infer_shapes::{InferShapes, impl_infer_shapes};
8use crate::operator::{
9    OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList, OutputTypesContext,
10};
11use crate::ops::{map_value_view, resolve_axis};
12use crate::value::ValueView;
13
14#[derive(Clone, Debug)]
15pub enum SplitSizes<'a> {
16    /// Split a tensor into pieces with a given size. If the axis size is not
17    /// evenly divisible by the size, the last piece will be smaller.
18    Size(i32),
19    /// Split a tensor into pieces with sizes specified by a vector. The sum of
20    /// the piece sizes must match the size of the axis.
21    Sizes(NdTensorView<'a, i32, 1>),
22    /// Split a tensor into N equal-sized pieces. If the size of the axis being
23    /// split is not evenly divisible by N, the last chunk will be smaller.
24    NumSplits(u32),
25}
26
27impl<'a> From<&'a [i32]> for SplitSizes<'a> {
28    fn from(val: &'a [i32]) -> Self {
29        Self::Sizes(val.into())
30    }
31}
32
33pub fn split<T: Copy>(
34    pool: &BufferPool,
35    input: TensorView<T>,
36    axis: isize,
37    split: SplitSizes,
38) -> Result<Vec<Tensor<T>>, OpError> {
39    let axis = resolve_axis(input.ndim(), axis)?;
40    let axis_size = input.size(axis);
41
42    let split_with_chunk_size = |chunk_size| {
43        range_chunks(0..axis_size, chunk_size)
44            .map(|split_range| input.slice_axis(axis, split_range).to_tensor_in(pool))
45            .collect()
46    };
47
48    let outputs = match split {
49        SplitSizes::Size(size) => {
50            if size < 1 {
51                return Err(OpError::InvalidValue("Split size must be >= 1"));
52            }
53            split_with_chunk_size(size as usize)
54        }
55        SplitSizes::Sizes(split) => {
56            if split.iter().any(|size| *size < 0) {
57                return Err(OpError::InvalidValue("Split sizes must be >= 0"));
58            }
59            let split_sum = split.iter().sum::<i32>() as usize;
60            if split_sum != input.size(axis) {
61                return Err(OpError::InvalidValue(
62                    "Split sizes do not sum to dimension size",
63                ));
64            }
65
66            let mut split_start = 0;
67            split
68                .iter()
69                .map(|&split_size| {
70                    let split_size = split_size as usize;
71                    let split_range = split_start..split_start + split_size;
72                    split_start += split_size;
73                    input.slice_axis(axis, split_range).to_tensor_in(pool)
74                })
75                .collect()
76        }
77        SplitSizes::NumSplits(n_splits) => {
78            let n_splits = n_splits as usize;
79            if n_splits == 0 {
80                return Err(OpError::InvalidValue("num_outputs must be > 0"));
81            }
82            if n_splits > axis_size {
83                return Err(OpError::InvalidValue("num_outputs exceeds dim size"));
84            }
85            split_with_chunk_size(axis_size.div_ceil(n_splits))
86        }
87    };
88
89    Ok(outputs)
90}
91
92#[derive(Debug)]
93pub struct Split {
94    pub axis: isize,
95    pub num_outputs: Option<u32>,
96}
97
98impl Operator for Split {
99    fn name(&self) -> &str {
100        "Split"
101    }
102
103    fn max_inputs(&self) -> Option<usize> {
104        Some(2)
105    }
106
107    fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
108        let input = ctx.inputs().require(0)?;
109
110        // In Split v18+, the operator should specify either a vector of split
111        // sizes or a count of outputs to produce. Older versions of the Split
112        // operator could omit both of these, in which case the number of
113        // outputs was determined by looking at the operator node's number of
114        // outputs.
115        //
116        // See https://github.com/robertknight/rten/issues/689.
117        let splits = ctx.inputs().get_as(1)?;
118        let num_outputs = self.num_outputs.or(ctx.num_outputs());
119
120        let split_sizes = if let Some(splits) = splits {
121            SplitSizes::Sizes(splits)
122        } else if let Some(num_outputs) = num_outputs {
123            SplitSizes::NumSplits(num_outputs)
124        } else {
125            return Err(OpError::InvalidValue(
126                "Either `num_outputs` or `splits` must be set",
127            ));
128        };
129
130        map_value_view!(input, x, {
131            split(ctx.pool(), x, self.axis, split_sizes)
132                .map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
133        })
134    }
135
136    fn output_types(&self, ctx: &OutputTypesContext) -> Option<OutputTypeList> {
137        Some(OutputTypeList::from_elem(
138            OutputType::CopyFromInput(0),
139            ctx.num_outputs,
140        ))
141    }
142
143    fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
144        Some(self)
145    }
146}
147
148impl_infer_shapes!(
149    Split,
150    op,
151    shape_ops::Split {
152        axis: op.axis as i32,
153        num_outputs: op.num_outputs
154    }
155);
156
157#[cfg(test)]
158mod tests {
159    use rten_tensor::prelude::*;
160    use rten_tensor::{NdTensor, Tensor};
161    use rten_testing::TestCases;
162
163    use crate::buffer_pool::BufferPool;
164    use crate::operator::{InputList, OpError, OpRunContext, Operator};
165
166    use super::{Split, SplitSizes, split};
167
168    #[test]
169    fn test_split() {
170        let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
171
172        #[derive(Debug)]
173        struct Case {
174            axis: isize,
175            splits: Option<NdTensor<i32, 1>>,
176            num_outputs: Option<u32>,
177
178            // Number of outputs the Split node has in the graph.
179            graph_outputs: Option<u32>,
180
181            expected: Vec<Tensor>,
182        }
183
184        let cases = [
185            // Positive axis, splits specified via input.
186            Case {
187                axis: 1,
188                splits: Some([1, 1].into()),
189                num_outputs: None,
190                graph_outputs: None,
191                expected: [
192                    Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
193                    Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
194                ]
195                .into(),
196            },
197            // Negative axis, splits specified via input.
198            Case {
199                axis: -1,
200                splits: Some([1, 1].into()),
201                num_outputs: None,
202                graph_outputs: None,
203                expected: [
204                    Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
205                    Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
206                ]
207                .into(),
208            },
209            // Split count specified via `num_outputs` attribute.
210            Case {
211                axis: 0,
212                splits: None,
213                num_outputs: Some(3),
214                graph_outputs: None,
215                expected: [
216                    Tensor::from([[0., 1.], [2., 3.]]),
217                    Tensor::from([[4., 5.], [6., 7.]]),
218                    Tensor::from([[8., 9.]]),
219                ]
220                .into(),
221            },
222            // Split count inferred from graph outputs
223            Case {
224                axis: 1,
225                splits: None,
226                num_outputs: None,
227                graph_outputs: Some(2),
228                expected: [
229                    Tensor::from([[0.], [2.], [4.], [6.], [8.]]),
230                    Tensor::from([[1.], [3.], [5.], [7.], [9.]]),
231                ]
232                .into(),
233            },
234        ];
235
236        cases.test_each(|case| {
237            let split_op = Split {
238                axis: case.axis,
239                num_outputs: case.num_outputs,
240            };
241
242            let inputs = InputList::from_iter([
243                Some(input.view().into()),
244                case.splits.as_ref().map(|s| s.view().into()),
245            ]);
246            let pool = BufferPool::new();
247            let mut ctx = OpRunContext::new(&pool, &inputs);
248            if let Some(n_outputs) = case.graph_outputs {
249                ctx.set_num_outputs(n_outputs);
250            }
251            let results = split_op.run(&ctx).unwrap();
252            let results: Vec<Tensor> = results.into_iter().map(|o| o.try_into().unwrap()).collect();
253
254            let expected_splits = match (case.splits.as_ref(), case.num_outputs) {
255                (None, Some(n)) => n as usize,
256                (Some(sizes), None) => sizes.len(),
257                (None, None) => case.graph_outputs.unwrap() as usize,
258                (Some(_), Some(_)) => 0,
259            };
260            assert_eq!(results.len(), expected_splits);
261            assert_eq!(results, case.expected);
262        })
263    }
264
265    #[test]
266    fn test_split_invalid_inputs() {
267        let input = Tensor::from([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]);
268
269        #[derive(Debug)]
270        struct Case<'a> {
271            axis: isize,
272            splits: SplitSizes<'a>,
273            expected: OpError,
274        }
275
276        let cases = [
277            Case {
278                axis: 2,
279                splits: [1, 1].as_slice().into(),
280                expected: OpError::InvalidValue("Axis is invalid"),
281            },
282            Case {
283                axis: 1,
284                splits: [1, 2].as_slice().into(),
285                expected: OpError::InvalidValue("Split sizes do not sum to dimension size"),
286            },
287            Case {
288                axis: 1,
289                splits: [1, -2].as_slice().into(),
290                expected: OpError::InvalidValue("Split sizes must be >= 0"),
291            },
292            Case {
293                axis: 1,
294                splits: SplitSizes::NumSplits(0),
295                expected: OpError::InvalidValue("num_outputs must be > 0"),
296            },
297            Case {
298                axis: 1,
299                splits: SplitSizes::NumSplits(3),
300                expected: OpError::InvalidValue("num_outputs exceeds dim size"),
301            },
302        ];
303
304        cases.test_each(|case| {
305            let pool = BufferPool::new();
306            let result = split(&pool, input.view(), case.axis, case.splits.clone());
307            assert_eq!(result.err().as_ref(), Some(&case.expected));
308        })
309    }
310}