rten_shape_inference/
infer_shapes.rs

1//! Traits for shape inference and common implementations.
2
3use smallvec::SmallVec;
4
5pub use crate::{
6    sym_expr::SymExpr,
7    sym_gen::SymbolGen,
8    sym_tensor::{Constant, SymTensor},
9};
10
11/// Errors when performing shape inference.
12#[derive(Clone, Debug, PartialEq)]
13pub enum InferShapesError {
14    /// Too many or too few inputs were provided for this operator.
15    IncorrectInputCount,
16
17    /// The input shapes are incompatible.
18    ///
19    /// Operator execution will fail if given inputs with these shapes.
20    IncompatibleShapes,
21
22    /// An input's rank does not match that expected by the operator.
23    IncorrectRank,
24
25    /// An operator input or attribute has an invalid value.
26    InvalidValue,
27
28    /// The number of outputs could not be determined.
29    UnknownOutputCount,
30}
31
32/// Infer the shapes of an operator's outputs given its inputs.
33pub trait InferShapes {
34    /// Infer the shapes and optionally values of an operator's outputs given
35    /// its inputs.
36    ///
37    /// The operator may need to generate new symbolic dimensions to represent
38    /// dimensions that are unknown or combinations of inputs. These should be
39    /// generated using `sym_gen`.
40    fn infer_shapes(
41        &self,
42        inputs: &[SymTensor],
43        sym_gen: &mut SymbolGen,
44    ) -> Result<Vec<SymTensor>, InferShapesError>;
45}
46
47/// Shape inference for unary operators.
48///
49/// These operators take at least one input and return a single output with
50/// the same shape as the first input. Unary operators may take additional
51/// inputs (eg. min/max parameters for the Clip operator) that don't affect
52/// the output shape.
53pub struct UnaryOp;
54
55impl InferShapes for UnaryOp {
56    fn infer_shapes(
57        &self,
58        inputs: &[SymTensor],
59        _sym_gen: &mut SymbolGen,
60    ) -> Result<Vec<SymTensor>, InferShapesError> {
61        let Some(data) = inputs.first() else {
62            return Err(InferShapesError::IncorrectInputCount);
63        };
64
65        let shape = if let Some(shape) = data.shape() {
66            SymTensor::from_shape(shape.collect())
67        } else {
68            SymTensor::unknown("unknown input shape")
69        };
70
71        Ok([shape].into())
72    }
73}
74
75/// Shape inference for binary operators.
76///
77/// These operators take two inputs and return an output whose shape is the
78/// result of broadcasting the two input shapes together following ONNX's
79/// [broadcasting rules](https://onnx.ai/onnx/repo-docs/Broadcasting.html).
80pub struct BinaryOp;
81
82impl InferShapes for BinaryOp {
83    fn infer_shapes(
84        &self,
85        inputs: &[SymTensor],
86        _sym_gen: &mut SymbolGen,
87    ) -> Result<Vec<SymTensor>, InferShapesError> {
88        let [a, b] = inputs else {
89            return Err(InferShapesError::IncorrectInputCount);
90        };
91
92        let (Some(a_dims), Some(b_dims)) = (a.shape(), b.shape()) else {
93            return Ok([SymTensor::unknown("unknown input shape")].into());
94        };
95
96        let a_pad = b_dims.len().saturating_sub(a_dims.len());
97        let b_pad = a_dims.len().saturating_sub(b_dims.len());
98        let mut out_shape: Vec<SymExpr> = Vec::with_capacity(a_pad + a_dims.len());
99
100        let a_iter = std::iter::repeat_n(SymExpr::Value(1), a_pad).chain(a_dims);
101        let b_iter = std::iter::repeat_n(SymExpr::Value(1), b_pad).chain(b_dims);
102
103        for (a, b) in a_iter.zip(b_iter) {
104            let dim: SymExpr = match (a, b) {
105                (a, b) if a == b => a.clone(),
106
107                // If either size is 1, it will be broadcast against the other
108                // size.
109                (SymExpr::Value(1), b) => b.clone(),
110                (a, SymExpr::Value(1)) => a.clone(),
111
112                // If both sizes are fixed and different, we know execution
113                // will fail.
114                (SymExpr::Value(_), SymExpr::Value(_)) => {
115                    return Err(InferShapesError::IncompatibleShapes);
116                }
117
118                // If one dim is a fixed value other than 1 and the other
119                // dim is symbolic, execution can only succeed if the symbolic
120                // dim has the same size as the fixed dim.
121                (SymExpr::Var(_a), SymExpr::Value(b)) => SymExpr::Value(b),
122                (SymExpr::Value(a), SymExpr::Var(_b)) => SymExpr::Value(a),
123
124                // In cases where both values are unknown, the result can be
125                // either of the dimensions.
126                //
127                // 1. If both sizes are equal, the result can be seen as either
128                //    the first or second dim.
129                // 2. If only one of the sizes is 1, the result will be the other
130                //    dim.
131                // 3. If the sizes are different, the op will fail.
132                //
133                // Where the op succeeds, the result is the maximum of the LHS
134                // and RHS sizes.
135                (a, b) => a.broadcast(&b),
136            };
137            out_shape.push(dim);
138        }
139
140        Ok([SymTensor::from_shape(out_shape)].into())
141    }
142}
143
144/// Shape inference for variadic operators.
145///
146/// This is a generalization of [`BinaryOp`] to operators which take a variable
147/// number of inputs whose shapes are broadcast against each other, using the
148/// same rules.
149pub struct VariadicOp;
150
151impl InferShapes for VariadicOp {
152    fn infer_shapes(
153        &self,
154        inputs: &[SymTensor],
155        sym_gen: &mut SymbolGen,
156    ) -> Result<Vec<SymTensor>, InferShapesError> {
157        if inputs.is_empty() {
158            return Err(InferShapesError::IncorrectInputCount);
159        }
160
161        let first_shape = inputs[0]
162            .shape()
163            .map(|shape| SymTensor::from_shape(shape.collect()))
164            .unwrap_or_else(|| SymTensor::unknown("unknown input shape"));
165
166        let out_shape: Result<SymTensor, InferShapesError> =
167            inputs
168                .iter()
169                .skip(1)
170                .try_fold(first_shape, |out_shape, in_shape| {
171                    let mut shapes =
172                        BinaryOp.infer_shapes(&[out_shape, in_shape.clone()], sym_gen)?;
173                    Ok(shapes.remove(0))
174                });
175
176        Ok([out_shape?].into())
177    }
178}
179
180/// Shape inference for reduction operators.
181#[derive(Clone, Debug, PartialEq)]
182pub struct ReductionOp<'a> {
183    /// Axes over which the reduction is applied.
184    ///
185    /// Reduction ops take the axes as an attribute in ONNX opset <= 13 and an
186    /// input in opset 18+.
187    pub axes: Option<&'a [i32]>,
188
189    /// True if the reduced dimension is retained as a 1-sized dimension in the
190    /// output.
191    pub keep_dims: bool,
192}
193
194impl InferShapes for ReductionOp<'_> {
195    fn infer_shapes(
196        &self,
197        inputs: &[SymTensor],
198        _sym_gen: &mut SymbolGen,
199    ) -> Result<Vec<SymTensor>, InferShapesError> {
200        match inputs.len() {
201            1 | 2 => {}
202            _ => {
203                return Err(InferShapesError::IncorrectInputCount);
204            }
205        }
206
207        let data = &inputs[0];
208
209        let Some(data_dims) = data.shape() else {
210            return Ok([SymTensor::unknown("unknown input shape")].into());
211        };
212
213        let ndim = data_dims.len();
214        let mut axes: SmallVec<[usize; 4]> =
215            if let Some(Constant::Vector(axes)) = inputs.get(1).and_then(|x| x.to_constant()) {
216                resolve_axes(ndim, axes.iter()).map_err(|_| InferShapesError::IncorrectRank)?
217            } else if let Some(axes) = self.axes {
218                resolve_axes(ndim, axes.iter()).map_err(|_| InferShapesError::IncorrectRank)?
219            } else {
220                (0..ndim).collect()
221            };
222        axes.sort();
223        axes.dedup();
224
225        let out_ndim = if self.keep_dims {
226            ndim
227        } else {
228            ndim - axes.len()
229        };
230        let mut out_shape = Vec::with_capacity(out_ndim);
231
232        for (i, dim) in data_dims.enumerate() {
233            if !axes.contains(&i) {
234                out_shape.push(dim.clone());
235                continue;
236            } else if self.keep_dims {
237                out_shape.push(SymExpr::Value(1));
238            }
239        }
240
241        Ok([SymTensor::from_shape(out_shape)].into())
242    }
243}
244
245/// Resolve an index given as a value in `[-len, len-1]` to a positive index in
246/// `[0, len)`, or return None if the index is out of bounds.
247fn resolve_index(len: usize, index: i32) -> Option<usize> {
248    let len = len.min(i32::MAX as usize) as i32;
249    if index < -len || index >= len {
250        return None;
251    }
252
253    if index >= 0 {
254        Some(index as usize)
255    } else {
256        Some((len + index) as usize)
257    }
258}
259
260/// Resolve an axis given as a value in `[-ndim, ndim-1]` to the zero-based
261/// dimension of a tensor with `ndim` dimensions.
262///
263/// Negative axis values count backwards from the last dimension.
264pub(crate) fn resolve_axis(ndim: usize, axis: i32) -> Result<usize, InferShapesError> {
265    resolve_index(ndim, axis).ok_or(InferShapesError::IncorrectRank)
266}
267
268/// Resolve a sequence of axes values in `[-ndim, ndim-1]` to zero-based dimension
269/// indexes in a tensor with `ndim` dimensions.
270///
271/// Negative axis values count backwards from the last dimension.
272fn resolve_axes<'a, I: ExactSizeIterator<Item = &'a i32>>(
273    ndim: usize,
274    axes: I,
275) -> Result<SmallVec<[usize; 4]>, InferShapesError> {
276    let mut resolved_axes = SmallVec::with_capacity(axes.len());
277    for axis in axes {
278        let resolved = resolve_axis(ndim, *axis)?;
279        resolved_axes.push(resolved);
280    }
281    Ok(resolved_axes)
282}
283
284#[cfg(test)]
285mod tests {
286    use rten_testing::TestCases;
287
288    use super::{
289        BinaryOp, InferShapes, InferShapesError, ReductionOp, SymExpr, SymTensor, SymbolGen,
290        UnaryOp, VariadicOp,
291    };
292    use crate::sym_tensor::{sym_elems, sym_shape};
293
294    #[test]
295    fn test_unary_op_infer() {
296        let input = sym_shape!("batch", 16, "seq", 24);
297        let mut sym_gen = SymbolGen::new();
298        let shape = UnaryOp
299            .infer_shapes(&[input.clone()], &mut sym_gen)
300            .unwrap();
301        assert_eq!(shape.len(), 1);
302        assert_eq!(shape[0], input);
303
304        let err = UnaryOp.infer_shapes(&[], &mut sym_gen).err().unwrap();
305        assert_eq!(err, InferShapesError::IncorrectInputCount);
306    }
307
308    #[test]
309    fn test_binary_op() {
310        #[derive(Debug)]
311        struct Case {
312            lhs: SymTensor,
313            rhs: SymTensor,
314            expected: SymTensor,
315        }
316
317        let cases = [
318            Case {
319                lhs: sym_shape!("batch"),
320                rhs: sym_shape!("batch"),
321                expected: sym_shape!("batch"),
322            },
323            Case {
324                lhs: sym_shape!(2, 3),
325                rhs: sym_shape!(2, 3),
326                expected: sym_shape!(2, 3),
327            },
328            Case {
329                lhs: sym_shape!(1, 5),
330                rhs: sym_shape!(4, 1),
331                expected: sym_shape!(4, 5),
332            },
333            Case {
334                lhs: sym_shape!(1, 1),
335                rhs: sym_shape!(1, 1),
336                expected: sym_shape!(1, 1),
337            },
338            Case {
339                lhs: sym_shape!(1, "bar"),
340                rhs: sym_shape!("foo", 1),
341                expected: sym_shape!("foo", "bar"),
342            },
343            Case {
344                lhs: sym_shape!("foo"),
345                rhs: sym_shape!("bar"),
346                expected: sym_shape!(SymExpr::from("foo").broadcast(&SymExpr::from("bar"))),
347            },
348        ];
349
350        cases.test_each(|case| {
351            let mut sym_gen = SymbolGen::new();
352            let shape = BinaryOp
353                .infer_shapes(&[case.lhs.clone(), case.rhs.clone()], &mut sym_gen)
354                .unwrap();
355            assert_eq!(shape.len(), 1);
356            assert_eq!(shape[0], case.expected.clone());
357        });
358    }
359
360    #[test]
361    fn test_binary_op_invalid() {
362        #[derive(Clone, Debug)]
363        struct Case {
364            inputs: Vec<Vec<SymExpr>>,
365            expected: InferShapesError,
366        }
367
368        let cases = [
369            Case {
370                inputs: [sym_elems!(5)].into(),
371                expected: InferShapesError::IncorrectInputCount,
372            },
373            Case {
374                inputs: [sym_elems!(5), sym_elems!(3)].into(),
375                expected: InferShapesError::IncompatibleShapes,
376            },
377        ];
378
379        cases.test_each_clone(|case| {
380            let mut sym_gen = SymbolGen::new();
381            let inputs: Vec<_> = case.inputs.into_iter().map(SymTensor::from_shape).collect();
382            let err = BinaryOp.infer_shapes(&inputs, &mut sym_gen).err().unwrap();
383            assert_eq!(err, case.expected);
384        });
385    }
386
387    #[test]
388    fn test_variadic_op() {
389        let mut sym_gen = SymbolGen::new();
390        let a = sym_shape!("batch", 4, 1, 1);
391        let b = sym_shape!("batch", 1, 8, 1);
392        let c = sym_shape!("batch", 1, 8, 16);
393
394        // Single input
395        let result = VariadicOp.infer_shapes(&[a.clone()], &mut sym_gen).unwrap();
396        assert_eq!(result[0], sym_shape!("batch", 4, 1, 1));
397
398        // N inputs
399        let result = VariadicOp
400            .infer_shapes(&[a.clone(), b, c], &mut sym_gen)
401            .unwrap();
402        assert_eq!(result[0], sym_shape!("batch", 4, 8, 16));
403    }
404
405    #[test]
406    fn test_reduction_op() {
407        #[derive(Clone, Debug)]
408        struct Case<'a> {
409            inputs: Vec<SymTensor>,
410            op: ReductionOp<'a>,
411            expected: Vec<SymExpr>,
412        }
413
414        let axes = vec![SymExpr::Value(1i32)];
415
416        let default_op = ReductionOp {
417            axes: None,
418            keep_dims: false,
419        };
420
421        let cases = [
422            // Reduce single axis
423            Case {
424                inputs: [
425                    SymTensor::from_shape(sym_elems!("batch", 4, 5)),
426                    SymTensor::from_vec(axes.clone()),
427                ]
428                .into(),
429                op: default_op.clone(),
430                expected: sym_elems!("batch", 5),
431            },
432            // Reduce single axis specified as an attribute
433            Case {
434                inputs: [SymTensor::from_shape(sym_elems!("batch", 4, 5))].into(),
435                op: ReductionOp {
436                    axes: Some(&[1i32]),
437                    ..default_op
438                },
439                expected: sym_elems!("batch", 5),
440            },
441            // Reduce single axis with `keep_dims=true`
442            Case {
443                inputs: [
444                    SymTensor::from_shape(sym_elems!("batch", 4, 5)),
445                    SymTensor::from_vec(axes.clone()),
446                ]
447                .into(),
448                op: ReductionOp {
449                    keep_dims: true,
450                    ..default_op
451                },
452                expected: sym_elems!("batch", 1, 5),
453            },
454            // Reduce all axes
455            Case {
456                inputs: [SymTensor::from_shape(sym_elems!(3, 4, 5))].into(),
457                op: default_op.clone(),
458                expected: sym_elems!(),
459            },
460        ];
461
462        cases.test_each(|case| {
463            let mut sym_gen = SymbolGen::new();
464            let shapes = case.op.infer_shapes(&case.inputs, &mut sym_gen).unwrap();
465            assert_eq!(shapes.len(), 1);
466            assert_eq!(shapes[0], SymTensor::from_shape(case.expected.clone()));
467        });
468    }
469}