1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{DimLike, sym::*};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16 UndeterminedSymbol(String),
17 Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24 TooEarly::Other(s) => write!(f, "{s}"),
25 }
26 }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[derive(Clone, PartialEq, Eq, Hash, Debug)]
34pub enum TDim {
35 Val(i64),
36 Sym(Symbol),
37 Add(Vec<TDim>),
38 Mul(Vec<TDim>),
39 MulInt(i64, Box<TDim>),
40 Div(Box<TDim>, u64),
41 Broadcast(Vec<TDim>),
42 Min(Vec<TDim>),
43 Max(Vec<TDim>),
44}
45
46use TDim::*;
47
48fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
49 match (a, b) {
50 (Sym(a), Sym(b)) => a.cmp(b),
51 (Val(a), Val(b)) => a.cmp(b),
52 (Add(a), Add(b))
53 | (Mul(a), Mul(b))
54 | (Broadcast(a), Broadcast(b))
55 | (Min(a), Min(b))
56 | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
57 a.iter()
58 .zip(b.iter())
59 .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
60 ),
61 (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
62 (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
63 (Sym(_), _) => Ordering::Less,
64 (_, Sym(_)) => Ordering::Greater,
65 (Val(_), _) => Ordering::Less,
66 (_, Val(_)) => Ordering::Greater,
67 (Add(_), _) => Ordering::Less,
68 (_, Add(_)) => Ordering::Greater,
69 (Mul(_), _) => Ordering::Less,
70 (_, Mul(_)) => Ordering::Greater,
71 (MulInt(_, _), _) => Ordering::Less,
72 (_, MulInt(_, _)) => Ordering::Greater,
73 (Broadcast(_), _) => Ordering::Less,
74 (_, Broadcast(_)) => Ordering::Greater,
75 (Min(_), _) => Ordering::Less,
76 (_, Min(_)) => Ordering::Greater,
77 (Max(_), _) => Ordering::Less,
78 (_, Max(_)) => Ordering::Greater,
79 }
80}
81
82impl fmt::Display for TDim {
83 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
84 match &self {
85 Sym(sym) => write!(fmt, "{sym}"),
86 Val(it) => write!(fmt, "{it}"),
87 Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
88 Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
89 Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")),
90 Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
91 Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
92 MulInt(a, b) => write!(fmt, "{a}*{b}"),
93 Div(a, b) => write!(fmt, "({a})/{b}"),
94 }
95 }
96}
97
98impl TDim {
99 #[inline]
100 pub fn is_one(&self) -> bool {
101 matches!(self, Val(1))
102 }
103
104 #[inline]
105 pub fn to_i64(&self) -> TractResult<i64> {
106 if let Val(v) = self {
107 Ok(*v)
108 } else {
109 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
110 }
111 }
112
113 #[inline]
114 pub fn as_i64(&self) -> Option<i64> {
115 if let Val(v) = self { Some(*v) } else { None }
116 }
117
118 pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
119 match self {
120 Sym(sym) => {
121 let Some(v) = values.get(sym) else {
122 Err(TooEarly::UndeterminedSymbol(self.to_string()))?
123 };
124 Ok(v)
125 }
126 Val(v) => Ok(*v),
127 Add(terms) => {
128 terms.iter().try_fold(0, |acc, it| it.eval_to_i64(values).map(|x| acc + x))
129 }
130 Mul(terms) => {
131 terms.iter().try_fold(1, |acc, it| it.eval_to_i64(values).map(|x| acc * x))
132 }
133 Min(terms) => terms
134 .iter()
135 .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
136 Max(terms) => terms
137 .iter()
138 .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
139 Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
140 it.eval_to_i64(values)
141 .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
142 }),
143 Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
144 MulInt(p, a) => Ok(a.eval_to_i64(values)? * *p),
145 }
146 }
147
148 pub fn eval(&self, values: &SymbolValues) -> TDim {
149 match self {
150 Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
151 Val(v) => Val(*v),
152 Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
153 Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
154 Min(terms) => {
155 terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
156 }
157 Max(terms) => {
158 terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
159 }
160 Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
161 acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
162 }),
163 Div(a, q) => a.eval(values) / *q as i64,
164 MulInt(p, a) => a.eval(values) * *p,
165 }
166 }
167
168 pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
169 if let Val(v) = self {
170 return Val(*v);
171 }
172 let scope = self.find_scope().unwrap();
173 let scope = scope.0;
174 let locked = scope.lock();
175 let scope = locked.borrow();
176 self.clone().simplify_rec(&scope, Some(scenario))
177 }
178
179 pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
180 match self {
181 Sym(sym) => Ok(if sym == from { to.clone() } else { self.clone() }),
182 Val(v) => Ok(Val(*v)),
183 Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
184 Ok(acc + it.substitute(from, to)?)
185 }),
186 Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
187 Ok(acc * it.substitute(from, to)?)
188 }),
189 Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
190 acc.broadcast(it.substitute(from, to)?)
191 }),
192 Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
193 Ok(acc.mini(it.substitute(from, to)?))
194 }),
195 Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
196 Ok(acc.maxi(it.substitute(from, to)?))
197 }),
198 Div(a, q) => Ok(a.substitute(from, to)? / *q as i64),
199 MulInt(p, a) => Ok(a.substitute(from, to)? * *p),
200 }
201 }
202
203 pub fn reduce(self) -> TDim {
204 self.simplify()
205 .wiggle()
206 .into_iter()
207 .sorted_by(tdim_lexi_order)
208 .unique()
209 .map(|e| e.simplify())
210 .min_by_key(|e| e.cost())
211 .unwrap()
212 }
213
214 fn cost(&self) -> usize {
215 use self::TDim::*;
216 match self {
217 Sym(_) | Val(_) => 1,
218 Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
219 Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
220 Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
221 Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
222 Div(a, _) => 3 * a.cost(),
223 MulInt(_, a) => 2 * a.cost(),
224 }
225 }
226
227 fn wiggle(&self) -> Vec<TDim> {
228 use self::TDim::*;
229 match self {
230 Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) => vec![self.clone()],
231 Add(terms) => {
232 let mut forms = vec![];
233 let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
234
235 fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
236 terms.iter().enumerate().find_map(|(index, t)| match t {
237 Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
238 _ => None,
239 })
240 }
241
242 fn generate_new_numerator(
243 div_index: usize,
244 numerator: &TDim,
245 quotient: u64,
246 expr: &[TDim],
247 ) -> Vec<TDim> {
248 expr.iter()
249 .enumerate()
250 .map(|(index, term)| {
251 if index == div_index {
252 numerator.clone()
253 } else {
254 MulInt(quotient as i64, Box::new(term.clone()))
255 }
256 })
257 .collect()
258 }
259
260 for expr in sub_exprs {
261 if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
262 let new_numerator =
263 generate_new_numerator(div_index, numerator, quotient, &expr);
264 forms.push(Div(Box::new(Add(new_numerator)), quotient))
265 }
266
267 forms.push(Add(expr));
268 }
269 forms
270 }
271 MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
272 Div(a, q) => {
273 let mut forms = vec![];
274 for num in a.wiggle() {
275 if let Add(terms) = &num {
276 let (integer, non_integer): (Vec<_>, Vec<_>) =
277 terms.iter().cloned().partition(|a| a.gcd() % q == 0);
278 let mut new_terms = integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
279 if non_integer.len() > 0 {
280 new_terms.push(Div(b!(Add(non_integer)), *q));
281 }
282 forms.push(Add(new_terms))
283 }
284 forms.push(Div(b!(num), *q))
285 }
286 forms
287 }
288 }
289 }
290
291 fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
292 match tdim {
293 Val(_) => None,
294 Sym(s) => Some(s),
295 Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
296 terms.iter().find_map(Self::find_any_sym)
297 }
298 MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
299 }
300 }
301
302 pub fn find_scope(&self) -> Option<SymbolScope> {
303 Self::find_any_sym(self).and_then(|s| s.scope().clone())
304 }
305
306 pub fn simplify(self) -> TDim {
307 use self::TDim::*;
308 if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
309 return Val(v);
310 }
311 let Some(scope) = self.find_scope() else {
312 return self;
313 };
314 let scope = scope.0;
315 let locked = scope.lock();
316 let scope = locked.borrow();
317 let it = self.simplify_rec(&scope, None);
318 let mut current: Option<TDim> = None;
319 for scenario in scope.scenarios() {
320 let v = it.clone().simplify_rec(&scope, Some(scenario));
321 if current.is_some_and(|c| c != v) {
322 return it;
323 } else {
324 current = Some(v);
325 }
326 }
327 current.unwrap_or(it)
328 }
329
330 fn simplify_rec(self, scope: &SymbolScopeData, scenario: Option<&str>) -> TDim {
331 match self {
332 Add(mut terms) => {
333 #[allow(clippy::mutable_key_type)]
334 let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
335 while let Some(term) = terms.pop() {
337 let simplified = term.simplify_rec(scope, scenario);
338 match simplified {
339 Val(0) => {} Add(members) => {
341 terms.extend(members);
342 continue;
343 }
344 Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
345 MulInt(value, factor) => {
346 *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
347 }
348 n => *simplified_terms.entry(n).or_insert(0) += 1,
349 };
350 }
351
352 pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
353 match count {
354 0 => None,
355 _ if term == TDim::Val(1) => Some(TDim::Val(count)),
356 1 => Some(term),
357 _ => Some(TDim::MulInt(count, Box::new(term))),
358 }
359 }
360
361 let mut members: Vec<TDim> = simplified_terms
362 .into_iter()
363 .filter_map(|(term, count)| evaluate_count(term, count))
364 .collect();
365 members.sort_by(tdim_lexi_order);
366
367 match members.len() {
368 0 => TDim::Val(0),
369 1 => members.into_iter().next().unwrap(),
370 _ => TDim::Add(members),
371 }
372 }
373 Mul(terms) => {
374 let mut flattened_terms = vec![];
377 for t in terms {
378 if let Mul(inner_terms) = t.clone().reduce() {
379 flattened_terms.extend(inner_terms);
380 } else {
381 flattened_terms.push(t);
382 }
383 }
384 let mut terms = flattened_terms;
385
386 let mut gcd = Mul(terms.clone()).gcd() as i64;
387 if gcd == 0 {
388 return Val(0);
389 }
390 terms = if gcd != 1 {
391 terms
392 .into_iter()
393 .map(|t| {
394 let gcd = t.gcd();
395 (t / gcd).simplify_rec(scope, scenario)
396 })
397 .collect()
398 } else {
399 terms
400 };
401 if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
402 gcd = -gcd;
403 }
404 terms.retain(|t| !t.is_one() && t != &Val(-1));
405 terms.sort_by(tdim_lexi_order);
406
407 match (gcd, terms.len()) {
408 (_, 0) => Val(gcd), (0, _) => Val(0), (1, 1) => terms.remove(0), (1, _) => Mul(terms), (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), _ => MulInt(gcd, Box::new(Mul(terms))), }
416 }
417 MulInt(coef, expr) => {
418 match *expr {
419 MulInt(c2, inner) => {
420 return MulInt(coef * c2, inner).simplify_rec(scope, scenario);
421 }
422 Val(v) => return Val(coef * v),
423 _ => {}
424 }
425
426 let simplified = expr.simplify_rec(scope, scenario);
427 match (coef, simplified) {
428 (0, _) => Val(0), (1, s) => s, (_, Add(terms)) => Add(terms
431 .into_iter()
432 .map(|term| MulInt(coef, Box::new(term)).simplify_rec(scope, scenario))
433 .collect()), (c, Val(v)) => Val(c * v), (c, MulInt(v, inner)) => MulInt(c * v, inner), (_, s) => MulInt(coef, Box::new(s)), }
438 }
439 Div(a, q) => {
440 if q == 1 {
441 return a.simplify_rec(scope, scenario);
442 } else if let Div(a, q2) = *a {
443 return Div(a, q * q2).simplify_rec(scope, scenario);
444 }
445 let a = a.simplify_rec(scope, scenario);
446 if let Val(a) = a {
447 Val(a / q as i64)
448 } else if let MulInt(-1, a) = a {
449 MulInt(-1, b!(Div(a, q)))
450 } else if let Add(mut terms) = a {
451 if terms
452 .iter()
453 .any(|t| if let MulInt(-1, s) = t { matches!(&**s, Sym(_)) } else { false })
454 {
455 MulInt(
456 -1,
457 b!(Div(
458 b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
459 .simplify_rec(scope, scenario)),
460 q
461 )),
462 )
463 } else if let Some(v) =
464 terms.iter().find_map(|t| if let Val(v) = t { Some(*v) } else { None })
465 {
466 let offset = if v >= q as i64 {
467 Some(v / q as i64)
468 } else if v < 0 {
469 Some(-Integer::div_ceil(&-v, &(q as i64)))
470 } else {
471 None
472 };
473 if let Some(val) = offset {
474 terms.push(Val(-val * q as i64));
475 Add(vec![
476 Val(val),
477 Div(b!(Add(terms).simplify_rec(scope, scenario)), q),
478 ])
479 } else {
480 Div(b!(Add(terms)), q)
481 }
482 } else {
483 Div(b!(Add(terms)), q)
484 }
485 } else if let MulInt(p, a) = a {
486 if p == q as i64 {
487 a.simplify()
488 } else {
489 let gcd = p.abs().gcd(&(q as i64));
490 if gcd == p {
491 Div(a, q / gcd as u64)
492 } else if gcd == q as i64 {
493 MulInt(p / gcd, a)
494 } else if gcd > 1 {
495 Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
496 .simplify_rec(scope, scenario)
497 } else {
498 Div(b!(MulInt(p, a)), q)
499 }
500 }
501 } else {
502 Div(b!(a), q)
503 }
504 }
505 Broadcast(terms) => {
506 let mut terms: Vec<TDim> = terms
507 .iter()
508 .map(|s| s.clone().simplify_rec(scope, scenario))
509 .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
510 .filter(|t| !t.is_one())
511 .sorted_by(tdim_lexi_order)
512 .dedup()
513 .collect_vec();
514 match &*terms {
516 [] => Val(1),
517 [_] => terms.remove(0),
518 [a, Min(m)] | [Min(m), a]
519 if m.contains(a) && m.iter().all(|t| scope.prove_strict_positive(t)) =>
520 {
521 a.clone()
522 }
523 _ => Broadcast(terms),
524 }
525 }
526
527 Min(terms) => {
528 let mut flatten: Vec<TDim> = terms
529 .into_iter()
530 .map(|t| t.simplify_rec(scope, scenario))
531 .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
532 .sorted_by(tdim_lexi_order)
533 .dedup()
534 .collect();
535 #[allow(clippy::mutable_key_type)]
536 let mut redundant = HashSet::<TDim>::default();
537 for pair in flatten.iter().permutations(2) {
538 let (a, b) = (pair[0], pair[1]);
539 if redundant.contains(a) || redundant.contains(b) {
540 continue;
541 }
542 let diff = a.clone() - b;
543 if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff)
544 {
545 redundant.insert(a.clone());
546 }
547 }
548 flatten.retain(|t| !redundant.contains(t));
549 if flatten.len() == 0 {
550 i64::MAX.to_dim()
551 } else if flatten.len() == 1 {
552 flatten.into_iter().next().unwrap()
553 } else {
554 Min(flatten)
555 }
556 }
557 Max(terms) => {
558 let mut flatten: Vec<TDim> = terms
559 .into_iter()
560 .map(|t| t.simplify_rec(scope, scenario))
561 .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
562 .sorted_by(tdim_lexi_order)
563 .dedup()
564 .collect();
565 #[allow(clippy::mutable_key_type)]
566 let mut redundant = HashSet::<TDim>::default();
567 for pair in flatten.iter().permutations(2) {
568 let (a, b) = (pair[0], pair[1]);
569 if redundant.contains(a) || redundant.contains(b) {
570 continue;
571 }
572 let diff = a.clone() - b;
573 if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff)
574 {
575 redundant.insert(b.clone());
576 }
577 }
578 flatten.retain(|t| !redundant.contains(t));
579 if flatten.len() == 0 {
580 i64::MIN.to_dim()
581 } else if flatten.len() == 1 {
582 flatten.into_iter().next().unwrap()
583 } else {
584 Max(flatten)
585 }
586 }
587 Sym(s) => scope
588 .assertions(scenario)
589 .find_map(|a| match a {
590 Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
591 _ => None,
592 })
593 .unwrap_or(Sym(s)),
594 Val(_) => self,
595 }
596 }
597
598 pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
599 use self::TDim::*;
600 match self {
601 Val(n) => Some(*n),
602 Sym(_) => {
603 if upper {
604 scope
605 .all_assertions()
606 .iter()
607 .filter_map(|assert| match &assert {
608 Assertion::LT(left, right)
609 if left == self && right.as_i64().is_some() =>
610 {
611 Some(right.as_i64().unwrap() - 1)
612 }
613 Assertion::LTE(left, right)
614 if left == self && right.as_i64().is_some() =>
615 {
616 Some(right.as_i64().unwrap())
617 }
618 _ => None,
619 })
620 .min()
621 } else {
622 scope
623 .all_assertions()
624 .iter()
625 .filter_map(|assert| match &assert {
626 Assertion::GT(left, right)
627 if left == self && right.as_i64().is_some() =>
628 {
629 Some(right.as_i64().unwrap() + 1)
630 }
631 Assertion::GTE(left, right)
632 if left == self && right.as_i64().is_some() =>
633 {
634 Some(right.as_i64().unwrap())
635 }
636 _ => None,
637 })
638 .max()
639 }
640 }
641 Add(terms) => {
642 let mut bound = 0;
643 for t in terms {
644 if let Some(b) = t.inclusive_bound(scope, upper) {
645 bound += b;
646 } else {
647 return None;
648 }
649 }
650 Some(bound)
651 }
652 MulInt(p, a) => match p.cmp(&0) {
653 Ordering::Equal => Some(0),
654 Ordering::Greater => a.inclusive_bound(scope, upper).map(|x| x * p),
655 Ordering::Less => a.inclusive_bound(scope, !upper).map(|x| x * p),
656 },
657 Mul(_) => None,
658 Min(terms) if !upper => {
659 terms.iter().filter_map(|t| t.inclusive_bound(scope, false)).min()
660 }
661 Max(terms) if upper => {
662 terms.iter().filter_map(|t| t.inclusive_bound(scope, true)).max()
663 }
664 Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
665 Broadcast(terms) => {
666 if upper {
667 Max(terms.clone()).inclusive_bound(scope, true)
668 } else {
669 Min(terms.clone()).inclusive_bound(scope, false)
670 }
671 }
672 _ => None,
673 }
674 }
675
676 pub fn low_inclusive_bound(&self) -> Option<i64> {
677 if let TDim::Val(v) = self {
678 return Some(*v);
679 }
680 let scope = self.find_scope()?;
681 let data = scope.0.lock();
682 let data = data.borrow();
683 self.inclusive_bound(&data, false)
684 }
685
686 pub fn high_inclusive_bound(&self) -> Option<i64> {
687 if let TDim::Val(v) = self {
688 return Some(*v);
689 }
690 let scope = self.find_scope()?;
691 let data = scope.0.lock();
692 let data = data.borrow();
693 self.inclusive_bound(&data, true)
694 }
695
696 pub fn prove_positive_or_zero(&self) -> bool {
697 if let TDim::Val(v) = self {
698 return *v >= 0;
699 }
700 let Some(scope) = self.find_scope() else { return false };
701 let data = scope.0.lock();
702 let data = data.borrow();
703 data.prove_positive_or_zero(self)
704 }
705
706 pub fn prove_strict_positive(&self) -> bool {
707 if let TDim::Val(v) = self {
708 return *v > 0;
709 }
710 (self.clone() - 1).prove_positive_or_zero()
711 }
712
713 pub fn prove_negative_or_zero(&self) -> bool {
714 if let TDim::Val(v) = self {
715 return *v <= 0;
716 }
717 self.clone().neg().prove_positive_or_zero()
718 }
719
720 pub fn prove_strict_negative(&self) -> bool {
721 if let TDim::Val(v) = self {
722 return *v < 0;
723 }
724 self.clone().neg().prove_strict_positive()
725 }
726
727 pub fn gcd(&self) -> u64 {
728 use self::TDim::*;
729 match self {
730 Val(v) => v.unsigned_abs(),
731 Sym(_) => 1,
732 Add(terms) => {
733 let (head, tail) = terms.split_first().unwrap();
734 tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
735 }
736 MulInt(p, a) => a.gcd() * p.unsigned_abs(),
737 Mul(terms) => terms.iter().map(|t| t.gcd()).product(),
738 Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
739 Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
740 Div(a, q) => {
741 if a.gcd() % *q == 0 {
742 a.gcd() / *q
743 } else {
744 1
745 }
746 }
747 Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
748 }
749 }
750
751 fn div(&self, d: u64) -> TDim {
752 use self::TDim::*;
753 if d == 1 {
754 return self.clone();
755 }
756 match self {
757 Val(v) => Val(v / d as i64),
758 Sym(_) => panic!(),
759 Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
760 Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
761 Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
762 Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
763 Mul(_) => Div(Box::new(self.clone()), d),
764 MulInt(p, a) => {
765 if *p == d as i64 {
766 (**a).clone()
767 } else {
768 let gcd = p.unsigned_abs().gcd(&d);
769 MulInt(p / gcd as i64, b!(a.div(d / gcd)))
770 }
771 }
772 Div(a, q) => Div(a.clone(), q * d),
773 }
774 }
775
776 pub fn div_ceil(self, rhs: u64) -> TDim {
777 TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
778 }
779
780 pub(super) fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
781 fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
782 match d {
783 Val(_) => (0, 1),
784 Sym(s) => ((sym == s) as i64, 1),
785 Add(terms) => terms
786 .iter()
787 .map(|d| slope_rec(d, sym))
788 .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
789 Mul(terms) => terms
790 .iter()
791 .map(|d| slope_rec(d, sym))
792 .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
793 MulInt(p, a) => {
794 let (n, d) = slope_rec(a, sym);
795 (p * n, d)
796 }
797 Div(a, q) => {
798 let (n, d) = slope_rec(a, sym);
799 (n, d * *q as i64)
800 }
801 Broadcast(terms) => slope_rec(&terms[0], sym),
802 Min(terms) => slope_rec(&terms[0], sym),
803 Max(terms) => slope_rec(&terms[0], sym),
804 }
805 }
806 let (p, q) = slope_rec(self, sym);
807 reduce_ratio(p, q)
808 }
809
810 #[allow(clippy::mutable_key_type)]
811 pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
812 match self {
813 Val(_) => maplit::hashset!(),
814 Sym(s) => maplit::hashset!(s.clone()),
815 Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
816 terms.iter().fold(maplit::hashset!(), |mut set, v| {
817 set.extend(v.symbols());
818 set
819 })
820 }
821 MulInt(_, a) => a.symbols(),
822 Div(a, _) => a.symbols(),
823 }
824 }
825
826 pub fn compatible_with(&self, other: &TDim) -> bool {
827 if let Ok(x) = (self.clone() - other).to_i64() {
828 return x == 0;
829 }
830 true }
832}
833
834pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
835 let gcd = p.abs().gcd(&q.abs());
836 if gcd > 1 {
837 p /= gcd;
838 q /= gcd;
839 }
840 if q < 0 { (-p, (-q) as u64) } else { (p, q as u64) }
841}
842
843impl Zero for TDim {
844 fn zero() -> Self {
845 Val(0)
846 }
847 fn is_zero(&self) -> bool {
848 matches!(self, Val(0))
849 }
850}
851
852impl Default for TDim {
853 fn default() -> TDim {
854 Val(0)
855 }
856}
857
858impl num_traits::Bounded for TDim {
859 fn min_value() -> Self {
860 TDim::Val(i64::MIN)
861 }
862
863 fn max_value() -> Self {
864 TDim::Val(i64::MAX)
865 }
866}
867
868impl num_traits::One for TDim {
869 fn one() -> Self {
870 TDim::Val(1)
871 }
872}
873
874impl ::std::iter::Sum for TDim {
875 fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
876 iter.fold(0.into(), |a, b| a + b)
877 }
878}
879
880impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
881 fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
882 iter.fold(0.into(), |a, b| a + b)
883 }
884}
885
886impl std::iter::Product for TDim {
887 fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
888 iter.fold(TDim::Val(1), |a, b| a * b)
889 }
890}
891
892impl<'a> ::std::iter::Product<&'a TDim> for TDim {
893 fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
894 iter.fold(1.into(), |a, b| a * b)
895 }
896}
897
898macro_rules! from_i {
899 ($i: ty) => {
900 impl From<$i> for TDim {
901 fn from(v: $i) -> TDim {
902 TDim::Val(v as _)
903 }
904 }
905 impl<'a> From<&'a $i> for TDim {
906 fn from(v: &'a $i) -> TDim {
907 TDim::Val(*v as _)
908 }
909 }
910 };
911}
912
913from_i!(i32);
914from_i!(i64);
915from_i!(u64);
916from_i!(isize);
917from_i!(usize);
918
919impl From<Symbol> for TDim {
920 fn from(it: Symbol) -> Self {
921 TDim::Sym(it)
922 }
923}
924
925impl<'a> From<&'a Symbol> for TDim {
926 fn from(it: &'a Symbol) -> Self {
927 TDim::Sym(it.clone())
928 }
929}
930
931impl ops::Neg for TDim {
932 type Output = Self;
933 fn neg(self) -> Self {
934 if let Val(v) = self { Val(-v) } else { TDim::MulInt(-1, Box::new(self)).reduce() }
935 }
936}
937
938impl<'a> ops::AddAssign<&'a TDim> for TDim {
939 fn add_assign(&mut self, rhs: &'a TDim) {
940 if rhs.is_zero() {
941 } else if self.is_zero() {
942 *self = rhs.clone();
943 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
944 *s += o;
945 } else {
946 *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
947 }
948 }
949}
950
951impl<I> ops::AddAssign<I> for TDim
952where
953 I: Into<TDim>,
954{
955 fn add_assign(&mut self, rhs: I) {
956 let rhs = rhs.into();
957 if rhs.is_zero() {
958 } else if self.is_zero() {
959 *self = rhs;
960 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
961 *s += o;
962 } else {
963 *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
964 }
965 }
966}
967
968impl<I> ops::Add<I> for TDim
969where
970 I: Into<TDim>,
971{
972 type Output = Self;
973 fn add(mut self, rhs: I) -> Self {
974 self += rhs;
975 self
976 }
977}
978
979impl<'a> ops::Add<&'a TDim> for TDim {
980 type Output = Self;
981 fn add(mut self, rhs: &'a TDim) -> Self {
982 self += rhs;
983 self
984 }
985}
986
987#[allow(clippy::suspicious_op_assign_impl)]
988impl<'a> ops::SubAssign<&'a TDim> for TDim {
989 fn sub_assign(&mut self, rhs: &'a TDim) {
990 if rhs.is_zero() {
991 } else if self.is_zero() {
992 *self = rhs.clone().neg();
993 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
994 *s -= o;
995 } else {
996 *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
997 }
998 }
999}
1000
1001impl<I> ops::SubAssign<I> for TDim
1002where
1003 I: Into<TDim>,
1004{
1005 fn sub_assign(&mut self, rhs: I) {
1006 let rhs = rhs.into();
1007 if rhs.is_zero() {
1008 } else if self.is_zero() {
1009 *self = rhs.neg();
1010 } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1011 *s -= o;
1012 } else {
1013 *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1014 }
1015 }
1016}
1017
1018impl<I> ops::Sub<I> for TDim
1019where
1020 I: Into<TDim>,
1021{
1022 type Output = Self;
1023 fn sub(mut self, rhs: I) -> Self {
1024 self -= rhs;
1025 self
1026 }
1027}
1028
1029impl<'a> ops::Sub<&'a TDim> for TDim {
1030 type Output = Self;
1031 fn sub(mut self, rhs: &'a TDim) -> Self {
1032 self -= rhs;
1033 self
1034 }
1035}
1036
1037impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1038 fn mul_assign(&mut self, rhs: I) {
1039 let rhs = rhs.into();
1040 if self.is_one() {
1041 *self = rhs
1042 } else if rhs.is_one() {
1043 } else {
1044 *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1045 }
1046 }
1047}
1048
1049impl<'a> ops::MulAssign<&'a TDim> for TDim {
1050 fn mul_assign(&mut self, rhs: &'a TDim) {
1051 if self.is_one() {
1052 *self = rhs.clone()
1053 } else if rhs.is_one() {
1054 } else {
1055 *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1056 }
1057 }
1058}
1059
1060impl<I: Into<TDim>> ops::Mul<I> for TDim {
1061 type Output = Self;
1062 fn mul(mut self, rhs: I) -> Self {
1063 self *= rhs.into();
1064 self
1065 }
1066}
1067
1068impl<'a> ops::Mul<&'a TDim> for TDim {
1069 type Output = Self;
1070 fn mul(mut self, rhs: &'a TDim) -> Self {
1071 self *= rhs;
1072 self
1073 }
1074}
1075
1076impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1077 fn div_assign(&mut self, rhs: I) {
1078 *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1079 }
1080}
1081
1082impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1083 type Output = Self;
1084 fn div(mut self, rhs: I) -> Self {
1085 self /= rhs.as_();
1086 self
1087 }
1088}
1089
1090impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1091 fn rem_assign(&mut self, rhs: I) {
1092 *self += -(self.clone() / rhs.as_() * rhs.as_());
1093 }
1094}
1095
1096impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1097 type Output = Self;
1098 fn rem(mut self, rhs: I) -> Self {
1099 self %= rhs;
1100 self
1101 }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107
1108 macro_rules! b( ($e:expr) => { Box::new($e) } );
1109
1110 lazy_static::lazy_static! {
1111 static ref table: SymbolScope = SymbolScope::default();
1112 static ref A: Symbol = table.sym("a");
1113 static ref B: Symbol = table.sym("b");
1114 static ref C: Symbol = table.sym("c");
1115 static ref D: Symbol = table.sym("d");
1116 static ref E: Symbol = table.sym("e");
1117 }
1118
1119 fn neg(a: &TDim) -> TDim {
1120 mul(-1, a)
1121 }
1122
1123 fn add(a: &TDim, b: &TDim) -> TDim {
1124 TDim::Add(vec![a.clone(), b.clone()])
1125 }
1126
1127 fn mul(a: i64, b: &TDim) -> TDim {
1128 TDim::MulInt(a, b![b.clone()])
1129 }
1130
1131 fn div(a: &TDim, b: u64) -> TDim {
1132 TDim::Div(b!(a.clone()), b)
1133 }
1134
1135 #[test]
1136 fn reduce_add() {
1137 assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1138 }
1139
1140 #[test]
1141 fn reduce_neg_mul() {
1142 assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1143 }
1144
1145 #[test]
1146 fn reduce_cplx_ex_2() {
1147 assert_eq!(
1148 add(
1149 &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1150 &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1151 )
1152 .reduce(),
1153 Val(-4)
1154 )
1155 }
1156
1157 #[test]
1158 fn reduce_cplx_ex_3() {
1159 assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1160 }
1161
1162 #[test]
1163 fn reduce_cplx_ex_4() {
1164 assert_eq!(
1166 add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1167 .reduce(),
1168 1.into()
1169 );
1170 }
1171
1172 #[test]
1173 fn reduce_mul_mul_1() {
1174 assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1175 }
1176
1177 #[test]
1178 fn reduce_mul_mul_2() {
1179 assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1180 }
1181
1182 #[test]
1183 fn reduce_mul_div_1() {
1184 assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1185 }
1186
1187 #[test]
1188 fn const_and_add() {
1189 let e: TDim = 2i64.into();
1190 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1191 let e: TDim = TDim::from(2) + 3;
1192 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1193 let e: TDim = TDim::from(2) - 3;
1194 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1195 let e: TDim = -TDim::from(2);
1196 assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1197 }
1198
1199 #[test]
1200 fn substitution() {
1201 let a: TDim = A.to_dim();
1202 assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1203 let e = a + 3;
1204 assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1205 }
1206
1207 #[test]
1208 fn reduce_adds() {
1209 let e: TDim = TDim::from(2) + 1;
1210 assert_eq!(e, TDim::from(3));
1211 let e: TDim = TDim::from(3) + 2;
1212 assert_eq!(e, TDim::from(5));
1213 let e: TDim = TDim::from(3) + 0;
1214 assert_eq!(e, TDim::from(3));
1215 let e: TDim = TDim::from(3) + 2 + 1;
1216 assert_eq!(e, TDim::from(6));
1217 }
1218
1219 #[test]
1220 fn reduce_muls() {
1221 let e: TDim = Val(1) * A.to_dim();
1222 assert_eq!(e, A.to_dim());
1223 let e: TDim = A.to_dim() * &B.to_dim() * 1;
1224 assert_eq!(e, A.to_dim() * &B.to_dim());
1225 }
1226
1227 #[test]
1228 fn reduce_divs() {
1229 let e: TDim = TDim::from(2) / 1;
1230 assert_eq!(e, TDim::from(2));
1231 let e: TDim = TDim::from(3) / 2;
1232 assert_eq!(e, TDim::from(1));
1233 let e: TDim = TDim::from(3) % 2;
1234 assert_eq!(e, TDim::from(1));
1235 let e: TDim = TDim::from(5) / 2;
1236 assert_eq!(e, TDim::from(2));
1237 let e: TDim = TDim::from(5) % 2;
1238 assert_eq!(e, TDim::from(1));
1239 }
1240
1241 #[test]
1242 fn reduce_div_bug_0() {
1243 let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1244 let e2: TDim = (A.to_dim() + 21) / 2;
1245 assert_eq!(e1, e2);
1246 }
1247
1248 #[test]
1249 fn reduce_div_bug_1() {
1250 let e1: TDim = (A.to_dim() + -1) / 2;
1251 let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1252 assert_eq!(e1, e2);
1253 }
1254
1255 #[test]
1256 fn reduce_div_bug_2() {
1257 let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1258 let e2: TDim = (A.to_dim() + 3) / 4;
1259 assert_eq!(e1, e2);
1260 }
1261
1262 #[test]
1263 fn reduce_div_bug_3() {
1264 let e1: TDim = (A.to_dim() / 2) * -4;
1265 let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1266 assert_eq!(e1, e2);
1267 }
1268
1269 #[test]
1270 fn reduce_mul_div() {
1271 let e: TDim = A.to_dim() * 2 / 2;
1272 assert_eq!(e, A.to_dim());
1273 }
1274
1275 #[test]
1276 fn reduce_div_mul() {
1277 let e: TDim = A.to_dim() / 2 * 2;
1278 assert_ne!(e, A.to_dim());
1279 }
1280
1281 #[test]
1282 fn reduce_add_div() {
1283 let e: TDim = A.to_dim() / 2 + 1;
1284 assert_eq!(e, ((A.to_dim() + 2) / 2));
1285 }
1286
1287 #[test]
1288 fn reduce_neg_mul_() {
1289 let e: TDim = TDim::from(1) - A.to_dim() * 2;
1290 assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1291 }
1292
1293 #[test]
1294 fn reduce_add_rem_1() {
1295 assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1296 }
1297
1298 #[test]
1299 fn reduce_add_rem_2() {
1300 assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1301 }
1302
1303 #[test]
1304 fn reduce_rem_div() {
1305 let e: TDim = A.to_dim() % 2 / 2;
1306 assert_eq!(e, TDim::from(0));
1307 }
1308
1309 #[test]
1310 fn conv2d_ex_1() {
1311 let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1312 assert_eq!(e, TDim::from(1));
1313 }
1314
1315 #[test]
1316 fn conv2d_ex_2() {
1317 let e = (A.to_dim() - 3 + 1).div_ceil(1);
1318 assert_eq!(e, A.to_dim() + -2);
1319 }
1320
1321 #[test]
1322 fn extract_int_gcd_from_muls() {
1323 let term = (A.to_dim() + 1) / 4;
1324 let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1325 let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1326 assert_eq!(mul, target);
1327 }
1328
1329 #[test]
1330 fn equality_of_muls() {
1331 let term = (A.to_dim() + 1) / 4;
1332 let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1333 let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1334 assert_eq!(mul1, mul2);
1335 }
1336
1337 #[test]
1338 fn factorize_complex_expr_times_int() {
1339 let term = (A.to_dim() + 1) / 4;
1340 let e = term.clone() * 2 - &term - 1;
1341 assert_eq!(e, term - 1);
1342 }
1343
1344 #[test]
1345 fn broadcast_over_min() {
1346 for a in 1..5 {
1352 for b in 1..5 {
1353 if b > 1 && a > b {
1354 assert!(a.broadcast(a.min(b)).is_err());
1355 } else {
1356 assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1357 }
1358 }
1359 }
1360 }
1361
1362 #[test]
1363 fn min_ints_1() {
1364 assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1365 }
1366
1367 #[test]
1368 fn min_ints_2() {
1369 assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1370 }
1371
1372 #[test]
1373 fn min_same() {
1374 assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1375 }
1376
1377 #[test]
1378 fn min_noop() {
1379 assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1380 }
1381
1382 #[test]
1383 fn min_diff_1() {
1384 assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
1385 }
1386
1387 #[test]
1388 fn slope_0() {
1389 assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
1390 }
1391
1392 #[test]
1393 fn slope_1() {
1394 assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
1395 }
1396
1397 #[test]
1398 fn slope_2() {
1399 assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
1400 }
1401
1402 #[test]
1403 fn slope_3() {
1404 assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
1405 }
1406
1407 #[test]
1408 fn slope_4() {
1409 assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
1410 }
1411
1412 #[test]
1413 fn slope_5() {
1414 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1415 assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
1416 }
1417
1418 #[test]
1419 fn slope_6() {
1420 assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1421 assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
1422 }
1423
1424 #[test]
1425 fn min_0() -> TractResult<()> {
1426 let symbols = SymbolScope::default();
1427 assert_eq!(
1428 symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
1429 symbols.parse_tdim("S+2").unwrap(),
1430 );
1431 Ok(())
1432 }
1433
1434 #[test]
1435 fn commutative_mul_parens() -> TractResult<()> {
1436 let symbols = SymbolScope::default();
1437 assert_eq!(
1438 symbols.parse_tdim("A*(B*C)").unwrap().simplify(),
1439 symbols.parse_tdim("(B*A)*C").unwrap().simplify(),
1440 );
1441 Ok(())
1442 }
1443
1444 #[test]
1445 fn commutative_in_nemo_parakeet_model() -> TractResult<()> {
1446 let symbols = SymbolScope::default();
1447 assert_eq!(
1448 symbols
1449 .parse_tdim("8*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8))*((B)*((S+7)/8))")
1450 .unwrap()
1451 .simplify(),
1452 symbols
1453 .parse_tdim("8*((B)*(1+-1*max(0,5000+-1*(S+7)/8)+max(0,4999+(S+7)/8)))*((S+7)/8)")
1454 .unwrap()
1455 .simplify(),
1456 );
1457 Ok(())
1458 }
1459
1460 #[test]
1461 fn commutative_mul_parens_deep() -> TractResult<()> {
1462 let symbols = SymbolScope::default();
1463 let deep_tdim = Mul(vec![
1464 Mul(vec![Mul(vec![Mul(vec![A.to_dim(), B.to_dim()]), C.to_dim()]), D.to_dim()]),
1465 E.to_dim(),
1466 ])
1467 .simplify();
1468 assert_eq!(deep_tdim, symbols.parse_tdim("a*b*c*d*e").unwrap().simplify());
1469 Ok(())
1470 }
1471}