rten_shape_inference/
sym_expr.rs

1//! Symbolic expressions representing integer values.
2
3use std::cmp::Ordering;
4use std::fmt;
5use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
6use std::rc::Rc;
7
8/// A named variable.
9///
10/// The variable may carry assumptions about its value, such as being >= 0.
11///
12/// Two symbols are equal if they have the same name.
13#[derive(Clone, PartialEq)]
14pub struct Symbol {
15    pub name: String,
16
17    // True if this value is assumed to be >= 0.
18    pub positive: bool,
19}
20
21/// Symbolic expression representing an integer value.
22///
23/// Expressions can be known integer values, named symbols or composite
24/// expressions.
25#[derive(Clone)]
26pub enum SymExpr {
27    /// Element with a known integer value.
28    Value(i32),
29    /// Symbolic value
30    Var(Rc<Symbol>),
31    /// Addition of two symbolic values
32    Add((Rc<SymExpr>, Rc<SymExpr>)),
33    /// Subtraction of two symbolic values
34    Sub((Rc<SymExpr>, Rc<SymExpr>)),
35    /// Multiplication of two symbolic values
36    Mul((Rc<SymExpr>, Rc<SymExpr>)),
37    /// Flooring division of first expression by second.
38    Div((Rc<SymExpr>, Rc<SymExpr>)),
39    /// Ceiling division of first expression by second.
40    DivCeil((Rc<SymExpr>, Rc<SymExpr>)),
41    /// Maximum of two symbolic values
42    Max((Rc<SymExpr>, Rc<SymExpr>)),
43    /// Minimum of two symbolic values
44    Min((Rc<SymExpr>, Rc<SymExpr>)),
45    /// Broadcast two symbolic values.
46    ///
47    /// This behaves like `Max`, except it implies that both expressions are
48    /// positive and either equal or 1.
49    Broadcast((Rc<SymExpr>, Rc<SymExpr>)),
50    /// Negation of a value
51    Neg(Rc<SymExpr>),
52}
53
54impl SymExpr {
55    /// Return the range of possible values this element may have.
56    pub fn range(&self) -> (i32, i32) {
57        match self {
58            Self::Value(x) => (*x, *x),
59            Self::Var(sym) => {
60                if sym.positive {
61                    (0, i32::MAX)
62                } else {
63                    (i32::MIN, i32::MAX)
64                }
65            }
66            Self::Neg(x) => {
67                if x.is_positive() {
68                    (i32::MIN, -1)
69                } else {
70                    (i32::MIN, i32::MAX)
71                }
72            }
73            Self::Add((lhs, rhs))
74            | Self::Mul((lhs, rhs))
75            | Self::Max((lhs, rhs))
76            | Self::Min((lhs, rhs))
77            | Self::Div((lhs, rhs))
78            | Self::DivCeil((lhs, rhs)) => {
79                let (lhs_min, lhs_max) = lhs.range();
80                let (rhs_min, rhs_max) = rhs.range();
81                (lhs_min.min(rhs_min), lhs_max.max(rhs_max))
82            }
83            Self::Sub((_lhs, _rhs)) => {
84                // Note: Unlike for addition, subtraction involving two
85                // positive symbols may produce a negative result.
86                (i32::MIN, i32::MAX)
87            }
88            Self::Broadcast((lhs, rhs)) => {
89                let (lhs_min, lhs_max) = lhs.range();
90                let (rhs_min, rhs_max) = rhs.range();
91                (lhs_min.min(rhs_min).max(0), lhs_max.max(rhs_max).max(0))
92            }
93        }
94    }
95
96    /// Return true if the value of this expression is known to be >= 0.
97    pub fn is_positive(&self) -> bool {
98        match self {
99            Self::Value(x) => *x >= 0,
100            Self::Var(sym) => sym.positive,
101            Self::Neg(_expr) => false,
102            Self::Add((lhs, rhs)) => lhs.is_positive() && rhs.is_positive(),
103            Self::Sub((_lhs, _rhs)) => false,
104            Self::Mul((lhs, rhs)) => lhs.is_positive() && rhs.is_positive(),
105            Self::Div((lhs, rhs)) | Self::DivCeil((lhs, rhs)) => {
106                lhs.is_positive() && rhs.is_positive()
107            }
108            Self::Max((lhs, rhs)) => lhs.is_positive() || rhs.is_positive(),
109            Self::Min((lhs, rhs)) => lhs.is_positive() && rhs.is_positive(),
110            Self::Broadcast(_) => true,
111        }
112    }
113
114    /// Return the maximum of `self` and `other`.
115    pub fn max(&self, other: &SymExpr) -> SymExpr {
116        Self::Max((self.clone().into(), other.clone().into()))
117    }
118
119    /// Return the minimum of `self` and `other`.
120    pub fn min(&self, other: &SymExpr) -> SymExpr {
121        Self::Min((self.clone().into(), other.clone().into()))
122    }
123
124    /// Return the result of broadcasting `self` and `other`.
125    pub fn broadcast(&self, other: &SymExpr) -> SymExpr {
126        Self::Broadcast((self.clone().into(), other.clone().into()))
127    }
128
129    /// Return the result of dividing `self` by `other`, rounded up.
130    pub fn div_ceil(&self, other: &SymExpr) -> SymExpr {
131        Self::DivCeil((self.clone().into(), other.clone().into()))
132    }
133
134    fn is_value(&self) -> bool {
135        matches!(self, Self::Value(_))
136    }
137
138    // Re-order and re-associate operands of commutative and associative
139    // operations so that constants are on the left or "canonical order".
140    //
141    // For example `Mul(Mul(a, 2), Mul(b, 3))` becomes
142    // `Mul(Mul(2, 3), Mul(a, b))`.
143    fn canonicalize(&self) -> SymExpr {
144        fn collect_terms(
145            terms: &mut Vec<SymExpr>,
146            term: &SymExpr,
147            extract_lhs_rhs: &impl Fn(&SymExpr) -> Option<&(Rc<SymExpr>, Rc<SymExpr>)>,
148        ) {
149            if let Some((lhs, rhs)) = extract_lhs_rhs(term) {
150                collect_terms(terms, lhs, extract_lhs_rhs);
151                collect_terms(terms, rhs, extract_lhs_rhs);
152            } else {
153                terms.push(term.canonicalize())
154            }
155        }
156
157        // Re-associate and simplify terms in a nested associative expression.
158        //
159        // This operates in 4 steps:
160        //
161        // 1. Collect all the terms in nested expressions of the same type
162        // 2. Sort the terms in canonical order
163        // 3. Simplify the result by removing any redundant terms
164        // 4. Fold the terms back into a new expression
165        fn reassociate_terms(
166            term: &SymExpr,
167            extract_terms: &impl Fn(&SymExpr) -> Option<&(Rc<SymExpr>, Rc<SymExpr>)>,
168            simplify: impl Fn(Vec<SymExpr>) -> Vec<SymExpr>,
169            init: SymExpr,
170            fold: impl Fn(SymExpr, SymExpr) -> SymExpr,
171        ) -> SymExpr {
172            let mut terms = Vec::new();
173            collect_terms(&mut terms, term, extract_terms);
174            terms.sort_by(cmp_values_first);
175            let terms = simplify(terms);
176            terms.into_iter().fold(init, fold)
177        }
178
179        // Remove adjacent equal terms.
180        //
181        // This is a simplification for idempotent operations
182        // (eg. max(x, max(x, y)) => max(x, y)).
183        let remove_adjacent_equal_terms = |mut terms: Vec<SymExpr>| {
184            let mut idx = 0;
185            while idx < terms.len().saturating_sub(1) {
186                if terms[idx] == terms[idx + 1].clone() {
187                    terms.remove(idx);
188                } else {
189                    idx += 1;
190                }
191            }
192            terms
193        };
194
195        match self {
196            Self::Value(_) | Self::Var(_) => self.clone(),
197            Self::Neg(expr) => Self::Neg(expr.canonicalize().into()),
198            Self::Mul(_) => reassociate_terms(
199                self,
200                &|term| {
201                    if let Self::Mul(inner) = term {
202                        Some(inner)
203                    } else {
204                        None
205                    }
206                },
207                |terms| terms,
208                SymExpr::Value(1),
209                |prod, x| prod * x,
210            ),
211            Self::Add(_) => {
212                // Remove adjacent terms which cancel.
213                let remove_adjacent_opposite_terms = |mut terms: Vec<SymExpr>| {
214                    let mut idx = 0;
215                    while idx < terms.len().saturating_sub(1) {
216                        if terms[idx].is_negation_of(&terms[idx + 1]) {
217                            terms.remove(idx);
218                            terms.remove(idx);
219                        } else {
220                            idx += 1;
221                        }
222                    }
223                    terms
224                };
225
226                reassociate_terms(
227                    self,
228                    &|term| match term {
229                        Self::Add(inner) => Some(inner),
230                        _ => None,
231                    },
232                    remove_adjacent_opposite_terms,
233                    SymExpr::Value(0),
234                    |sum, x| sum + x,
235                )
236            }
237            Self::Max(_) => reassociate_terms(
238                self,
239                &|term| match term {
240                    Self::Max(inner) => Some(inner),
241                    _ => None,
242                },
243                remove_adjacent_equal_terms,
244                SymExpr::Value(i32::MIN),
245                |max, x| max.max(&x),
246            ),
247            Self::Min(_) => reassociate_terms(
248                self,
249                &|term| match term {
250                    Self::Min(inner) => Some(inner),
251                    _ => None,
252                },
253                remove_adjacent_equal_terms,
254                SymExpr::Value(i32::MAX),
255                |min, x| min.min(&x),
256            ),
257            Self::Sub((lhs, rhs)) => {
258                // Rewrite `x - y` as `x + (-y)`. This makes it easier to
259                // simplify expressions by canceling opposite terms.
260                let lhs = lhs.canonicalize();
261                let rhs = rhs.canonicalize();
262                Self::Add((lhs.into(), (-rhs).into())).canonicalize()
263            }
264            Self::Div((lhs, rhs)) => {
265                let lhs = lhs.canonicalize();
266                let rhs = rhs.canonicalize();
267                Self::Div((lhs.into(), rhs.into()))
268            }
269            Self::DivCeil((lhs, rhs)) => {
270                let lhs = lhs.canonicalize();
271                let rhs = rhs.canonicalize();
272                Self::DivCeil((lhs.into(), rhs.into()))
273            }
274            Self::Broadcast(_) => reassociate_terms(
275                self,
276                &|term| match term {
277                    Self::Broadcast(inner) => Some(inner),
278                    _ => None,
279                },
280                remove_adjacent_equal_terms,
281                SymExpr::Value(1),
282                |result, x| result.broadcast(&x),
283            ),
284        }
285    }
286
287    /// Simplify an expression.
288    ///
289    /// This simplifies expressions such as identities (eg. `x + 0` becomes `x`).
290    pub fn simplify(&self) -> SymExpr {
291        self.canonicalize().simplify_canonical()
292    }
293
294    /// Simplify an expression which is assumed to have been put in canonical
295    /// form by [`canonicalize`](Self::canonicalize).
296    fn simplify_canonical(&self) -> SymExpr {
297        match self {
298            Self::Value(_) | Self::Var(_) => self.clone(),
299            Self::Neg(expr) => match expr.simplify_canonical() {
300                SymExpr::Value(x) => SymExpr::Value(-x),
301                expr => Self::Neg(expr.into()),
302            },
303            Self::Add((lhs, rhs)) => {
304                let lhs = lhs.simplify_canonical();
305                let rhs = rhs.simplify_canonical();
306
307                match (lhs, rhs) {
308                    (SymExpr::Value(0), rhs) => rhs,
309                    (lhs, SymExpr::Value(0)) => lhs,
310                    (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x + y),
311                    (lhs, SymExpr::Neg(rhs)) if lhs == *rhs => SymExpr::Value(0),
312                    (lhs, rhs) => lhs + rhs,
313                }
314            }
315            Self::Sub((lhs, rhs)) => {
316                let lhs = lhs.simplify_canonical();
317                let rhs = rhs.simplify_canonical();
318
319                match (lhs, rhs) {
320                    (lhs, SymExpr::Value(0)) => lhs,
321                    (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x - y),
322                    (lhs, rhs) if lhs == rhs => SymExpr::Value(0),
323                    (lhs, rhs) => lhs - rhs,
324                }
325            }
326            Self::Mul((lhs, rhs)) => {
327                let lhs = lhs.simplify_canonical();
328                let rhs = rhs.simplify_canonical();
329
330                match (lhs, rhs) {
331                    (SymExpr::Value(1), rhs) => rhs,
332                    (lhs, SymExpr::Value(1)) => lhs,
333                    (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x * y),
334                    (lhs, rhs) => lhs * rhs,
335                }
336            }
337            Self::Div((lhs, rhs)) => {
338                let lhs = lhs.simplify_canonical();
339                let rhs = rhs.simplify_canonical();
340
341                match (lhs, rhs) {
342                    (lhs, SymExpr::Value(1)) => lhs,
343                    (SymExpr::Value(x), SymExpr::Value(y)) if y != 0 => SymExpr::Value(x / y),
344                    // x/x => 1
345                    //
346                    // Where we assume the RHS is non-zero.
347                    //
348                    // This is a special case of canceling common terms. The
349                    // more general case (eg. XY / XZ => Y/Z) still needs to
350                    // be implemented.
351                    (lhs, rhs) if lhs == rhs => SymExpr::Value(1),
352
353                    // x / b / c => x / (b * c)
354                    (SymExpr::Div((lhs, c1)), c2) => match (&*c1, c2) {
355                        (SymExpr::Value(c1), SymExpr::Value(c2)) if *c1 != 0 && c2 != 0 => {
356                            (*lhs).clone() / SymExpr::Value(c1 * c2)
357                        }
358                        (c1, c2) => (*lhs).clone() / (c1.clone() * c2),
359                    },
360                    (lhs, rhs) => lhs / rhs,
361                }
362            }
363            Self::DivCeil((lhs, rhs)) => {
364                let lhs = lhs.simplify_canonical();
365                let rhs = rhs.simplify_canonical();
366
367                match (lhs, rhs) {
368                    (lhs, SymExpr::Value(1)) => lhs,
369                    (SymExpr::Value(x), SymExpr::Value(y)) if y != 0 => {
370                        SymExpr::Value(div_ceil(x, y))
371                    }
372                    // x/x => 1
373                    //
374                    // Where we assume the RHS is non-zero.
375                    //
376                    // This is a special case of canceling common terms. The
377                    // more general case (eg. XY / XZ => Y/Z) still needs to
378                    // be implemented.
379                    (lhs, rhs) if lhs == rhs => SymExpr::Value(1),
380
381                    // x.div_ceil(b).div_ceil(c) => x.div_ceil(b * c) if b > 0
382                    // and c > 0.
383                    (SymExpr::DivCeil((lhs, c1)), c2) => match (&*c1, c2) {
384                        (SymExpr::Value(c1), SymExpr::Value(c2)) if *c1 > 0 && c2 > 0 => {
385                            lhs.div_ceil(&SymExpr::Value(c1 * c2))
386                        }
387                        (c1, c2) => lhs.div_ceil(&(c1.clone() * c2)),
388                    },
389                    (lhs, rhs) => lhs.div_ceil(&rhs),
390                }
391            }
392            Self::Max((lhs, rhs)) => {
393                let lhs = lhs.simplify_canonical();
394                let rhs = rhs.simplify_canonical();
395
396                if lhs == rhs {
397                    lhs
398                } else {
399                    match (lhs, rhs) {
400                        (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x.max(y)),
401                        (lhs, rhs) => Self::Max((lhs.into(), rhs.into())),
402                    }
403                }
404            }
405            Self::Min((lhs, rhs)) => {
406                let lhs = lhs.simplify_canonical();
407                let rhs = rhs.simplify_canonical();
408
409                if lhs == rhs {
410                    lhs
411                } else {
412                    match (lhs, rhs) {
413                        (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x.min(y)),
414                        (lhs, rhs) => Self::Min((lhs.into(), rhs.into())),
415                    }
416                }
417            }
418            Self::Broadcast((lhs, rhs)) => {
419                let lhs = lhs.simplify_canonical();
420                let rhs = rhs.simplify_canonical();
421
422                match (lhs, rhs) {
423                    (SymExpr::Value(x), SymExpr::Value(y)) if x == y => SymExpr::Value(x),
424                    (SymExpr::Value(1), y) => y,
425                    (x, SymExpr::Value(1)) => x,
426                    (SymExpr::Value(x), y) if x != 1 => SymExpr::Value(x),
427                    (x, SymExpr::Value(y)) if y != 1 => SymExpr::Value(y),
428                    (lhs, rhs) if lhs == rhs => lhs,
429                    (lhs, rhs) => SymExpr::Broadcast((lhs.into(), rhs.into())),
430                }
431            }
432        }
433    }
434
435    /// Return the precedence of the operator.
436    ///
437    /// This is used to add parentheses when formatting an expression tree.
438    fn precedence(&self) -> u8 {
439        match self {
440            // Functions and atomic values have the maximum precedence, so they
441            // never need to be wrapped in parens when formatting an expression.
442            Self::Value(_) | Self::Var(_) | Self::Max(_) | Self::Min(_) | Self::Broadcast(_) => 4,
443            Self::Div(_) | Self::DivCeil(_) => 3,
444            Self::Mul(_) => 2,
445            Self::Add(_) => 1,
446            Self::Sub(_) | Self::Neg(_) => 0,
447        }
448    }
449
450    /// Create a named symbol, with no assumptions about the value.
451    pub fn var(name: &str) -> Self {
452        SymExpr::Var(
453            Symbol {
454                name: name.to_string(),
455                positive: false,
456            }
457            .into(),
458        )
459    }
460
461    /// Create a named symbol representing a positive value (ie. `>= 0`).
462    pub fn pos_var(name: &str) -> Self {
463        SymExpr::Var(
464            Symbol {
465                name: name.to_string(),
466                positive: true,
467            }
468            .into(),
469        )
470    }
471
472    /// Compute `self / rhs` as an expression, or return `None` if an exact
473    /// division is not possible.
474    pub fn exact_div(&self, rhs: &SymExpr) -> Option<SymExpr> {
475        let lhs = self;
476        match (lhs, rhs) {
477            // Fixed values
478            (SymExpr::Value(lhs), SymExpr::Value(rhs)) => {
479                if *rhs != 0 && lhs % rhs == 0 {
480                    Some(SymExpr::Value(lhs / rhs))
481                } else {
482                    None
483                }
484            }
485            // Identities
486            (lhs, rhs) if lhs == rhs => Some(SymExpr::Value(1)),
487            (lhs, SymExpr::Value(1)) => Some(lhs.clone()),
488            // If LHS is a product, recurse
489            (SymExpr::Mul((lhs_a, lhs_b)), rhs) => {
490                if let Some(new_lhs_a) = lhs_a.exact_div(rhs) {
491                    Some(SymExpr::Mul((new_lhs_a.into(), lhs_b.clone())))
492                } else {
493                    lhs_b
494                        .exact_div(rhs)
495                        .map(|new_lhs_b| SymExpr::Mul((lhs_a.clone(), new_lhs_b.into())))
496                }
497            }
498            _ => None,
499        }
500    }
501
502    /// Return the name of the symbol in a unary expression.
503    ///
504    /// Returns `None` if the expression is not unary or has a fixed value.
505    fn name(&self) -> Option<&str> {
506        match self {
507            SymExpr::Value(_) => None,
508            SymExpr::Var(sym) => Some(&sym.name),
509            SymExpr::Neg(x) => x.name(),
510            SymExpr::Add(_)
511            | SymExpr::Sub(_)
512            | SymExpr::Mul(_)
513            | SymExpr::Div(_)
514            | SymExpr::DivCeil(_)
515            | SymExpr::Max(_)
516            | SymExpr::Min(_)
517            | SymExpr::Broadcast(_) => None,
518        }
519    }
520
521    /// Return true if `self` and `other` are negations of each other, meaning
522    /// that adding the two terms together will produce zero.
523    fn is_negation_of(&self, other: &SymExpr) -> bool {
524        match (self, other) {
525            (x, SymExpr::Neg(y)) if *x == **y => true,
526            (SymExpr::Neg(x), y) if **x == *y => true,
527            _ => false,
528        }
529    }
530}
531
532/// Sort terms in an order that makes simplification easier, by making terms
533/// which can be combined or eliminated adjacent.
534fn cmp_values_first(a: &SymExpr, b: &SymExpr) -> Ordering {
535    match (a.is_value(), b.is_value()) {
536        (true, false) => Ordering::Less,
537        (false, true) => Ordering::Greater,
538        _ => match (a.name(), b.name()) {
539            (Some(a_name), Some(b_name)) => a_name.cmp(b_name),
540            (Some(_), None) => Ordering::Less,
541            (None, Some(_)) => Ordering::Greater,
542            _ => Ordering::Equal,
543        },
544    }
545}
546
547impl PartialEq<SymExpr> for SymExpr {
548    fn eq(&self, other: &SymExpr) -> bool {
549        let commutative_eq = |self_lhs, self_rhs, other_lhs, other_rhs| {
550            (self_lhs == other_lhs && self_rhs == other_rhs)
551                || (self_lhs == other_rhs && self_rhs == other_lhs)
552        };
553
554        // Symbols are equal if they have the same value or the same name.
555        match self {
556            Self::Value(x) => match other {
557                Self::Value(y) => x == y,
558                _ => false,
559            },
560            Self::Var(x) => match other {
561                Self::Var(y) => x.name == y.name,
562                _ => false,
563            },
564            Self::Neg(x) => match other {
565                Self::Neg(y) => x == y,
566                _ => false,
567            },
568            Self::Add((a, b)) => match other {
569                Self::Add((c, d)) => commutative_eq(a, b, c, d),
570                _ => false,
571            },
572            Self::Mul((a, b)) => match other {
573                Self::Mul((c, d)) => commutative_eq(a, b, c, d),
574                _ => false,
575            },
576            Self::Max((a, b)) => match other {
577                Self::Max((c, d)) => commutative_eq(a, b, c, d),
578                _ => false,
579            },
580            Self::Min((a, b)) => match other {
581                Self::Min((c, d)) => commutative_eq(a, b, c, d),
582                _ => false,
583            },
584            Self::Sub((a, b)) => match other {
585                Self::Sub((c, d)) => a == c && b == d,
586                _ => false,
587            },
588            Self::Div((a, b)) => match other {
589                Self::Div((c, d)) => a == c && b == d,
590                _ => false,
591            },
592            Self::DivCeil((a, b)) => match other {
593                Self::DivCeil((c, d)) => a == c && b == d,
594                _ => false,
595            },
596            Self::Broadcast((a, b)) => match other {
597                Self::Broadcast((c, d)) => commutative_eq(a, b, c, d),
598                _ => false,
599            },
600        }
601    }
602}
603
604impl Add<SymExpr> for SymExpr {
605    type Output = SymExpr;
606
607    fn add(self, rhs: SymExpr) -> Self {
608        Self::Add((self.into(), rhs.into()))
609    }
610}
611
612impl Sub<SymExpr> for SymExpr {
613    type Output = SymExpr;
614
615    fn sub(self, rhs: SymExpr) -> Self {
616        Self::Sub((self.into(), rhs.into()))
617    }
618}
619
620impl AddAssign<SymExpr> for SymExpr {
621    fn add_assign(&mut self, rhs: SymExpr) {
622        *self = Self::Add((self.clone().into(), rhs.into()));
623    }
624}
625
626impl Mul<SymExpr> for SymExpr {
627    type Output = SymExpr;
628
629    fn mul(self, rhs: SymExpr) -> Self {
630        Self::Mul((self.into(), rhs.into()))
631    }
632}
633
634impl Div<SymExpr> for SymExpr {
635    type Output = SymExpr;
636
637    fn div(self, rhs: SymExpr) -> Self {
638        Self::Div((self.into(), rhs.into()))
639    }
640}
641
642impl Neg for SymExpr {
643    type Output = SymExpr;
644
645    fn neg(self) -> Self {
646        Self::Neg(self.into())
647    }
648}
649
650impl From<Symbol> for SymExpr {
651    fn from(val: Symbol) -> Self {
652        Self::Var(val.into())
653    }
654}
655
656/// Create a symbol with a given name and an assumption that the value is
657/// positive (`>= 0`).
658///
659/// The rationale for the positivity assumption is that during shape inference,
660/// the most common use of symbols is to represent dimension sizes.
661impl<'a> From<&'a str> for SymExpr {
662    fn from(name: &'a str) -> Self {
663        SymExpr::Var(
664            Symbol {
665                name: name.to_string(),
666                positive: true,
667            }
668            .into(),
669        )
670    }
671}
672
673impl From<i32> for SymExpr {
674    fn from(val: i32) -> Self {
675        SymExpr::Value(val)
676    }
677}
678
679impl fmt::Debug for SymExpr {
680    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681        let add_parens = |f: &mut fmt::Formatter<'_>, expr: &SymExpr| {
682            if expr.precedence() < self.precedence() {
683                write!(f, "({:?})", expr)
684            } else {
685                write!(f, "{:?}", expr)
686            }
687        };
688        let write_binop = |f: &mut fmt::Formatter<'_>, op, lhs, rhs| {
689            add_parens(f, lhs)?;
690            write!(f, " {op} ")?;
691            add_parens(f, rhs)
692        };
693        match self {
694            Self::Value(val) => write!(f, "{}", val),
695            Self::Var(sym) => write!(
696                f,
697                "\"{}\"{}",
698                sym.name,
699                if sym.positive { 'u' } else { 'i' }
700            ),
701            // nb. No space between "-" and expression to make formatting
702            // distinct from subtraction.
703            Self::Neg(expr) => write!(f, "-{:?}", expr),
704            Self::Add((lhs, rhs)) => write_binop(f, '+', lhs, rhs),
705            Self::Sub((lhs, rhs)) => write_binop(f, '-', lhs, rhs),
706            Self::Mul((lhs, rhs)) => write_binop(f, '*', lhs, rhs),
707            Self::Div((lhs, rhs)) => write_binop(f, '/', lhs, rhs),
708            Self::DivCeil((lhs, rhs)) => write!(f, "ceil_div({:?}, {:?})", lhs, rhs),
709            Self::Max((lhs, rhs)) => write!(f, "max({:?}, {:?})", lhs, rhs),
710            Self::Min((lhs, rhs)) => write!(f, "min({:?}, {:?})", lhs, rhs),
711            Self::Broadcast((lhs, rhs)) => write!(f, "broadcast({:?}, {:?})", lhs, rhs),
712        }
713    }
714}
715
716impl fmt::Display for SymExpr {
717    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
718        let add_parens = |f: &mut fmt::Formatter<'_>, expr: &SymExpr| {
719            if expr.precedence() < self.precedence() {
720                write!(f, "({})", expr)
721            } else {
722                write!(f, "{}", expr)
723            }
724        };
725        let write_binop = |f: &mut fmt::Formatter<'_>, op, lhs, rhs| {
726            add_parens(f, lhs)?;
727            write!(f, " {op} ")?;
728            add_parens(f, rhs)
729        };
730        match self {
731            Self::Value(val) => write!(f, "{}", val),
732            Self::Var(sym) => write!(f, "{}", sym.name),
733            // nb. No space between "-" and expression to make formatting
734            // distinct from subtraction.
735            Self::Neg(expr) => write!(f, "-{}", expr),
736            Self::Add((lhs, rhs)) => write_binop(f, '+', lhs, rhs),
737            Self::Sub((lhs, rhs)) => write_binop(f, '-', lhs, rhs),
738            Self::Mul((lhs, rhs)) => write_binop(f, '*', lhs, rhs),
739            Self::Div((lhs, rhs)) => write_binop(f, '/', lhs, rhs),
740            Self::DivCeil((lhs, rhs)) => write!(f, "ceil_div({}, {})", lhs, rhs),
741            Self::Max((lhs, rhs)) => write!(f, "max({}, {})", lhs, rhs),
742            Self::Min((lhs, rhs)) => write!(f, "min({}, {})", lhs, rhs),
743            Self::Broadcast((lhs, rhs)) => write!(f, "broadcast({}, {})", lhs, rhs),
744        }
745    }
746}
747
748/// Copied from unstable [`i32::div_ceil`] in the standard library.
749pub const fn div_ceil(lhs: i32, rhs: i32) -> i32 {
750    let d = lhs / rhs;
751    let r = lhs % rhs;
752
753    // When remainder is non-zero we have a.div_ceil(b) == 1 + a.div_floor(b),
754    // so we can re-use the algorithm from div_floor, just adding 1.
755    let correction = 1 + ((lhs ^ rhs) >> (i32::BITS - 1));
756    if r != 0 { d + correction } else { d }
757}
758
759#[cfg(test)]
760mod tests {
761    use super::SymExpr;
762
763    #[test]
764    fn test_range() {
765        let x = SymExpr::pos_var("x");
766        assert_eq!(x.range(), (0, i32::MAX));
767
768        let y = SymExpr::var("y");
769        assert_eq!(y.range(), (i32::MIN, i32::MAX));
770    }
771
772    #[test]
773    fn test_simplify_add() {
774        let x = SymExpr::pos_var("x");
775        let zero = SymExpr::from(0);
776        let one = SymExpr::from(1);
777
778        let expr = x.clone() + zero.clone();
779        assert_eq!(expr, SymExpr::Add((x.clone().into(), zero.clone().into())));
780        assert_eq!(expr.simplify(), x);
781
782        let expr_2 = x.clone() + one.clone();
783        assert_eq!(
784            expr_2.simplify(),
785            SymExpr::Add((x.clone().into(), one.clone().into()))
786        );
787    }
788
789    // Check `C + X + D` is simplified to `S + X` where C and D are
790    // constants and `S = C+D`.
791    #[test]
792    fn test_simplify_add_reassociate() {
793        let x = SymExpr::from("x");
794        let c1 = SymExpr::from(3);
795        let c2 = SymExpr::from(4);
796
797        // C + X + D => S + X
798        let expr = (x.clone() + c1.clone()) + c2.clone();
799        let simplified = expr.simplify();
800        assert_eq!(simplified, SymExpr::from(7) + x.clone());
801
802        // C + X + D + X => S + X + X
803        let expr = (x.clone() + c1) + (x.clone() + c2);
804        let simplified = expr.simplify();
805        assert_eq!(simplified, SymExpr::from(7) + x.clone() + x);
806    }
807
808    #[test]
809    fn test_simplify_sub() {
810        let x = SymExpr::pos_var("x");
811        let zero = SymExpr::from(0);
812        let one = SymExpr::from(1);
813
814        // x - 0 => x
815        let expr = x.clone() - zero.clone();
816        assert_eq!(expr, SymExpr::Sub((x.clone().into(), zero.clone().into())));
817        assert_eq!(expr.simplify(), x);
818
819        // x - x => 0
820        let expr = x.clone() - x.clone();
821        assert_eq!(expr.simplify(), SymExpr::Value(0));
822
823        // x - 1 => x + (-1)
824        let expr_2 = x.clone() - one.clone();
825        assert_eq!(
826            expr_2.simplify(),
827            SymExpr::Add((x.clone().into(), SymExpr::from(-1).into()))
828        );
829
830        // x + y - x => y
831        let y = SymExpr::pos_var("y");
832        let expr = x.clone() + y.clone() - x.clone();
833        assert_eq!(expr.simplify(), y.clone());
834
835        // x + x + y - x => x + y.
836        let expr = x.clone() + x.clone() + y.clone() - x.clone();
837        assert_eq!(expr.simplify(), x.clone() + y.clone());
838
839        // x + y - x - y => 0
840        let expr = x.clone() + y.clone() - x.clone() - y.clone();
841        assert_eq!(expr.simplify(), 0.into());
842
843        // -x + x => 0
844        let expr = -x.clone() + x.clone();
845        assert_eq!(expr.simplify(), 0.into());
846
847        // x + (-x) => 0
848        let expr = x.clone() + (-x.clone());
849        assert_eq!(expr.simplify(), 0.into());
850
851        // (x + y) - (x + y) => 0
852        let expr = (x.clone() + y.clone()) - (x.clone() + y.clone());
853        assert_eq!(expr.simplify(), 0.into());
854    }
855
856    #[test]
857    fn test_simplify_mul() {
858        let x = SymExpr::pos_var("x");
859        let one = SymExpr::from(1);
860        let two = SymExpr::from(2);
861
862        let expr = x.clone() * one.clone();
863        assert_eq!(expr, SymExpr::Mul((x.clone().into(), one.clone().into())));
864        assert_eq!(expr.simplify(), x);
865
866        let expr_2 = x.clone() * two.clone();
867        assert_eq!(
868            expr_2.simplify(),
869            SymExpr::Mul((x.clone().into(), two.clone().into()))
870        );
871    }
872
873    #[test]
874    fn test_simplify_div() {
875        let x = SymExpr::pos_var("x");
876        let one = SymExpr::from(1);
877        let two = SymExpr::from(2);
878
879        // Constant eval
880        let expr = SymExpr::from(5) / SymExpr::from(2);
881        assert_eq!(expr.simplify(), SymExpr::from(2));
882
883        // Constant with zero divisor
884        let expr = SymExpr::from(5) / SymExpr::from(0);
885        assert_eq!(expr.simplify(), SymExpr::from(5) / SymExpr::from(0));
886
887        // x / 1 => x
888        let expr = x.clone() / one.clone();
889        assert_eq!(expr, SymExpr::Div((x.clone().into(), one.clone().into())));
890        assert_eq!(expr.simplify(), x);
891
892        // x / x => 1
893        let expr = x.clone() / x.clone();
894        assert_eq!(expr.simplify(), one);
895
896        // x / 2 => x / 2
897        let expr_2 = x.clone() / two.clone();
898        assert_eq!(
899            expr_2.simplify(),
900            SymExpr::Div((x.clone().into(), two.clone().into()))
901        );
902
903        // x / 2 / 2 => x / 4
904        let expr = x.clone() / two.clone() / two.clone();
905        assert_eq!(expr.simplify(), x.clone() / SymExpr::from(4));
906
907        // x / 0 / 2 => not simplified (divisor is zero)
908        let zero = SymExpr::from(0);
909        let expr = x.clone() / zero.clone() / two.clone();
910        assert_eq!(expr.simplify(), x.clone() / (zero.clone() * two.clone()));
911
912        // x / 2 / 0 => not simplified (divisor is zero)
913        let expr = x.clone() / two.clone() / zero.clone();
914        assert_eq!(expr.simplify(), x.clone() / (two.clone() * zero));
915    }
916
917    #[test]
918    fn test_simplify_div_ceil() {
919        let x = SymExpr::pos_var("x");
920        let one = SymExpr::from(1);
921        let two = SymExpr::from(2);
922
923        // Constant eval
924        let expr = SymExpr::from(5).div_ceil(&SymExpr::from(2));
925        assert_eq!(expr.simplify(), SymExpr::from(3));
926
927        // Constant with zero divisor
928        let expr = SymExpr::from(5).div_ceil(&SymExpr::from(0));
929        assert_eq!(
930            expr.simplify(),
931            SymExpr::from(5).div_ceil(&SymExpr::from(0))
932        );
933
934        // x / 1 => x
935        let expr = x.clone().div_ceil(&one);
936        assert_eq!(
937            expr,
938            SymExpr::DivCeil((x.clone().into(), one.clone().into()))
939        );
940        assert_eq!(expr.simplify(), x);
941
942        // x / x => 1
943        let expr = x.clone().div_ceil(&x);
944        assert_eq!(expr.simplify(), one);
945
946        // x / 2 => x / 2
947        let expr_2 = x.clone().div_ceil(&two);
948        assert_eq!(
949            expr_2.simplify(),
950            SymExpr::DivCeil((x.clone().into(), two.clone().into()))
951        );
952
953        // x / 2 / 2 => x / 4
954        let expr = x.clone().div_ceil(&two).div_ceil(&two);
955        assert_eq!(expr.simplify(), x.clone().div_ceil(&SymExpr::from(4)));
956
957        // x.div_ceil(0).div_ceil(2) => not simplified (divisor is zero)
958        let zero = SymExpr::from(0);
959        let expr = x.clone().div_ceil(&zero).div_ceil(&two);
960        assert_eq!(
961            expr.simplify(),
962            x.clone().div_ceil(&(zero.clone() * two.clone()))
963        );
964
965        // x.div_ceil(-1).div_ceil(2) => not simplified (divisor is negative)
966        let neg_one = SymExpr::from(-1);
967        let expr = x.clone().div_ceil(&neg_one).div_ceil(&two);
968        assert_eq!(expr.simplify(), x.div_ceil(&(neg_one.clone() * two)));
969    }
970
971    // Check `C * X * D` is simplified to `CD * X` where C and D are
972    // constants.
973    #[test]
974    fn test_simplify_mul_reassociate() {
975        let x = SymExpr::from("x");
976        let c1 = SymExpr::from(3);
977        let c2 = SymExpr::from(4);
978
979        // C * X * D => CD * X
980        let expr = (x.clone() * c1.clone()) * c2.clone();
981        let simplified = expr.simplify();
982        assert_eq!(simplified, SymExpr::from(12) * x.clone());
983
984        // Same as above, but contained inside an addition expression.
985        let expr = SymExpr::from(5) + expr;
986        let simplified = expr.simplify();
987        assert_eq!(simplified, SymExpr::from(5) + SymExpr::from(12) * x.clone());
988
989        // C * X * D * X => CD * X * X
990        let expr = (x.clone() * c1) * (x.clone() * c2);
991        let simplified = expr.simplify();
992        assert_eq!(simplified, SymExpr::from(12) * x.clone() * x);
993    }
994
995    #[test]
996    fn test_simplify_max() {
997        let one = SymExpr::from(1);
998        let two = SymExpr::from(2);
999        let expr = one.max(&two);
1000
1001        assert_eq!(expr, SymExpr::Max((one.clone().into(), two.clone().into())));
1002        assert_eq!(expr.simplify(), two.clone());
1003    }
1004
1005    #[test]
1006    fn test_simplify_nested_max() {
1007        let expr = SymExpr::from(10)
1008            .max(&SymExpr::from(5).max(&SymExpr::from(11)))
1009            .simplify();
1010        assert_eq!(expr, SymExpr::from(11));
1011    }
1012
1013    #[test]
1014    fn test_simplify_min() {
1015        let one = SymExpr::from(1);
1016        let two = SymExpr::from(2);
1017        let expr = one.min(&two);
1018
1019        assert_eq!(expr, SymExpr::Min((one.clone().into(), two.clone().into())));
1020        assert_eq!(expr.simplify(), one.clone());
1021    }
1022
1023    #[test]
1024    fn test_simplify_nested_min() {
1025        let expr = SymExpr::from(10)
1026            .min(&SymExpr::from(5).min(&SymExpr::from(3)))
1027            .simplify();
1028        assert_eq!(expr, SymExpr::from(3));
1029    }
1030
1031    #[test]
1032    fn test_simplify_broadcast() {
1033        let one = SymExpr::from(1);
1034        let ten = SymExpr::from(10);
1035        let foo = SymExpr::from("foo");
1036
1037        // (x, N) where N != 1 => N
1038        assert_eq!(ten.broadcast(&ten).simplify(), ten.clone());
1039        assert_eq!(ten.broadcast(&foo).simplify(), ten.clone());
1040        assert_eq!(one.broadcast(&ten).simplify(), ten.clone());
1041        assert_eq!(ten.broadcast(&one).simplify(), ten.clone());
1042
1043        // (x, N) where N == 1 => x
1044        assert_eq!(foo.broadcast(&one).simplify(), foo.clone());
1045        assert_eq!(one.broadcast(&foo).simplify(), foo.clone());
1046
1047        // (x, x) => x
1048        assert_eq!(foo.broadcast(&foo).simplify(), foo.clone());
1049    }
1050
1051    #[test]
1052    fn test_simplify_nested_broadcast() {
1053        let foo = SymExpr::from("foo");
1054        let ten = SymExpr::from(10);
1055        let expr = foo.broadcast(&foo.broadcast(&ten)).simplify();
1056        assert_eq!(expr, SymExpr::from(10));
1057    }
1058
1059    #[test]
1060    fn test_simplify_neg() {
1061        let minus_one = -SymExpr::from(1);
1062        assert_eq!(minus_one.simplify(), SymExpr::from(-1));
1063    }
1064
1065    #[test]
1066    fn test_display() {
1067        let expr = (SymExpr::from(1) + SymExpr::pos_var("foo")) * SymExpr::from(3)
1068            + SymExpr::from(4)
1069            - SymExpr::from(5);
1070        assert_eq!(expr.to_string(), "(1 + foo) * 3 + 4 - 5");
1071    }
1072
1073    #[test]
1074    fn test_debug() {
1075        let expr = (SymExpr::from(1) + SymExpr::pos_var("foo")) * SymExpr::from(3)
1076            + SymExpr::var("bar")
1077            - SymExpr::from(5);
1078        assert_eq!(format!("{:?}", expr), "(1 + \"foo\"u) * 3 + \"bar\"i - 5");
1079    }
1080
1081    #[test]
1082    fn test_exact_div() {
1083        // Fixed values
1084        assert_eq!(
1085            SymExpr::from(15).exact_div(&SymExpr::from(3)),
1086            Some(SymExpr::from(5))
1087        );
1088        assert_eq!(SymExpr::from(15).exact_div(&SymExpr::from(4)), None);
1089        assert_eq!(SymExpr::from(15).exact_div(&SymExpr::from(0)), None);
1090
1091        // Identities
1092        assert_eq!(
1093            SymExpr::from("x").exact_div(&SymExpr::from("x")),
1094            Some(SymExpr::from(1))
1095        );
1096        assert_eq!(
1097            SymExpr::from("x").exact_div(&SymExpr::from(1)),
1098            Some(SymExpr::from("x"))
1099        );
1100
1101        // Products with common term in LHS and RHS
1102        assert_eq!(
1103            (SymExpr::from("x") * SymExpr::from("y"))
1104                .exact_div(&SymExpr::from("y"))
1105                .map(|s| s.simplify()),
1106            Some(SymExpr::from("x"))
1107        );
1108        assert_eq!(
1109            (SymExpr::from("y") * SymExpr::from("x"))
1110                .exact_div(&SymExpr::from("y"))
1111                .map(|s| s.simplify()),
1112            Some(SymExpr::from("x"))
1113        );
1114
1115        // Cases where result is unknown
1116        assert_eq!(SymExpr::from("x").exact_div(&SymExpr::from("y")), None);
1117    }
1118}