rten_shape_inference/lib.rs
1//! Shape inference for ONNX graphs.
2//!
3//! # About shape inference
4//!
5//! Some ONNX model optimizations depend upon knowledge about the shapes and
6//! types of various values in the graph. These values may have dynamic sizes
7//! that depend on model inputs. In a typical language model for example, the
8//! input has dynamic dimensions for the batch size and sequence length.
9//!
10//! The goal of shape inference is to take information embedded in the model
11//! about the shapes of model inputs and trace how graph operators transform,
12//! extract and otherwise process tensor shapes, and produce metadata about the
13//! shape of each value in the graph.
14//!
15//! As an example, suppose a model has an image input of shape (batch, 3,
16//! height, width) and computes a mask with shape (batch, height, width). This
17//! could be done with a sequence of operators such as:
18//!
19//! ```text
20//! S = Shape(Image) // ["batch", 3, "height", "width"]
21//! B = Gather(S, axis=0, indices=0) // "batch"
22//! BV = Unsqueeze(B, axis=0) // ["batch"]
23//! H = Gather(S, axis=0, indices=2) // "height"
24//! HV = Unsqueeze(H, axis=0) // ["height"]
25//! W = Gather(S, axis=0, indices=3) // "width"
26//! WV = Unsqueeze(H, axis=0) // ["width"]
27//! S2 = Concat<axis=0>(BV, HV, WV) // ["batch", "height", "width"]
28//! Mask = ConstantOfShape<value=1>(S2) // shape("batch", "height", "width")
29//! ```
30//!
31//! Shape inference of this graph involves following the extraction of the of
32//! the input shape, its transformation and uses in order to determine the shape
33//! of the output. If there was an optimization to combine all these nodes into
34//! one, which depended on knowing that the output shape was the same as the
35//! input minus the second dimension, the results of shape inference could be
36//! used to verify this.
37//!
38//! # Crate overview
39//!
40//! The main export of this crate is the [`InferShapes`] trait, plus types which
41//! implement it in [`ops`]. This trait computes the output shapes produced for
42//! a given set of input shapes. The shapes are symbolic, meaning they can
43//! represent variables that change at inference time.
44//!
45//! Many ONNX operators have the same shape inference rules, so there is an M:1
46//! mapping between operators and shape inference implementations. For most
47//! operators, shape inference only represents how the operator transforms the
48//! shapes of tensors, but not their values. For a subset of operators, shape
49//! inference can also understand how the operator transforms the values of
50//! inputs, where the values are scalars or vectors of symbolic expressions.
51//! This is needed for understanding subgraphs in ONNX models that extract and
52//! transform shapes. For example, shape inference for the `Concat` op can
53//! express that concatentating vectors `["batch"]` and `["height" / 2, "width"
54//! / 2]` produces the output `["batch", "height" / 2, "width" / 2]`.
55//!
56//! # Symbolic values
57//!
58//! Symbolic values are multi-dimensional array types where the dimension sizes
59//! and elements are _symbolic expressions_. Expressions can be known integers,
60//! named symbols, or composite expressions involving these. Values are
61//! represented by [`SymTensor`] and expressions by [`SymExpr`].
62
63mod infer_shapes;
64pub mod ops;
65mod sym_expr;
66mod sym_gen;
67mod sym_tensor;
68
69pub use infer_shapes::{BinaryOp, InferShapes, InferShapesError, ReductionOp, UnaryOp, VariadicOp};
70pub use sym_expr::{SymExpr, Symbol};
71pub use sym_gen::SymbolGen;
72pub use sym_tensor::{Constant, SymTensor};