1use std::cmp::Ordering;
4use std::fmt;
5use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
6use std::rc::Rc;
7
8#[derive(Clone, PartialEq)]
14pub struct Symbol {
15 pub name: String,
16
17 pub positive: bool,
19}
20
21#[derive(Clone)]
26pub enum SymExpr {
27 Value(i32),
29 Var(Rc<Symbol>),
31 Add((Rc<SymExpr>, Rc<SymExpr>)),
33 Sub((Rc<SymExpr>, Rc<SymExpr>)),
35 Mul((Rc<SymExpr>, Rc<SymExpr>)),
37 Div((Rc<SymExpr>, Rc<SymExpr>)),
39 DivCeil((Rc<SymExpr>, Rc<SymExpr>)),
41 Max((Rc<SymExpr>, Rc<SymExpr>)),
43 Min((Rc<SymExpr>, Rc<SymExpr>)),
45 Broadcast((Rc<SymExpr>, Rc<SymExpr>)),
50 Neg(Rc<SymExpr>),
52}
53
54impl SymExpr {
55 pub fn range(&self) -> (i32, i32) {
57 match self {
58 Self::Value(x) => (*x, *x),
59 Self::Var(sym) => {
60 if sym.positive {
61 (0, i32::MAX)
62 } else {
63 (i32::MIN, i32::MAX)
64 }
65 }
66 Self::Neg(x) => {
67 if x.is_positive() {
68 (i32::MIN, -1)
69 } else {
70 (i32::MIN, i32::MAX)
71 }
72 }
73 Self::Add((lhs, rhs))
74 | Self::Mul((lhs, rhs))
75 | Self::Max((lhs, rhs))
76 | Self::Min((lhs, rhs))
77 | Self::Div((lhs, rhs))
78 | Self::DivCeil((lhs, rhs)) => {
79 let (lhs_min, lhs_max) = lhs.range();
80 let (rhs_min, rhs_max) = rhs.range();
81 (lhs_min.min(rhs_min), lhs_max.max(rhs_max))
82 }
83 Self::Sub((_lhs, _rhs)) => {
84 (i32::MIN, i32::MAX)
87 }
88 Self::Broadcast((lhs, rhs)) => {
89 let (lhs_min, lhs_max) = lhs.range();
90 let (rhs_min, rhs_max) = rhs.range();
91 (lhs_min.min(rhs_min).max(0), lhs_max.max(rhs_max).max(0))
92 }
93 }
94 }
95
96 pub fn is_positive(&self) -> bool {
98 match self {
99 Self::Value(x) => *x >= 0,
100 Self::Var(sym) => sym.positive,
101 Self::Neg(_expr) => false,
102 Self::Add((lhs, rhs)) => lhs.is_positive() && rhs.is_positive(),
103 Self::Sub((_lhs, _rhs)) => false,
104 Self::Mul((lhs, rhs)) => lhs.is_positive() && rhs.is_positive(),
105 Self::Div((lhs, rhs)) | Self::DivCeil((lhs, rhs)) => {
106 lhs.is_positive() && rhs.is_positive()
107 }
108 Self::Max((lhs, rhs)) => lhs.is_positive() || rhs.is_positive(),
109 Self::Min((lhs, rhs)) => lhs.is_positive() && rhs.is_positive(),
110 Self::Broadcast(_) => true,
111 }
112 }
113
114 pub fn max(&self, other: &SymExpr) -> SymExpr {
116 Self::Max((self.clone().into(), other.clone().into()))
117 }
118
119 pub fn min(&self, other: &SymExpr) -> SymExpr {
121 Self::Min((self.clone().into(), other.clone().into()))
122 }
123
124 pub fn broadcast(&self, other: &SymExpr) -> SymExpr {
126 Self::Broadcast((self.clone().into(), other.clone().into()))
127 }
128
129 pub fn div_ceil(&self, other: &SymExpr) -> SymExpr {
131 Self::DivCeil((self.clone().into(), other.clone().into()))
132 }
133
134 fn is_value(&self) -> bool {
135 matches!(self, Self::Value(_))
136 }
137
138 fn canonicalize(&self) -> SymExpr {
144 fn collect_terms(
145 terms: &mut Vec<SymExpr>,
146 term: &SymExpr,
147 extract_lhs_rhs: &impl Fn(&SymExpr) -> Option<&(Rc<SymExpr>, Rc<SymExpr>)>,
148 ) {
149 if let Some((lhs, rhs)) = extract_lhs_rhs(term) {
150 collect_terms(terms, lhs, extract_lhs_rhs);
151 collect_terms(terms, rhs, extract_lhs_rhs);
152 } else {
153 terms.push(term.canonicalize())
154 }
155 }
156
157 fn reassociate_terms(
166 term: &SymExpr,
167 extract_terms: &impl Fn(&SymExpr) -> Option<&(Rc<SymExpr>, Rc<SymExpr>)>,
168 simplify: impl Fn(Vec<SymExpr>) -> Vec<SymExpr>,
169 init: SymExpr,
170 fold: impl Fn(SymExpr, SymExpr) -> SymExpr,
171 ) -> SymExpr {
172 let mut terms = Vec::new();
173 collect_terms(&mut terms, term, extract_terms);
174 terms.sort_by(cmp_values_first);
175 let terms = simplify(terms);
176 terms.into_iter().fold(init, fold)
177 }
178
179 let remove_adjacent_equal_terms = |mut terms: Vec<SymExpr>| {
184 let mut idx = 0;
185 while idx < terms.len().saturating_sub(1) {
186 if terms[idx] == terms[idx + 1].clone() {
187 terms.remove(idx);
188 } else {
189 idx += 1;
190 }
191 }
192 terms
193 };
194
195 match self {
196 Self::Value(_) | Self::Var(_) => self.clone(),
197 Self::Neg(expr) => Self::Neg(expr.canonicalize().into()),
198 Self::Mul(_) => reassociate_terms(
199 self,
200 &|term| {
201 if let Self::Mul(inner) = term {
202 Some(inner)
203 } else {
204 None
205 }
206 },
207 |terms| terms,
208 SymExpr::Value(1),
209 |prod, x| prod * x,
210 ),
211 Self::Add(_) => {
212 let remove_adjacent_opposite_terms = |mut terms: Vec<SymExpr>| {
214 let mut idx = 0;
215 while idx < terms.len().saturating_sub(1) {
216 if terms[idx].is_negation_of(&terms[idx + 1]) {
217 terms.remove(idx);
218 terms.remove(idx);
219 } else {
220 idx += 1;
221 }
222 }
223 terms
224 };
225
226 reassociate_terms(
227 self,
228 &|term| match term {
229 Self::Add(inner) => Some(inner),
230 _ => None,
231 },
232 remove_adjacent_opposite_terms,
233 SymExpr::Value(0),
234 |sum, x| sum + x,
235 )
236 }
237 Self::Max(_) => reassociate_terms(
238 self,
239 &|term| match term {
240 Self::Max(inner) => Some(inner),
241 _ => None,
242 },
243 remove_adjacent_equal_terms,
244 SymExpr::Value(i32::MIN),
245 |max, x| max.max(&x),
246 ),
247 Self::Min(_) => reassociate_terms(
248 self,
249 &|term| match term {
250 Self::Min(inner) => Some(inner),
251 _ => None,
252 },
253 remove_adjacent_equal_terms,
254 SymExpr::Value(i32::MAX),
255 |min, x| min.min(&x),
256 ),
257 Self::Sub((lhs, rhs)) => {
258 let lhs = lhs.canonicalize();
261 let rhs = rhs.canonicalize();
262 Self::Add((lhs.into(), (-rhs).into())).canonicalize()
263 }
264 Self::Div((lhs, rhs)) => {
265 let lhs = lhs.canonicalize();
266 let rhs = rhs.canonicalize();
267 Self::Div((lhs.into(), rhs.into()))
268 }
269 Self::DivCeil((lhs, rhs)) => {
270 let lhs = lhs.canonicalize();
271 let rhs = rhs.canonicalize();
272 Self::DivCeil((lhs.into(), rhs.into()))
273 }
274 Self::Broadcast(_) => reassociate_terms(
275 self,
276 &|term| match term {
277 Self::Broadcast(inner) => Some(inner),
278 _ => None,
279 },
280 remove_adjacent_equal_terms,
281 SymExpr::Value(1),
282 |result, x| result.broadcast(&x),
283 ),
284 }
285 }
286
287 pub fn simplify(&self) -> SymExpr {
291 self.canonicalize().simplify_canonical()
292 }
293
294 fn simplify_canonical(&self) -> SymExpr {
297 match self {
298 Self::Value(_) | Self::Var(_) => self.clone(),
299 Self::Neg(expr) => match expr.simplify_canonical() {
300 SymExpr::Value(x) => SymExpr::Value(-x),
301 expr => Self::Neg(expr.into()),
302 },
303 Self::Add((lhs, rhs)) => {
304 let lhs = lhs.simplify_canonical();
305 let rhs = rhs.simplify_canonical();
306
307 match (lhs, rhs) {
308 (SymExpr::Value(0), rhs) => rhs,
309 (lhs, SymExpr::Value(0)) => lhs,
310 (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x + y),
311 (lhs, SymExpr::Neg(rhs)) if lhs == *rhs => SymExpr::Value(0),
312 (lhs, rhs) => lhs + rhs,
313 }
314 }
315 Self::Sub((lhs, rhs)) => {
316 let lhs = lhs.simplify_canonical();
317 let rhs = rhs.simplify_canonical();
318
319 match (lhs, rhs) {
320 (lhs, SymExpr::Value(0)) => lhs,
321 (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x - y),
322 (lhs, rhs) if lhs == rhs => SymExpr::Value(0),
323 (lhs, rhs) => lhs - rhs,
324 }
325 }
326 Self::Mul((lhs, rhs)) => {
327 let lhs = lhs.simplify_canonical();
328 let rhs = rhs.simplify_canonical();
329
330 match (lhs, rhs) {
331 (SymExpr::Value(1), rhs) => rhs,
332 (lhs, SymExpr::Value(1)) => lhs,
333 (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x * y),
334 (lhs, rhs) => lhs * rhs,
335 }
336 }
337 Self::Div((lhs, rhs)) => {
338 let lhs = lhs.simplify_canonical();
339 let rhs = rhs.simplify_canonical();
340
341 match (lhs, rhs) {
342 (lhs, SymExpr::Value(1)) => lhs,
343 (SymExpr::Value(x), SymExpr::Value(y)) if y != 0 => SymExpr::Value(x / y),
344 (lhs, rhs) if lhs == rhs => SymExpr::Value(1),
352
353 (SymExpr::Div((lhs, c1)), c2) => match (&*c1, c2) {
355 (SymExpr::Value(c1), SymExpr::Value(c2)) if *c1 != 0 && c2 != 0 => {
356 (*lhs).clone() / SymExpr::Value(c1 * c2)
357 }
358 (c1, c2) => (*lhs).clone() / (c1.clone() * c2),
359 },
360 (lhs, rhs) => lhs / rhs,
361 }
362 }
363 Self::DivCeil((lhs, rhs)) => {
364 let lhs = lhs.simplify_canonical();
365 let rhs = rhs.simplify_canonical();
366
367 match (lhs, rhs) {
368 (lhs, SymExpr::Value(1)) => lhs,
369 (SymExpr::Value(x), SymExpr::Value(y)) if y != 0 => {
370 SymExpr::Value(div_ceil(x, y))
371 }
372 (lhs, rhs) if lhs == rhs => SymExpr::Value(1),
380
381 (SymExpr::DivCeil((lhs, c1)), c2) => match (&*c1, c2) {
384 (SymExpr::Value(c1), SymExpr::Value(c2)) if *c1 > 0 && c2 > 0 => {
385 lhs.div_ceil(&SymExpr::Value(c1 * c2))
386 }
387 (c1, c2) => lhs.div_ceil(&(c1.clone() * c2)),
388 },
389 (lhs, rhs) => lhs.div_ceil(&rhs),
390 }
391 }
392 Self::Max((lhs, rhs)) => {
393 let lhs = lhs.simplify_canonical();
394 let rhs = rhs.simplify_canonical();
395
396 if lhs == rhs {
397 lhs
398 } else {
399 match (lhs, rhs) {
400 (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x.max(y)),
401 (lhs, rhs) => Self::Max((lhs.into(), rhs.into())),
402 }
403 }
404 }
405 Self::Min((lhs, rhs)) => {
406 let lhs = lhs.simplify_canonical();
407 let rhs = rhs.simplify_canonical();
408
409 if lhs == rhs {
410 lhs
411 } else {
412 match (lhs, rhs) {
413 (SymExpr::Value(x), SymExpr::Value(y)) => SymExpr::Value(x.min(y)),
414 (lhs, rhs) => Self::Min((lhs.into(), rhs.into())),
415 }
416 }
417 }
418 Self::Broadcast((lhs, rhs)) => {
419 let lhs = lhs.simplify_canonical();
420 let rhs = rhs.simplify_canonical();
421
422 match (lhs, rhs) {
423 (SymExpr::Value(x), SymExpr::Value(y)) if x == y => SymExpr::Value(x),
424 (SymExpr::Value(1), y) => y,
425 (x, SymExpr::Value(1)) => x,
426 (SymExpr::Value(x), y) if x != 1 => SymExpr::Value(x),
427 (x, SymExpr::Value(y)) if y != 1 => SymExpr::Value(y),
428 (lhs, rhs) if lhs == rhs => lhs,
429 (lhs, rhs) => SymExpr::Broadcast((lhs.into(), rhs.into())),
430 }
431 }
432 }
433 }
434
435 fn precedence(&self) -> u8 {
439 match self {
440 Self::Value(_) | Self::Var(_) | Self::Max(_) | Self::Min(_) | Self::Broadcast(_) => 4,
443 Self::Div(_) | Self::DivCeil(_) => 3,
444 Self::Mul(_) => 2,
445 Self::Add(_) => 1,
446 Self::Sub(_) | Self::Neg(_) => 0,
447 }
448 }
449
450 pub fn var(name: &str) -> Self {
452 SymExpr::Var(
453 Symbol {
454 name: name.to_string(),
455 positive: false,
456 }
457 .into(),
458 )
459 }
460
461 pub fn pos_var(name: &str) -> Self {
463 SymExpr::Var(
464 Symbol {
465 name: name.to_string(),
466 positive: true,
467 }
468 .into(),
469 )
470 }
471
472 pub fn exact_div(&self, rhs: &SymExpr) -> Option<SymExpr> {
475 let lhs = self;
476 match (lhs, rhs) {
477 (SymExpr::Value(lhs), SymExpr::Value(rhs)) => {
479 if *rhs != 0 && lhs % rhs == 0 {
480 Some(SymExpr::Value(lhs / rhs))
481 } else {
482 None
483 }
484 }
485 (lhs, rhs) if lhs == rhs => Some(SymExpr::Value(1)),
487 (lhs, SymExpr::Value(1)) => Some(lhs.clone()),
488 (SymExpr::Mul((lhs_a, lhs_b)), rhs) => {
490 if let Some(new_lhs_a) = lhs_a.exact_div(rhs) {
491 Some(SymExpr::Mul((new_lhs_a.into(), lhs_b.clone())))
492 } else {
493 lhs_b
494 .exact_div(rhs)
495 .map(|new_lhs_b| SymExpr::Mul((lhs_a.clone(), new_lhs_b.into())))
496 }
497 }
498 _ => None,
499 }
500 }
501
502 fn name(&self) -> Option<&str> {
506 match self {
507 SymExpr::Value(_) => None,
508 SymExpr::Var(sym) => Some(&sym.name),
509 SymExpr::Neg(x) => x.name(),
510 SymExpr::Add(_)
511 | SymExpr::Sub(_)
512 | SymExpr::Mul(_)
513 | SymExpr::Div(_)
514 | SymExpr::DivCeil(_)
515 | SymExpr::Max(_)
516 | SymExpr::Min(_)
517 | SymExpr::Broadcast(_) => None,
518 }
519 }
520
521 fn is_negation_of(&self, other: &SymExpr) -> bool {
524 match (self, other) {
525 (x, SymExpr::Neg(y)) if *x == **y => true,
526 (SymExpr::Neg(x), y) if **x == *y => true,
527 _ => false,
528 }
529 }
530}
531
532fn cmp_values_first(a: &SymExpr, b: &SymExpr) -> Ordering {
535 match (a.is_value(), b.is_value()) {
536 (true, false) => Ordering::Less,
537 (false, true) => Ordering::Greater,
538 _ => match (a.name(), b.name()) {
539 (Some(a_name), Some(b_name)) => a_name.cmp(b_name),
540 (Some(_), None) => Ordering::Less,
541 (None, Some(_)) => Ordering::Greater,
542 _ => Ordering::Equal,
543 },
544 }
545}
546
547impl PartialEq<SymExpr> for SymExpr {
548 fn eq(&self, other: &SymExpr) -> bool {
549 let commutative_eq = |self_lhs, self_rhs, other_lhs, other_rhs| {
550 (self_lhs == other_lhs && self_rhs == other_rhs)
551 || (self_lhs == other_rhs && self_rhs == other_lhs)
552 };
553
554 match self {
556 Self::Value(x) => match other {
557 Self::Value(y) => x == y,
558 _ => false,
559 },
560 Self::Var(x) => match other {
561 Self::Var(y) => x.name == y.name,
562 _ => false,
563 },
564 Self::Neg(x) => match other {
565 Self::Neg(y) => x == y,
566 _ => false,
567 },
568 Self::Add((a, b)) => match other {
569 Self::Add((c, d)) => commutative_eq(a, b, c, d),
570 _ => false,
571 },
572 Self::Mul((a, b)) => match other {
573 Self::Mul((c, d)) => commutative_eq(a, b, c, d),
574 _ => false,
575 },
576 Self::Max((a, b)) => match other {
577 Self::Max((c, d)) => commutative_eq(a, b, c, d),
578 _ => false,
579 },
580 Self::Min((a, b)) => match other {
581 Self::Min((c, d)) => commutative_eq(a, b, c, d),
582 _ => false,
583 },
584 Self::Sub((a, b)) => match other {
585 Self::Sub((c, d)) => a == c && b == d,
586 _ => false,
587 },
588 Self::Div((a, b)) => match other {
589 Self::Div((c, d)) => a == c && b == d,
590 _ => false,
591 },
592 Self::DivCeil((a, b)) => match other {
593 Self::DivCeil((c, d)) => a == c && b == d,
594 _ => false,
595 },
596 Self::Broadcast((a, b)) => match other {
597 Self::Broadcast((c, d)) => commutative_eq(a, b, c, d),
598 _ => false,
599 },
600 }
601 }
602}
603
604impl Add<SymExpr> for SymExpr {
605 type Output = SymExpr;
606
607 fn add(self, rhs: SymExpr) -> Self {
608 Self::Add((self.into(), rhs.into()))
609 }
610}
611
612impl Sub<SymExpr> for SymExpr {
613 type Output = SymExpr;
614
615 fn sub(self, rhs: SymExpr) -> Self {
616 Self::Sub((self.into(), rhs.into()))
617 }
618}
619
620impl AddAssign<SymExpr> for SymExpr {
621 fn add_assign(&mut self, rhs: SymExpr) {
622 *self = Self::Add((self.clone().into(), rhs.into()));
623 }
624}
625
626impl Mul<SymExpr> for SymExpr {
627 type Output = SymExpr;
628
629 fn mul(self, rhs: SymExpr) -> Self {
630 Self::Mul((self.into(), rhs.into()))
631 }
632}
633
634impl Div<SymExpr> for SymExpr {
635 type Output = SymExpr;
636
637 fn div(self, rhs: SymExpr) -> Self {
638 Self::Div((self.into(), rhs.into()))
639 }
640}
641
642impl Neg for SymExpr {
643 type Output = SymExpr;
644
645 fn neg(self) -> Self {
646 Self::Neg(self.into())
647 }
648}
649
650impl From<Symbol> for SymExpr {
651 fn from(val: Symbol) -> Self {
652 Self::Var(val.into())
653 }
654}
655
656impl<'a> From<&'a str> for SymExpr {
662 fn from(name: &'a str) -> Self {
663 SymExpr::Var(
664 Symbol {
665 name: name.to_string(),
666 positive: true,
667 }
668 .into(),
669 )
670 }
671}
672
673impl From<i32> for SymExpr {
674 fn from(val: i32) -> Self {
675 SymExpr::Value(val)
676 }
677}
678
679impl fmt::Debug for SymExpr {
680 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681 let add_parens = |f: &mut fmt::Formatter<'_>, expr: &SymExpr| {
682 if expr.precedence() < self.precedence() {
683 write!(f, "({:?})", expr)
684 } else {
685 write!(f, "{:?}", expr)
686 }
687 };
688 let write_binop = |f: &mut fmt::Formatter<'_>, op, lhs, rhs| {
689 add_parens(f, lhs)?;
690 write!(f, " {op} ")?;
691 add_parens(f, rhs)
692 };
693 match self {
694 Self::Value(val) => write!(f, "{}", val),
695 Self::Var(sym) => write!(
696 f,
697 "\"{}\"{}",
698 sym.name,
699 if sym.positive { 'u' } else { 'i' }
700 ),
701 Self::Neg(expr) => write!(f, "-{:?}", expr),
704 Self::Add((lhs, rhs)) => write_binop(f, '+', lhs, rhs),
705 Self::Sub((lhs, rhs)) => write_binop(f, '-', lhs, rhs),
706 Self::Mul((lhs, rhs)) => write_binop(f, '*', lhs, rhs),
707 Self::Div((lhs, rhs)) => write_binop(f, '/', lhs, rhs),
708 Self::DivCeil((lhs, rhs)) => write!(f, "ceil_div({:?}, {:?})", lhs, rhs),
709 Self::Max((lhs, rhs)) => write!(f, "max({:?}, {:?})", lhs, rhs),
710 Self::Min((lhs, rhs)) => write!(f, "min({:?}, {:?})", lhs, rhs),
711 Self::Broadcast((lhs, rhs)) => write!(f, "broadcast({:?}, {:?})", lhs, rhs),
712 }
713 }
714}
715
716impl fmt::Display for SymExpr {
717 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
718 let add_parens = |f: &mut fmt::Formatter<'_>, expr: &SymExpr| {
719 if expr.precedence() < self.precedence() {
720 write!(f, "({})", expr)
721 } else {
722 write!(f, "{}", expr)
723 }
724 };
725 let write_binop = |f: &mut fmt::Formatter<'_>, op, lhs, rhs| {
726 add_parens(f, lhs)?;
727 write!(f, " {op} ")?;
728 add_parens(f, rhs)
729 };
730 match self {
731 Self::Value(val) => write!(f, "{}", val),
732 Self::Var(sym) => write!(f, "{}", sym.name),
733 Self::Neg(expr) => write!(f, "-{}", expr),
736 Self::Add((lhs, rhs)) => write_binop(f, '+', lhs, rhs),
737 Self::Sub((lhs, rhs)) => write_binop(f, '-', lhs, rhs),
738 Self::Mul((lhs, rhs)) => write_binop(f, '*', lhs, rhs),
739 Self::Div((lhs, rhs)) => write_binop(f, '/', lhs, rhs),
740 Self::DivCeil((lhs, rhs)) => write!(f, "ceil_div({}, {})", lhs, rhs),
741 Self::Max((lhs, rhs)) => write!(f, "max({}, {})", lhs, rhs),
742 Self::Min((lhs, rhs)) => write!(f, "min({}, {})", lhs, rhs),
743 Self::Broadcast((lhs, rhs)) => write!(f, "broadcast({}, {})", lhs, rhs),
744 }
745 }
746}
747
748pub const fn div_ceil(lhs: i32, rhs: i32) -> i32 {
750 let d = lhs / rhs;
751 let r = lhs % rhs;
752
753 let correction = 1 + ((lhs ^ rhs) >> (i32::BITS - 1));
756 if r != 0 { d + correction } else { d }
757}
758
759#[cfg(test)]
760mod tests {
761 use super::SymExpr;
762
763 #[test]
764 fn test_range() {
765 let x = SymExpr::pos_var("x");
766 assert_eq!(x.range(), (0, i32::MAX));
767
768 let y = SymExpr::var("y");
769 assert_eq!(y.range(), (i32::MIN, i32::MAX));
770 }
771
772 #[test]
773 fn test_simplify_add() {
774 let x = SymExpr::pos_var("x");
775 let zero = SymExpr::from(0);
776 let one = SymExpr::from(1);
777
778 let expr = x.clone() + zero.clone();
779 assert_eq!(expr, SymExpr::Add((x.clone().into(), zero.clone().into())));
780 assert_eq!(expr.simplify(), x);
781
782 let expr_2 = x.clone() + one.clone();
783 assert_eq!(
784 expr_2.simplify(),
785 SymExpr::Add((x.clone().into(), one.clone().into()))
786 );
787 }
788
789 #[test]
792 fn test_simplify_add_reassociate() {
793 let x = SymExpr::from("x");
794 let c1 = SymExpr::from(3);
795 let c2 = SymExpr::from(4);
796
797 let expr = (x.clone() + c1.clone()) + c2.clone();
799 let simplified = expr.simplify();
800 assert_eq!(simplified, SymExpr::from(7) + x.clone());
801
802 let expr = (x.clone() + c1) + (x.clone() + c2);
804 let simplified = expr.simplify();
805 assert_eq!(simplified, SymExpr::from(7) + x.clone() + x);
806 }
807
808 #[test]
809 fn test_simplify_sub() {
810 let x = SymExpr::pos_var("x");
811 let zero = SymExpr::from(0);
812 let one = SymExpr::from(1);
813
814 let expr = x.clone() - zero.clone();
816 assert_eq!(expr, SymExpr::Sub((x.clone().into(), zero.clone().into())));
817 assert_eq!(expr.simplify(), x);
818
819 let expr = x.clone() - x.clone();
821 assert_eq!(expr.simplify(), SymExpr::Value(0));
822
823 let expr_2 = x.clone() - one.clone();
825 assert_eq!(
826 expr_2.simplify(),
827 SymExpr::Add((x.clone().into(), SymExpr::from(-1).into()))
828 );
829
830 let y = SymExpr::pos_var("y");
832 let expr = x.clone() + y.clone() - x.clone();
833 assert_eq!(expr.simplify(), y.clone());
834
835 let expr = x.clone() + x.clone() + y.clone() - x.clone();
837 assert_eq!(expr.simplify(), x.clone() + y.clone());
838
839 let expr = x.clone() + y.clone() - x.clone() - y.clone();
841 assert_eq!(expr.simplify(), 0.into());
842
843 let expr = -x.clone() + x.clone();
845 assert_eq!(expr.simplify(), 0.into());
846
847 let expr = x.clone() + (-x.clone());
849 assert_eq!(expr.simplify(), 0.into());
850
851 let expr = (x.clone() + y.clone()) - (x.clone() + y.clone());
853 assert_eq!(expr.simplify(), 0.into());
854 }
855
856 #[test]
857 fn test_simplify_mul() {
858 let x = SymExpr::pos_var("x");
859 let one = SymExpr::from(1);
860 let two = SymExpr::from(2);
861
862 let expr = x.clone() * one.clone();
863 assert_eq!(expr, SymExpr::Mul((x.clone().into(), one.clone().into())));
864 assert_eq!(expr.simplify(), x);
865
866 let expr_2 = x.clone() * two.clone();
867 assert_eq!(
868 expr_2.simplify(),
869 SymExpr::Mul((x.clone().into(), two.clone().into()))
870 );
871 }
872
873 #[test]
874 fn test_simplify_div() {
875 let x = SymExpr::pos_var("x");
876 let one = SymExpr::from(1);
877 let two = SymExpr::from(2);
878
879 let expr = SymExpr::from(5) / SymExpr::from(2);
881 assert_eq!(expr.simplify(), SymExpr::from(2));
882
883 let expr = SymExpr::from(5) / SymExpr::from(0);
885 assert_eq!(expr.simplify(), SymExpr::from(5) / SymExpr::from(0));
886
887 let expr = x.clone() / one.clone();
889 assert_eq!(expr, SymExpr::Div((x.clone().into(), one.clone().into())));
890 assert_eq!(expr.simplify(), x);
891
892 let expr = x.clone() / x.clone();
894 assert_eq!(expr.simplify(), one);
895
896 let expr_2 = x.clone() / two.clone();
898 assert_eq!(
899 expr_2.simplify(),
900 SymExpr::Div((x.clone().into(), two.clone().into()))
901 );
902
903 let expr = x.clone() / two.clone() / two.clone();
905 assert_eq!(expr.simplify(), x.clone() / SymExpr::from(4));
906
907 let zero = SymExpr::from(0);
909 let expr = x.clone() / zero.clone() / two.clone();
910 assert_eq!(expr.simplify(), x.clone() / (zero.clone() * two.clone()));
911
912 let expr = x.clone() / two.clone() / zero.clone();
914 assert_eq!(expr.simplify(), x.clone() / (two.clone() * zero));
915 }
916
917 #[test]
918 fn test_simplify_div_ceil() {
919 let x = SymExpr::pos_var("x");
920 let one = SymExpr::from(1);
921 let two = SymExpr::from(2);
922
923 let expr = SymExpr::from(5).div_ceil(&SymExpr::from(2));
925 assert_eq!(expr.simplify(), SymExpr::from(3));
926
927 let expr = SymExpr::from(5).div_ceil(&SymExpr::from(0));
929 assert_eq!(
930 expr.simplify(),
931 SymExpr::from(5).div_ceil(&SymExpr::from(0))
932 );
933
934 let expr = x.clone().div_ceil(&one);
936 assert_eq!(
937 expr,
938 SymExpr::DivCeil((x.clone().into(), one.clone().into()))
939 );
940 assert_eq!(expr.simplify(), x);
941
942 let expr = x.clone().div_ceil(&x);
944 assert_eq!(expr.simplify(), one);
945
946 let expr_2 = x.clone().div_ceil(&two);
948 assert_eq!(
949 expr_2.simplify(),
950 SymExpr::DivCeil((x.clone().into(), two.clone().into()))
951 );
952
953 let expr = x.clone().div_ceil(&two).div_ceil(&two);
955 assert_eq!(expr.simplify(), x.clone().div_ceil(&SymExpr::from(4)));
956
957 let zero = SymExpr::from(0);
959 let expr = x.clone().div_ceil(&zero).div_ceil(&two);
960 assert_eq!(
961 expr.simplify(),
962 x.clone().div_ceil(&(zero.clone() * two.clone()))
963 );
964
965 let neg_one = SymExpr::from(-1);
967 let expr = x.clone().div_ceil(&neg_one).div_ceil(&two);
968 assert_eq!(expr.simplify(), x.div_ceil(&(neg_one.clone() * two)));
969 }
970
971 #[test]
974 fn test_simplify_mul_reassociate() {
975 let x = SymExpr::from("x");
976 let c1 = SymExpr::from(3);
977 let c2 = SymExpr::from(4);
978
979 let expr = (x.clone() * c1.clone()) * c2.clone();
981 let simplified = expr.simplify();
982 assert_eq!(simplified, SymExpr::from(12) * x.clone());
983
984 let expr = SymExpr::from(5) + expr;
986 let simplified = expr.simplify();
987 assert_eq!(simplified, SymExpr::from(5) + SymExpr::from(12) * x.clone());
988
989 let expr = (x.clone() * c1) * (x.clone() * c2);
991 let simplified = expr.simplify();
992 assert_eq!(simplified, SymExpr::from(12) * x.clone() * x);
993 }
994
995 #[test]
996 fn test_simplify_max() {
997 let one = SymExpr::from(1);
998 let two = SymExpr::from(2);
999 let expr = one.max(&two);
1000
1001 assert_eq!(expr, SymExpr::Max((one.clone().into(), two.clone().into())));
1002 assert_eq!(expr.simplify(), two.clone());
1003 }
1004
1005 #[test]
1006 fn test_simplify_nested_max() {
1007 let expr = SymExpr::from(10)
1008 .max(&SymExpr::from(5).max(&SymExpr::from(11)))
1009 .simplify();
1010 assert_eq!(expr, SymExpr::from(11));
1011 }
1012
1013 #[test]
1014 fn test_simplify_min() {
1015 let one = SymExpr::from(1);
1016 let two = SymExpr::from(2);
1017 let expr = one.min(&two);
1018
1019 assert_eq!(expr, SymExpr::Min((one.clone().into(), two.clone().into())));
1020 assert_eq!(expr.simplify(), one.clone());
1021 }
1022
1023 #[test]
1024 fn test_simplify_nested_min() {
1025 let expr = SymExpr::from(10)
1026 .min(&SymExpr::from(5).min(&SymExpr::from(3)))
1027 .simplify();
1028 assert_eq!(expr, SymExpr::from(3));
1029 }
1030
1031 #[test]
1032 fn test_simplify_broadcast() {
1033 let one = SymExpr::from(1);
1034 let ten = SymExpr::from(10);
1035 let foo = SymExpr::from("foo");
1036
1037 assert_eq!(ten.broadcast(&ten).simplify(), ten.clone());
1039 assert_eq!(ten.broadcast(&foo).simplify(), ten.clone());
1040 assert_eq!(one.broadcast(&ten).simplify(), ten.clone());
1041 assert_eq!(ten.broadcast(&one).simplify(), ten.clone());
1042
1043 assert_eq!(foo.broadcast(&one).simplify(), foo.clone());
1045 assert_eq!(one.broadcast(&foo).simplify(), foo.clone());
1046
1047 assert_eq!(foo.broadcast(&foo).simplify(), foo.clone());
1049 }
1050
1051 #[test]
1052 fn test_simplify_nested_broadcast() {
1053 let foo = SymExpr::from("foo");
1054 let ten = SymExpr::from(10);
1055 let expr = foo.broadcast(&foo.broadcast(&ten)).simplify();
1056 assert_eq!(expr, SymExpr::from(10));
1057 }
1058
1059 #[test]
1060 fn test_simplify_neg() {
1061 let minus_one = -SymExpr::from(1);
1062 assert_eq!(minus_one.simplify(), SymExpr::from(-1));
1063 }
1064
1065 #[test]
1066 fn test_display() {
1067 let expr = (SymExpr::from(1) + SymExpr::pos_var("foo")) * SymExpr::from(3)
1068 + SymExpr::from(4)
1069 - SymExpr::from(5);
1070 assert_eq!(expr.to_string(), "(1 + foo) * 3 + 4 - 5");
1071 }
1072
1073 #[test]
1074 fn test_debug() {
1075 let expr = (SymExpr::from(1) + SymExpr::pos_var("foo")) * SymExpr::from(3)
1076 + SymExpr::var("bar")
1077 - SymExpr::from(5);
1078 assert_eq!(format!("{:?}", expr), "(1 + \"foo\"u) * 3 + \"bar\"i - 5");
1079 }
1080
1081 #[test]
1082 fn test_exact_div() {
1083 assert_eq!(
1085 SymExpr::from(15).exact_div(&SymExpr::from(3)),
1086 Some(SymExpr::from(5))
1087 );
1088 assert_eq!(SymExpr::from(15).exact_div(&SymExpr::from(4)), None);
1089 assert_eq!(SymExpr::from(15).exact_div(&SymExpr::from(0)), None);
1090
1091 assert_eq!(
1093 SymExpr::from("x").exact_div(&SymExpr::from("x")),
1094 Some(SymExpr::from(1))
1095 );
1096 assert_eq!(
1097 SymExpr::from("x").exact_div(&SymExpr::from(1)),
1098 Some(SymExpr::from("x"))
1099 );
1100
1101 assert_eq!(
1103 (SymExpr::from("x") * SymExpr::from("y"))
1104 .exact_div(&SymExpr::from("y"))
1105 .map(|s| s.simplify()),
1106 Some(SymExpr::from("x"))
1107 );
1108 assert_eq!(
1109 (SymExpr::from("y") * SymExpr::from("x"))
1110 .exact_div(&SymExpr::from("y"))
1111 .map(|s| s.simplify()),
1112 Some(SymExpr::from("x"))
1113 );
1114
1115 assert_eq!(SymExpr::from("x").exact_div(&SymExpr::from("y")), None);
1117 }
1118}