rten_shape_inference/
sym_tensor.rs

1//! Tensors with symbolic shapes and values.
2
3use std::fmt;
4
5use crate::sym_expr::SymExpr;
6
7/// Vector or scalar with integer values.
8#[derive(Clone, Eq, Hash, PartialEq)]
9pub enum Constant {
10    Scalar(i32),
11    Vector(Vec<i32>),
12}
13
14impl Constant {
15    pub fn ndim(&self) -> usize {
16        match self {
17            Self::Scalar(_) => 0,
18            Self::Vector(_) => 1,
19        }
20    }
21
22    pub fn values(&self) -> &[i32] {
23        match self {
24            Self::Scalar(elem) => std::slice::from_ref(elem),
25            Self::Vector(vec) => vec.as_slice(),
26        }
27    }
28
29    pub fn into_vec(self) -> Vec<i32> {
30        match self {
31            Self::Scalar(x) => vec![x],
32            Self::Vector(vec) => vec,
33        }
34    }
35}
36
37impl fmt::Debug for Constant {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            Self::Scalar(val) => write!(f, "{}", val),
41            Self::Vector(vec) => write!(f, "{:?}", vec),
42        }
43    }
44}
45
46#[derive(Clone, Debug, PartialEq)]
47enum SymTensorKind {
48    Scalar(SymExpr),
49    Vector(Vec<SymExpr>),
50    Shape(Vec<SymExpr>),
51    Unknown {
52        /// Note about why this Unknown value was created, for debugging purposes.
53        note: &'static str,
54    },
55}
56
57/// Tensor with symbolic shape and elements.
58///
59/// This is a tensor where the elements and dimension sizes can be either
60/// concrete values or symbolic expressions. This type is used during shape
61/// inference to represent the shapes of operator inputs and outputs, as well as
62/// the values of operations which manipulate shapes.
63///
64/// The symbolic expressions can be integers, symbols with names and assumptions
65/// about their values or composite expressions (addition, multiplication etc.)
66///
67/// ```
68/// use rten_shape_inference::{SymTensor, SymExpr};
69///
70/// // Create a matrix with `nr` rows, `nc` columns and unknown values.
71/// let nr = SymExpr::from("nr");
72/// let nc = SymExpr::from("nc");
73/// let matrix = SymTensor::from_shape(vec![nr.clone(), nc.clone()]);
74/// assert_eq!(matrix.ndim(), Some(2));
75/// assert_eq!(matrix.size(0), Some(nr.clone()));
76/// assert_eq!(matrix.size(1), Some(nc.clone()));
77///
78/// // Turn the matrix's shape into a vector with values `["nr", "nc"]`.
79/// let shape = SymTensor::from_vec(matrix.shape().unwrap().collect());
80/// assert_eq!(shape.ndim(), Some(1));
81/// assert_eq!(shape.values(), Some([nr.clone(), nc.clone()].as_slice()));
82///
83/// // Get the number of elements in the matrix as an expression `Some(nr * nc)`.
84/// let len = shape.values().map(|v| v.iter().fold(
85///     SymExpr::Value(1),
86///     |prod, dim| prod * dim.clone()
87/// ).simplify());
88/// assert_eq!(len, Some(SymExpr::Mul((nr.into(), nc.into()))));
89/// ```
90#[derive(Clone, Debug, PartialEq)]
91pub struct SymTensor(SymTensorKind);
92
93impl SymTensor {
94    /// Create a new symbolic tensor with unknown shape and values.
95    ///
96    /// `note` is a short string indicating the reason why the tensor shape
97    /// and values are unknown. This is used for debugging purposes.
98    pub fn unknown(note: &'static str) -> Self {
99        Self(SymTensorKind::Unknown { note })
100    }
101
102    /// Create a new symbolic tensor with the given shape and unknown values.
103    pub fn from_shape(shape: Vec<SymExpr>) -> Self {
104        Self(SymTensorKind::Shape(shape))
105    }
106
107    /// Create a new symbolic tensor with the given shape and unknown values.
108    pub fn from_fixed_shape(shape: &[usize]) -> Self {
109        Self(SymTensorKind::Shape(
110            shape
111                .iter()
112                .copied()
113                .map(|size| SymExpr::Value(size as i32))
114                .collect(),
115        ))
116    }
117
118    /// Create a new symbolic vector.
119    pub fn from_vec(vec: Vec<SymExpr>) -> Self {
120        Self(SymTensorKind::Vector(vec))
121    }
122
123    /// Create a new symbolic scalar.
124    pub fn from_scalar(item: SymExpr) -> Self {
125        Self(SymTensorKind::Scalar(item))
126    }
127
128    /// Return this tensor's single element, if it is a scalar.
129    pub fn as_scalar(&self) -> Option<&SymExpr> {
130        match &self.0 {
131            SymTensorKind::Scalar(item) => Some(item),
132            _ => None,
133        }
134    }
135
136    /// Return this tensor's values as a slice, if it is a vector.
137    pub fn as_vector(&self) -> Option<&[SymExpr]> {
138        match &self.0 {
139            SymTensorKind::Vector(vec) => Some(vec),
140            _ => None,
141        }
142    }
143
144    /// Return this tensor's fixed values, if it is a scalar or a vector and
145    /// all values are fixed.
146    pub fn to_constant(&self) -> Option<Constant> {
147        match &self.0 {
148            SymTensorKind::Scalar(val) => match val {
149                SymExpr::Value(v) => Some(Constant::Scalar(*v)),
150                _ => None,
151            },
152            SymTensorKind::Vector(vec) => {
153                let values = vec
154                    .iter()
155                    .map(|v| match v {
156                        SymExpr::Value(v) => Some(*v),
157                        _ => None,
158                    })
159                    .collect::<Option<Vec<i32>>>()?;
160                Some(Constant::Vector(values))
161            }
162            SymTensorKind::Shape(_) | SymTensorKind::Unknown { .. } => None,
163        }
164    }
165
166    /// Return the number of dimensions, if known.
167    pub fn ndim(&self) -> Option<usize> {
168        match &self.0 {
169            SymTensorKind::Scalar(_) => Some(0),
170            SymTensorKind::Vector(_) => Some(1),
171            SymTensorKind::Shape(val) => Some(val.len()),
172            SymTensorKind::Unknown { .. } => None,
173        }
174    }
175
176    /// Return the size of the index'th dimension.
177    ///
178    /// Returns `None` if the index is out of bounds or the tensor's shape
179    /// is unknown.
180    pub fn size(&self, index: usize) -> Option<SymExpr> {
181        match &self.0 {
182            SymTensorKind::Scalar(_) => None,
183            SymTensorKind::Vector(val) => {
184                if index == 0 {
185                    Some(SymExpr::Value(val.len() as i32))
186                } else {
187                    None
188                }
189            }
190            SymTensorKind::Shape(val) => val.get(index).cloned(),
191            SymTensorKind::Unknown { .. } => None,
192        }
193    }
194
195    /// Return an iterator over the dimensions or `None` if unknown.
196    pub fn shape(&self) -> Option<impl ExactSizeIterator<Item = SymExpr> + Clone> {
197        let ndim = self.ndim()?;
198        let dims = (0..ndim).map(|d| self.size(d).unwrap());
199        Some(dims)
200    }
201
202    /// Return the symbolic values in this tensor, or `None` if unknown.
203    pub fn values(&self) -> Option<&[SymExpr]> {
204        match &self.0 {
205            SymTensorKind::Scalar(item) => Some(std::slice::from_ref(item)),
206            SymTensorKind::Vector(val) => Some(val),
207            SymTensorKind::Shape(_) | SymTensorKind::Unknown { .. } => None,
208        }
209    }
210
211    /// Simplify symbolic expressions in this tensor.
212    ///
213    /// See [`SymExpr::simplify`].
214    pub fn simplify(self) -> Self {
215        match self.0 {
216            SymTensorKind::Scalar(item) => Self::from_scalar(item.simplify()),
217            SymTensorKind::Vector(vec) => {
218                Self::from_vec(vec.into_iter().map(|x| x.simplify()).collect())
219            }
220            SymTensorKind::Shape(shape) => {
221                Self::from_shape(shape.into_iter().map(|d| d.simplify()).collect())
222            }
223            _ => self,
224        }
225    }
226}
227
228#[cfg(test)]
229pub(crate) use tests::{sym_elems, sym_shape, sym_vec};
230
231#[cfg(test)]
232mod tests {
233    use super::{SymExpr, SymTensor};
234
235    /// Create a `Vec<SymExpr>` from a list of symbol names and values.
236    macro_rules! sym_elems {
237        ($($x:expr),* $(,)?) => {
238            vec![$(SymExpr::from($x)),*]
239        };
240    }
241
242    /// Create a symbolic vector from a list of symbol names and values.
243    macro_rules! sym_vec {
244        ($($x:expr),* $(,)?) => {
245            SymTensor::from_vec(vec![$(SymExpr::from($x)),*])
246        };
247    }
248
249    /// Create a symbolic shape from a list of symbol names and values.
250    macro_rules! sym_shape {
251        ($($x:expr),* $(,)?) => {
252            SymTensor::from_shape(vec![$(SymExpr::from($x)),*])
253        };
254    }
255
256    pub(crate) use {sym_elems, sym_shape, sym_vec};
257
258    #[test]
259    fn test_scalar() {
260        let x = SymTensor::from_scalar("x".into());
261        assert_eq!(x.ndim(), Some(0));
262        assert_eq!(x.size(0), None);
263        assert_eq!(x.values(), Some(["x".into()].as_slice()));
264    }
265
266    #[test]
267    fn test_vector() {
268        let x = SymTensor::from_vec(vec!["x".into(), 2.into()]);
269        assert_eq!(x.ndim(), Some(1));
270        assert_eq!(x.size(0), Some(2.into()));
271        assert_eq!(x.size(1), None);
272        assert_eq!(x.values(), Some(["x".into(), 2.into()].as_slice()));
273    }
274
275    #[test]
276    fn test_tensor_with_shape() {
277        let x = SymTensor::from_shape(vec!["x".into(), 2.into()]);
278        assert_eq!(x.ndim(), Some(2));
279        assert_eq!(x.size(0), Some("x".into()));
280        assert_eq!(x.size(1), Some(2.into()));
281        assert_eq!(x.size(2), None);
282        assert_eq!(x.values(), None);
283        assert_eq!(
284            x.shape().unwrap().collect::<Vec<_>>(),
285            vec!["x".into(), 2.into()]
286        );
287    }
288    #[test]
289    fn test_simplify() {
290        // Simplify a shape
291        let matrix = SymTensor::from_shape(vec![
292            SymExpr::pos_var("rows") + SymExpr::from(0),
293            SymExpr::pos_var("cols") * SymExpr::from(1),
294        ])
295        .simplify();
296        assert_eq!(
297            matrix.shape().unwrap().collect::<Vec<_>>(),
298            vec!["rows".into(), "cols".into(),]
299        );
300
301        // Simplify a scalar
302        let x = SymExpr::var("x");
303        let add_expr = x.clone() + SymExpr::from(0);
304        let scalar = SymTensor::from_scalar(add_expr.clone()).simplify();
305        assert_eq!(scalar.as_scalar().unwrap(), &x);
306
307        // Simplify a vector
308        let vec = SymTensor::from_vec(vec![add_expr.clone(), add_expr.clone()]).simplify();
309        assert_eq!(vec.as_vector().unwrap(), [x.clone(), x.clone()]);
310    }
311
312    #[test]
313    fn test_unknown_shape() {
314        let x = SymTensor::unknown("missing input shape");
315        assert!(x.shape().is_none());
316        assert_eq!(x.values(), None);
317    }
318}