tract_data/dim/
tree.rs

1use crate::dim::Assertion;
2use crate::internal::*;
3
4use super::{sym::*, DimLike};
5use itertools::Itertools;
6use num_integer::Integer;
7use num_traits::{AsPrimitive, PrimInt, Zero};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt::Debug;
11use std::ops::Neg;
12use std::{fmt, ops};
13
14#[derive(Debug)]
15pub enum TooEarly {
16    UndeterminedSymbol(TDim),
17    Other(String),
18}
19
20impl std::fmt::Display for TooEarly {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TooEarly::UndeterminedSymbol(s) => write!(f, "Undetermined symbol in expression: {s}"),
24            TooEarly::Other(s) => write!(f, "{s}"),
25        }
26    }
27}
28
29impl std::error::Error for TooEarly {}
30
31macro_rules! b( ($e:expr) => { Box::new($e) } );
32
33#[derive(Clone, PartialEq, Eq, Hash, Debug)]
34pub enum TDim {
35    Val(i64),
36    Sym(Symbol),
37    Add(Vec<TDim>),
38    Mul(Vec<TDim>),
39    MulInt(i64, Box<TDim>),
40    Div(Box<TDim>, u64),
41    Broadcast(Vec<TDim>),
42    Min(Vec<TDim>),
43    Max(Vec<TDim>),
44}
45
46use TDim::*;
47
48fn tdim_lexi_order(a: &TDim, b: &TDim) -> Ordering {
49    match (a, b) {
50        (Sym(a), Sym(b)) => a.cmp(b),
51        (Val(a), Val(b)) => a.cmp(b),
52        (Add(a), Add(b))
53        | (Mul(a), Mul(b))
54        | (Broadcast(a), Broadcast(b))
55        | (Min(a), Min(b))
56        | (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
57            a.iter()
58                .zip(b.iter())
59                .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_lexi_order(a, b))),
60        ),
61        (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
62        (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_lexi_order(d, e)),
63        (Sym(_), _) => Ordering::Less,
64        (_, Sym(_)) => Ordering::Greater,
65        (Val(_), _) => Ordering::Less,
66        (_, Val(_)) => Ordering::Greater,
67        (Add(_), _) => Ordering::Less,
68        (_, Add(_)) => Ordering::Greater,
69        (Mul(_), _) => Ordering::Less,
70        (_, Mul(_)) => Ordering::Greater,
71        (MulInt(_, _), _) => Ordering::Less,
72        (_, MulInt(_, _)) => Ordering::Greater,
73        (Broadcast(_), _) => Ordering::Less,
74        (_, Broadcast(_)) => Ordering::Greater,
75        (Min(_), _) => Ordering::Less,
76        (_, Min(_)) => Ordering::Greater,
77        (Max(_), _) => Ordering::Less,
78        (_, Max(_)) => Ordering::Greater,
79    }
80}
81
82impl fmt::Display for TDim {
83    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
84        match &self {
85            Sym(sym) => write!(fmt, "{sym}"),
86            Val(it) => write!(fmt, "{it}"),
87            Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
88            Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
89            Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")),
90            Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
91            Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
92            MulInt(a, b) => write!(fmt, "{a}*{b}"),
93            Div(a, b) => write!(fmt, "({a})/{b}"),
94        }
95    }
96}
97
98impl TDim {
99    #[inline]
100    pub fn is_one(&self) -> bool {
101        matches!(self, Val(1))
102    }
103
104    #[inline]
105    pub fn to_i64(&self) -> TractResult<i64> {
106        if let Val(v) = self {
107            Ok(*v)
108        } else {
109            Err(TooEarly::UndeterminedSymbol(self.clone()).into())
110        }
111    }
112
113    #[inline]
114    pub fn as_i64(&self) -> Option<i64> {
115        if let Val(v) = self {
116            Some(*v)
117        } else {
118            None
119        }
120    }
121
122    pub fn eval_to_i64(&self, values: &SymbolValues) -> TractResult<i64> {
123        match self {
124            Sym(sym) => {
125                let Some(v) = values.get(sym) else {
126                    bail!(TooEarly::UndeterminedSymbol(self.clone()))
127                };
128                Ok(v)
129            }
130            Val(v) => Ok(*v),
131            Add(terms) => {
132                terms.iter().try_fold(0, |acc, it| it.eval_to_i64(values).map(|x| acc + x))
133            }
134            Mul(terms) => {
135                terms.iter().try_fold(1, |acc, it| it.eval_to_i64(values).map(|x| acc * x))
136            }
137            Min(terms) => terms
138                .iter()
139                .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
140            Max(terms) => terms
141                .iter()
142                .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
143            Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
144                it.eval_to_i64(values)
145                    .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
146            }),
147            Div(a, q) => Ok(a.eval_to_i64(values)? / *q as i64),
148            MulInt(p, a) => Ok(a.eval_to_i64(values)? * *p),
149        }
150    }
151
152    pub fn eval(&self, values: &SymbolValues) -> TDim {
153        match self {
154            Sym(sym) => values.get(sym).map(Val).unwrap_or_else(|| Sym(sym.clone())),
155            Val(v) => Val(*v),
156            Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
157            Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
158            Min(terms) => {
159                terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
160            }
161            Max(terms) => {
162                terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
163            }
164            Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
165                acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
166            }),
167            Div(a, q) => a.eval(values) / *q as i64,
168            MulInt(p, a) => a.eval(values) * *p,
169        }
170    }
171
172    pub fn eval_with_scenario(&self, scenario: &str) -> TDim {
173        if let Val(v) = self {
174            return Val(*v);
175        }
176        let scope = self.find_scope().unwrap();
177        let scope = scope.0;
178        let locked = scope.lock();
179        let scope = locked.borrow();
180        self.clone().simplify_rec(&scope, Some(scenario))
181    }
182
183    pub fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self> {
184        match self {
185            Sym(sym) => Ok(if sym == from { to.clone() } else { self.clone() }),
186            Val(v) => Ok(Val(*v)),
187            Add(terms) => terms.iter().try_fold(Val(0), |acc, it| -> TractResult<TDim> {
188                Ok(acc + it.substitute(from, to)?)
189            }),
190            Mul(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
191                Ok(acc * it.substitute(from, to)?)
192            }),
193            Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
194                acc.broadcast(it.substitute(from, to)?)
195            }),
196            Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
197                Ok(acc.mini(it.substitute(from, to)?))
198            }),
199            Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
200                Ok(acc.maxi(it.substitute(from, to)?))
201            }),
202            Div(a, q) => Ok(a.substitute(from, to)? / *q as i64),
203            MulInt(p, a) => Ok(a.substitute(from, to)? * *p),
204        }
205    }
206
207    pub fn reduce(self) -> TDim {
208        self.simplify()
209            .wiggle()
210            .into_iter()
211            .sorted_by(tdim_lexi_order)
212            .unique()
213            .map(|e| e.simplify())
214            .min_by_key(|e| e.cost())
215            .unwrap()
216    }
217
218    fn cost(&self) -> usize {
219        use self::TDim::*;
220        match self {
221            Sym(_) | Val(_) => 1,
222            Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
223            Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
224            Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
225            Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
226            Div(a, _) => 3 * a.cost(),
227            MulInt(_, a) => 2 * a.cost(),
228        }
229    }
230
231    fn wiggle(&self) -> Vec<TDim> {
232        use self::TDim::*;
233        match self {
234            Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) => vec![self.clone()],
235            Add(terms) => {
236                let mut forms = vec![];
237                let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
238
239                fn first_div_term(terms: &[TDim]) -> Option<(usize, &TDim, u64)> {
240                    terms.iter().enumerate().find_map(|(index, t)| match t {
241                        Div(numerator, quotient) => Some((index, &**numerator, *quotient)),
242                        _ => None,
243                    })
244                }
245
246                fn generate_new_numerator(
247                    div_index: usize,
248                    numerator: &TDim,
249                    quotient: u64,
250                    expr: &[TDim],
251                ) -> Vec<TDim> {
252                    expr.iter()
253                        .enumerate()
254                        .map(|(index, term)| {
255                            if index == div_index {
256                                numerator.clone()
257                            } else {
258                                MulInt(quotient as i64, Box::new(term.clone()))
259                            }
260                        })
261                        .collect()
262                }
263
264                for expr in sub_exprs {
265                    if let Some((div_index, numerator, quotient)) = first_div_term(&expr) {
266                        let new_numerator =
267                            generate_new_numerator(div_index, numerator, quotient, &expr);
268                        forms.push(Div(Box::new(Add(new_numerator)), quotient))
269                    }
270
271                    forms.push(Add(expr));
272                }
273                forms
274            }
275            MulInt(p, a) => a.wiggle().into_iter().map(|a| MulInt(*p, b!(a))).collect(),
276            Div(a, q) => {
277                let mut forms = vec![];
278                for num in a.wiggle() {
279                    if let Add(terms) = &num {
280                        let (integer, non_integer): (Vec<_>, Vec<_>) =
281                            terms.iter().cloned().partition(|a| a.gcd() % q == 0);
282                        let mut new_terms = integer.iter().map(|i| i.div(*q)).collect::<Vec<_>>();
283                        if non_integer.len() > 0 {
284                            new_terms.push(Div(b!(Add(non_integer)), *q));
285                        }
286                        forms.push(Add(new_terms))
287                    }
288                    forms.push(Div(b!(num), *q))
289                }
290                forms
291            }
292        }
293    }
294
295    fn find_any_sym(tdim: &TDim) -> Option<&Symbol> {
296        match tdim {
297            Val(_) => None,
298            Sym(s) => Some(s),
299            Add(terms) | Mul(terms) | Min(terms) | Max(terms) | Broadcast(terms) => {
300                terms.iter().find_map(Self::find_any_sym)
301            }
302            MulInt(_, t) | Div(t, _) => Self::find_any_sym(t),
303        }
304    }
305
306    pub fn find_scope(&self) -> Option<SymbolScope> {
307        Self::find_any_sym(self).and_then(|s| s.scope().clone())
308    }
309
310    pub fn simplify(self) -> TDim {
311        use self::TDim::*;
312        if let Ok(v) = self.eval_to_i64(&SymbolValues::default()) {
313            return Val(v);
314        }
315        let Some(scope) = self.find_scope() else {
316            return self;
317        };
318        let scope = scope.0;
319        let locked = scope.lock();
320        let scope = locked.borrow();
321        let it = self.simplify_rec(&scope, None);
322        let mut current: Option<TDim> = None;
323        for scenario in scope.scenarios() {
324            let v = it.clone().simplify_rec(&scope, Some(scenario));
325            if current.is_some_and(|c| c != v) {
326                return it;
327            } else {
328                current = Some(v);
329            }
330        }
331        current.unwrap_or(it)
332    }
333
334    fn simplify_rec(self, scope: &SymbolScopeData, scenario: Option<&str>) -> TDim {
335        match self {
336            Add(mut terms) => {
337                #[allow(clippy::mutable_key_type)]
338                let mut simplified_terms: HashMap<TDim, i64> = HashMap::new();
339                // factorize common sub-expr
340                while let Some(term) = terms.pop() {
341                    let simplified = term.simplify_rec(scope, scenario);
342                    match simplified {
343                        Val(0) => {} // ignore
344                        Add(members) => {
345                            terms.extend(members);
346                            continue;
347                        }
348                        Val(value) => *simplified_terms.entry(Val(1)).or_insert(0) += value,
349                        MulInt(value, factor) => {
350                            *simplified_terms.entry((*factor).clone()).or_insert(0) += value;
351                        }
352                        n => *simplified_terms.entry(n).or_insert(0) += 1,
353                    };
354                }
355
356                pub fn evaluate_count(term: TDim, count: i64) -> Option<TDim> {
357                    match count {
358                        0 => None,
359                        _ if term == TDim::Val(1) => Some(TDim::Val(count)),
360                        1 => Some(term),
361                        _ => Some(TDim::MulInt(count, Box::new(term))),
362                    }
363                }
364
365                let mut members: Vec<TDim> = simplified_terms
366                    .into_iter()
367                    .filter_map(|(term, count)| evaluate_count(term, count))
368                    .collect();
369                members.sort_by(tdim_lexi_order);
370
371                match members.len() {
372                    0 => TDim::Val(0),
373                    1 => members.into_iter().next().unwrap(),
374                    _ => TDim::Add(members),
375                }
376            }
377            Mul(terms) => {
378                let mut gcd = Mul(terms.clone()).gcd() as i64;
379                if gcd == 0 {
380                    return Val(0);
381                }
382                let mut terms = if gcd != 1 {
383                    terms
384                        .into_iter()
385                        .map(|t| {
386                            let gcd = t.gcd();
387                            (t / gcd).simplify_rec(scope, scenario)
388                        })
389                        .collect()
390                } else {
391                    terms
392                };
393                if terms.iter().filter(|t| t == &&Val(-1)).count() % 2 == 1 {
394                    gcd = -gcd;
395                }
396                terms.retain(|t| !t.is_one() && t != &Val(-1));
397                terms.sort_by(tdim_lexi_order);
398                match (gcd, terms.len()) {
399                    (_, 0) => Val(gcd), // Case #1: If 0 variables, return product
400                    (0, _) => Val(0),   // Case #2: Result is 0 if coef is 0 (actually
401                    // unreachable as we check at the beginning)
402                    (1, 1) => terms.remove(0), // Case #3: Product is 1, so return the only term
403                    (1, _) => Mul(terms), // Case #4: Product is 1, so return the non-integer terms
404                    (_, 1) => MulInt(gcd, Box::new(terms.remove(0))), // Case #5: Single variable, convert to 1 MulInt
405                    _ => MulInt(gcd, Box::new(Mul(terms))), // Case #6: Multiple variables, convert to MulInt
406                }
407            }
408            MulInt(coef, expr) => {
409                match *expr {
410                    MulInt(c2, inner) => {
411                        return MulInt(coef * c2, inner).simplify_rec(scope, scenario)
412                    }
413                    Val(v) => return Val(coef * v),
414                    _ => {}
415                }
416
417                let simplified = expr.simplify_rec(scope, scenario);
418                match (coef, simplified) {
419                    (0, _) => Val(0), // Case #1: If coef is 0, return 0
420                    (1, s) => s,      // Case #2: If coef is 1, return the simplified expression
421                    (_, Add(terms)) => Add(terms
422                        .into_iter()
423                        .map(|term| MulInt(coef, Box::new(term)).simplify_rec(scope, scenario))
424                        .collect()), // Case #3: If expression is an addition, distribute the coef
425                    (c, Val(v)) => Val(c * v), // Case #4: If expression is a value, combine coefs
426                    (c, MulInt(v, inner)) => MulInt(c * v, inner), // Case #5: If expression is a MulInt, combine coefs
427                    (_, s) => MulInt(coef, Box::new(s)), // Case #6: Otherwise, return the original
428                }
429            }
430            Div(a, q) => {
431                if q == 1 {
432                    return a.simplify_rec(scope, scenario);
433                } else if let Div(a, q2) = *a {
434                    return Div(a, q * q2).simplify_rec(scope, scenario);
435                }
436                let a = a.simplify_rec(scope, scenario);
437                if let Val(a) = a {
438                    Val(a / q as i64)
439                } else if let MulInt(-1, a) = a {
440                    MulInt(-1, b!(Div(a, q)))
441                } else if let Add(mut terms) = a {
442                    if terms.iter().any(|t| {
443                        if let MulInt(-1, s) = t {
444                            matches!(&**s, Sym(_))
445                        } else {
446                            false
447                        }
448                    }) {
449                        MulInt(
450                            -1,
451                            b!(Div(
452                                b!(Add(terms.into_iter().map(|t| MulInt(-1, b!(t))).collect())
453                                    .simplify_rec(scope, scenario)),
454                                q
455                            )),
456                        )
457                    } else if let Some(v) =
458                        terms.iter().find_map(|t| if let Val(v) = t { Some(*v) } else { None })
459                    {
460                        let offset = if v >= q as i64 {
461                            Some(v / q as i64)
462                        } else if v < 0 {
463                            Some(-Integer::div_ceil(&-v, &(q as i64)))
464                        } else {
465                            None
466                        };
467                        if let Some(val) = offset {
468                            terms.push(Val(-val * q as i64));
469                            Add(vec![
470                                Val(val),
471                                Div(b!(Add(terms).simplify_rec(scope, scenario)), q),
472                            ])
473                        } else {
474                            Div(b!(Add(terms)), q)
475                        }
476                    } else {
477                        Div(b!(Add(terms)), q)
478                    }
479                } else if let MulInt(p, a) = a {
480                    if p == q as i64 {
481                        a.simplify()
482                    } else {
483                        let gcd = p.abs().gcd(&(q as i64));
484                        if gcd == p {
485                            Div(a, q / gcd as u64)
486                        } else if gcd == q as i64 {
487                            MulInt(p / gcd, a)
488                        } else if gcd > 1 {
489                            Div(b!(MulInt(p / gcd, a)), q / gcd as u64)
490                                .simplify_rec(scope, scenario)
491                        } else {
492                            Div(b!(MulInt(p, a)), q)
493                        }
494                    }
495                } else {
496                    Div(b!(a), q)
497                }
498            }
499            Broadcast(terms) => {
500                let mut terms: Vec<TDim> = terms
501                    .iter()
502                    .map(|s| s.clone().simplify_rec(scope, scenario))
503                    .flat_map(|t| if let Broadcast(t) = t { t } else { vec![t] })
504                    .filter(|t| !t.is_one())
505                    .sorted_by(tdim_lexi_order)
506                    .dedup()
507                    .collect_vec();
508                // a#min(a,b) if a>0 && b>0 => a
509                match &*terms {
510                    [] => Val(1),
511                    [_] => terms.remove(0),
512                    [a, Min(m)] | [Min(m), a]
513                        if m.contains(a) && m.iter().all(|t| scope.prove_strict_positive(t)) =>
514                    {
515                        a.clone()
516                    }
517                    _ => Broadcast(terms),
518                }
519            }
520
521            Min(terms) => {
522                let mut flatten: Vec<TDim> = terms
523                    .into_iter()
524                    .map(|t| t.simplify_rec(scope, scenario))
525                    .flat_map(|t| if let Min(t) = t { t } else { vec![t] })
526                    .sorted_by(tdim_lexi_order)
527                    .dedup()
528                    .collect();
529                #[allow(clippy::mutable_key_type)]
530                let mut redundant = HashSet::<TDim>::default();
531                for pair in flatten.iter().permutations(2) {
532                    let (a, b) = (pair[0], pair[1]);
533                    if redundant.contains(a) || redundant.contains(b) {
534                        continue;
535                    }
536                    let diff = a.clone() - b;
537                    if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff)
538                    {
539                        redundant.insert(a.clone());
540                    }
541                }
542                flatten.retain(|t| !redundant.contains(t));
543                if flatten.len() == 0 {
544                    i64::MAX.to_dim()
545                } else if flatten.len() == 1 {
546                    flatten.into_iter().next().unwrap()
547                } else {
548                    Min(flatten)
549                }
550            }
551            Max(terms) => {
552                let mut flatten: Vec<TDim> = terms
553                    .into_iter()
554                    .map(|t| t.simplify_rec(scope, scenario))
555                    .flat_map(|t| if let Max(t) = t { t } else { vec![t] })
556                    .sorted_by(tdim_lexi_order)
557                    .dedup()
558                    .collect();
559                #[allow(clippy::mutable_key_type)]
560                let mut redundant = HashSet::<TDim>::default();
561                for pair in flatten.iter().permutations(2) {
562                    let (a, b) = (pair[0], pair[1]);
563                    if redundant.contains(a) || redundant.contains(b) {
564                        continue;
565                    }
566                    let diff = a.clone() - b;
567                    if diff.as_i64().is_some_and(|i| i >= 0) || scope.prove_positive_or_zero(&diff)
568                    {
569                        redundant.insert(b.clone());
570                    }
571                }
572                flatten.retain(|t| !redundant.contains(t));
573                if flatten.len() == 0 {
574                    i64::MIN.to_dim()
575                } else if flatten.len() == 1 {
576                    flatten.into_iter().next().unwrap()
577                } else {
578                    Max(flatten)
579                }
580            }
581            Sym(s) => scope
582                .assertions(scenario)
583                .find_map(|a| match a {
584                    Assertion::Eq(Sym(sym), v) if sym == &s => Some(v.clone()),
585                    _ => None,
586                })
587                .unwrap_or(Sym(s)),
588            Val(_) => self,
589        }
590    }
591
592    pub(super) fn inclusive_bound(&self, scope: &SymbolScopeData, upper: bool) -> Option<i64> {
593        use self::TDim::*;
594        match self {
595            Val(n) => Some(*n),
596            Sym(_) => {
597                if upper {
598                    scope
599                        .all_assertions()
600                        .iter()
601                        .filter_map(|assert| match &assert {
602                            Assertion::LT(left, right)
603                                if left == self && right.as_i64().is_some() =>
604                            {
605                                Some(right.as_i64().unwrap() - 1)
606                            }
607                            Assertion::LTE(left, right)
608                                if left == self && right.as_i64().is_some() =>
609                            {
610                                Some(right.as_i64().unwrap())
611                            }
612                            _ => None,
613                        })
614                        .min()
615                } else {
616                    scope
617                        .all_assertions()
618                        .iter()
619                        .filter_map(|assert| match &assert {
620                            Assertion::GT(left, right)
621                                if left == self && right.as_i64().is_some() =>
622                            {
623                                Some(right.as_i64().unwrap() + 1)
624                            }
625                            Assertion::GTE(left, right)
626                                if left == self && right.as_i64().is_some() =>
627                            {
628                                Some(right.as_i64().unwrap())
629                            }
630                            _ => None,
631                        })
632                        .max()
633                }
634            }
635            Add(terms) => {
636                let mut bound = 0;
637                for t in terms {
638                    if let Some(b) = t.inclusive_bound(scope, upper) {
639                        bound += b;
640                    } else {
641                        return None;
642                    }
643                }
644                Some(bound)
645            }
646            MulInt(p, a) => match p.cmp(&0) {
647                Ordering::Equal => Some(0),
648                Ordering::Greater => a.inclusive_bound(scope, upper).map(|x| x * p),
649                Ordering::Less => a.inclusive_bound(scope, !upper).map(|x| x * p),
650            },
651            Mul(_) => None,
652            Min(terms) if !upper => {
653                terms.iter().filter_map(|t| t.inclusive_bound(scope, false)).min()
654            }
655            Max(terms) if upper => {
656                terms.iter().filter_map(|t| t.inclusive_bound(scope, true)).max()
657            }
658            Div(a, q) => a.inclusive_bound(scope, upper).map(|x| x / (*q as i64)),
659            Broadcast(terms) => {
660                if upper {
661                    Max(terms.clone()).inclusive_bound(scope, true)
662                } else {
663                    Min(terms.clone()).inclusive_bound(scope, false)
664                }
665            }
666            _ => None,
667        }
668    }
669
670    pub fn low_inclusive_bound(&self) -> Option<i64> {
671        if let TDim::Val(v) = self {
672            return Some(*v);
673        }
674        let scope = self.find_scope()?;
675        let data = scope.0.lock();
676        let data = data.borrow();
677        self.inclusive_bound(&data, false)
678    }
679
680    pub fn high_inclusive_bound(&self) -> Option<i64> {
681        if let TDim::Val(v) = self {
682            return Some(*v);
683        }
684        let scope = self.find_scope()?;
685        let data = scope.0.lock();
686        let data = data.borrow();
687        self.inclusive_bound(&data, true)
688    }
689
690    pub fn prove_positive_or_zero(&self) -> bool {
691        if let TDim::Val(v) = self {
692            return *v >= 0;
693        }
694        let Some(scope) = self.find_scope() else { return false };
695        let data = scope.0.lock();
696        let data = data.borrow();
697        data.prove_positive_or_zero(self)
698    }
699
700    pub fn prove_strict_positive(&self) -> bool {
701        if let TDim::Val(v) = self {
702            return *v > 0;
703        }
704        (self.clone() - 1).prove_positive_or_zero()
705    }
706
707    pub fn prove_negative_or_zero(&self) -> bool {
708        if let TDim::Val(v) = self {
709            return *v <= 0;
710        }
711        self.clone().neg().prove_positive_or_zero()
712    }
713
714    pub fn prove_strict_negative(&self) -> bool {
715        if let TDim::Val(v) = self {
716            return *v < 0;
717        }
718        self.clone().neg().prove_strict_positive()
719    }
720
721    pub fn gcd(&self) -> u64 {
722        use self::TDim::*;
723        match self {
724            Val(v) => v.unsigned_abs(),
725            Sym(_) => 1,
726            Add(terms) => {
727                let (head, tail) = terms.split_first().unwrap();
728                tail.iter().fold(head.gcd(), |a, b| a.gcd(&b.gcd()))
729            }
730            MulInt(p, a) => a.gcd() * p.unsigned_abs(),
731            Mul(terms) => terms.iter().map(|t| t.gcd()).product(),
732            Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
733            Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
734            Div(a, q) => {
735                if a.gcd() % *q == 0 {
736                    a.gcd() / *q
737                } else {
738                    1
739                }
740            }
741            Broadcast(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap_or(1),
742        }
743    }
744
745    fn div(&self, d: u64) -> TDim {
746        use self::TDim::*;
747        if d == 1 {
748            return self.clone();
749        }
750        match self {
751            Val(v) => Val(v / d as i64),
752            Sym(_) => panic!(),
753            Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
754            Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
755            Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
756            Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
757            Mul(_) => Div(Box::new(self.clone()), d),
758            MulInt(p, a) => {
759                if *p == d as i64 {
760                    (**a).clone()
761                } else {
762                    let gcd = p.unsigned_abs().gcd(&d);
763                    MulInt(p / gcd as i64, b!(a.div(d / gcd)))
764                }
765            }
766            Div(a, q) => Div(a.clone(), q * d),
767        }
768    }
769
770    pub fn div_ceil(self, rhs: u64) -> TDim {
771        TDim::Div(Box::new(Add(vec![self, Val(rhs as i64 - 1)])), rhs).reduce()
772    }
773
774    pub(super) fn guess_slope(&self, sym: &Symbol) -> (i64, u64) {
775        fn slope_rec(d: &TDim, sym: &Symbol) -> (i64, i64) {
776            match d {
777                Val(_) => (0, 1),
778                Sym(s) => ((sym == s) as i64, 1),
779                Add(terms) => terms
780                    .iter()
781                    .map(|d| slope_rec(d, sym))
782                    .fold((0, 1), |a, b| ((a.0 * b.1 + a.1 * b.0), (b.1 * a.1))),
783                Mul(terms) => terms
784                    .iter()
785                    .map(|d| slope_rec(d, sym))
786                    .fold((1, 1), |a, b| ((a.0 * b.0), (b.1 * a.1))),
787                MulInt(p, a) => {
788                    let (n, d) = slope_rec(a, sym);
789                    (p * n, d)
790                }
791                Div(a, q) => {
792                    let (n, d) = slope_rec(a, sym);
793                    (n, d * *q as i64)
794                }
795                Broadcast(terms) => slope_rec(&terms[0], sym),
796                Min(terms) => slope_rec(&terms[0], sym),
797                Max(terms) => slope_rec(&terms[0], sym),
798            }
799        }
800        let (p, q) = slope_rec(self, sym);
801        reduce_ratio(p, q)
802    }
803
804    #[allow(clippy::mutable_key_type)]
805    pub fn symbols(&self) -> std::collections::HashSet<Symbol> {
806        match self {
807            Val(_) => maplit::hashset!(),
808            Sym(s) => maplit::hashset!(s.clone()),
809            Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
810                terms.iter().fold(maplit::hashset!(), |mut set, v| {
811                    set.extend(v.symbols());
812                    set
813                })
814            }
815            MulInt(_, a) => a.symbols(),
816            Div(a, _) => a.symbols(),
817        }
818    }
819
820    pub fn compatible_with(&self, other: &TDim) -> bool {
821        if let Ok(x) = (self.clone() - other).to_i64() {
822            return x == 0;
823        }
824        true // maybe ? :)
825    }
826}
827
828pub(super) fn reduce_ratio(mut p: i64, mut q: i64) -> (i64, u64) {
829    let gcd = p.abs().gcd(&q.abs());
830    if gcd > 1 {
831        p /= gcd;
832        q /= gcd;
833    }
834    if q < 0 {
835        (-p, (-q) as u64)
836    } else {
837        (p, q as u64)
838    }
839}
840
841impl Zero for TDim {
842    fn zero() -> Self {
843        Val(0)
844    }
845    fn is_zero(&self) -> bool {
846        matches!(self, Val(0))
847    }
848}
849
850impl Default for TDim {
851    fn default() -> TDim {
852        Val(0)
853    }
854}
855
856impl num_traits::Bounded for TDim {
857    fn min_value() -> Self {
858        TDim::Val(i64::MIN)
859    }
860
861    fn max_value() -> Self {
862        TDim::Val(i64::MAX)
863    }
864}
865
866impl num_traits::One for TDim {
867    fn one() -> Self {
868        TDim::Val(1)
869    }
870}
871
872impl ::std::iter::Sum for TDim {
873    fn sum<I: Iterator<Item = TDim>>(iter: I) -> TDim {
874        iter.fold(0.into(), |a, b| a + b)
875    }
876}
877
878impl<'a> ::std::iter::Sum<&'a TDim> for TDim {
879    fn sum<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
880        iter.fold(0.into(), |a, b| a + b)
881    }
882}
883
884impl std::iter::Product for TDim {
885    fn product<I: Iterator<Item = TDim>>(iter: I) -> Self {
886        iter.fold(TDim::Val(1), |a, b| a * b)
887    }
888}
889
890impl<'a> ::std::iter::Product<&'a TDim> for TDim {
891    fn product<I: Iterator<Item = &'a TDim>>(iter: I) -> TDim {
892        iter.fold(1.into(), |a, b| a * b)
893    }
894}
895
896macro_rules! from_i {
897    ($i: ty) => {
898        impl From<$i> for TDim {
899            fn from(v: $i) -> TDim {
900                TDim::Val(v as _)
901            }
902        }
903        impl<'a> From<&'a $i> for TDim {
904            fn from(v: &'a $i) -> TDim {
905                TDim::Val(*v as _)
906            }
907        }
908    };
909}
910
911from_i!(i32);
912from_i!(i64);
913from_i!(u64);
914from_i!(isize);
915from_i!(usize);
916
917impl From<Symbol> for TDim {
918    fn from(it: Symbol) -> Self {
919        TDim::Sym(it)
920    }
921}
922
923impl<'a> From<&'a Symbol> for TDim {
924    fn from(it: &'a Symbol) -> Self {
925        TDim::Sym(it.clone())
926    }
927}
928
929impl ops::Neg for TDim {
930    type Output = Self;
931    fn neg(self) -> Self {
932        if let Val(v) = self {
933            Val(-v)
934        } else {
935            TDim::MulInt(-1, Box::new(self)).reduce()
936        }
937    }
938}
939
940impl<'a> ops::AddAssign<&'a TDim> for TDim {
941    fn add_assign(&mut self, rhs: &'a TDim) {
942        if rhs.is_zero() {
943        } else if self.is_zero() {
944            *self = rhs.clone();
945        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
946            *s += o;
947        } else {
948            *self = TDim::Add(vec![std::mem::take(self), rhs.clone()]).reduce()
949        }
950    }
951}
952
953impl<I> ops::AddAssign<I> for TDim
954where
955    I: Into<TDim>,
956{
957    fn add_assign(&mut self, rhs: I) {
958        let rhs = rhs.into();
959        if rhs.is_zero() {
960        } else if self.is_zero() {
961            *self = rhs;
962        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
963            *s += o;
964        } else {
965            *self = TDim::Add(vec![std::mem::take(self), rhs]).reduce()
966        }
967    }
968}
969
970impl<I> ops::Add<I> for TDim
971where
972    I: Into<TDim>,
973{
974    type Output = Self;
975    fn add(mut self, rhs: I) -> Self {
976        self += rhs;
977        self
978    }
979}
980
981impl<'a> ops::Add<&'a TDim> for TDim {
982    type Output = Self;
983    fn add(mut self, rhs: &'a TDim) -> Self {
984        self += rhs;
985        self
986    }
987}
988
989#[allow(clippy::suspicious_op_assign_impl)]
990impl<'a> ops::SubAssign<&'a TDim> for TDim {
991    fn sub_assign(&mut self, rhs: &'a TDim) {
992        if rhs.is_zero() {
993        } else if self.is_zero() {
994            *self = rhs.clone().neg();
995        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
996            *s -= o;
997        } else {
998            *self = TDim::Add(vec![std::mem::take(self), rhs.clone().neg()]).reduce()
999        }
1000    }
1001}
1002
1003impl<I> ops::SubAssign<I> for TDim
1004where
1005    I: Into<TDim>,
1006{
1007    fn sub_assign(&mut self, rhs: I) {
1008        let rhs = rhs.into();
1009        if rhs.is_zero() {
1010        } else if self.is_zero() {
1011            *self = rhs.neg();
1012        } else if let (Val(s), Val(o)) = (&mut *self, &rhs) {
1013            *s -= o;
1014        } else {
1015            *self = TDim::Add(vec![std::mem::take(self), rhs.neg()]).reduce()
1016        }
1017    }
1018}
1019
1020impl<I> ops::Sub<I> for TDim
1021where
1022    I: Into<TDim>,
1023{
1024    type Output = Self;
1025    fn sub(mut self, rhs: I) -> Self {
1026        self -= rhs;
1027        self
1028    }
1029}
1030
1031impl<'a> ops::Sub<&'a TDim> for TDim {
1032    type Output = Self;
1033    fn sub(mut self, rhs: &'a TDim) -> Self {
1034        self -= rhs;
1035        self
1036    }
1037}
1038
1039impl<I: Into<TDim>> ops::MulAssign<I> for TDim {
1040    fn mul_assign(&mut self, rhs: I) {
1041        let rhs = rhs.into();
1042        if self.is_one() {
1043            *self = rhs
1044        } else if rhs.is_one() {
1045        } else {
1046            *self = TDim::Mul(vec![rhs, std::mem::take(self)]).reduce()
1047        }
1048    }
1049}
1050
1051impl<'a> ops::MulAssign<&'a TDim> for TDim {
1052    fn mul_assign(&mut self, rhs: &'a TDim) {
1053        if self.is_one() {
1054            *self = rhs.clone()
1055        } else if rhs.is_one() {
1056        } else {
1057            *self = TDim::Mul(vec![std::mem::take(self), rhs.clone()]).reduce()
1058        }
1059    }
1060}
1061
1062impl<I: Into<TDim>> ops::Mul<I> for TDim {
1063    type Output = Self;
1064    fn mul(mut self, rhs: I) -> Self {
1065        self *= rhs.into();
1066        self
1067    }
1068}
1069
1070impl<'a> ops::Mul<&'a TDim> for TDim {
1071    type Output = Self;
1072    fn mul(mut self, rhs: &'a TDim) -> Self {
1073        self *= rhs;
1074        self
1075    }
1076}
1077
1078impl<I: AsPrimitive<u64> + PrimInt> ops::DivAssign<I> for TDim {
1079    fn div_assign(&mut self, rhs: I) {
1080        *self = TDim::Div(Box::new(std::mem::take(self)), rhs.as_()).reduce()
1081    }
1082}
1083
1084impl<I: AsPrimitive<u64> + PrimInt> ops::Div<I> for TDim {
1085    type Output = Self;
1086    fn div(mut self, rhs: I) -> Self {
1087        self /= rhs.as_();
1088        self
1089    }
1090}
1091
1092impl<I: AsPrimitive<u64> + PrimInt> ops::RemAssign<I> for TDim {
1093    fn rem_assign(&mut self, rhs: I) {
1094        *self += -(self.clone() / rhs.as_() * rhs.as_());
1095    }
1096}
1097
1098impl<I: AsPrimitive<u64> + PrimInt> ops::Rem<I> for TDim {
1099    type Output = Self;
1100    fn rem(mut self, rhs: I) -> Self {
1101        self %= rhs;
1102        self
1103    }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108    use super::*;
1109
1110    macro_rules! b( ($e:expr) => { Box::new($e) } );
1111
1112    lazy_static::lazy_static! {
1113        static ref table: SymbolScope = SymbolScope::default();
1114        static ref A: Symbol = table.sym("a");
1115        static ref B: Symbol = table.sym("b");
1116    }
1117
1118    fn neg(a: &TDim) -> TDim {
1119        mul(-1, a)
1120    }
1121
1122    fn add(a: &TDim, b: &TDim) -> TDim {
1123        TDim::Add(vec![a.clone(), b.clone()])
1124    }
1125
1126    fn mul(a: i64, b: &TDim) -> TDim {
1127        TDim::MulInt(a, b![b.clone()])
1128    }
1129
1130    fn div(a: &TDim, b: u64) -> TDim {
1131        TDim::Div(b!(a.clone()), b)
1132    }
1133
1134    #[test]
1135    fn reduce_add() {
1136        assert_eq!(add(&A.to_dim(), &neg(&A.to_dim())).reduce(), Val(0))
1137    }
1138
1139    #[test]
1140    fn reduce_neg_mul() {
1141        assert_eq!(neg(&mul(2, &A.to_dim())).reduce(), mul(-2, &A.to_dim()))
1142    }
1143
1144    #[test]
1145    fn reduce_cplx_ex_2() {
1146        assert_eq!(
1147            add(
1148                &add(&Val(-4), &mul(-2, &div(&A.to_dim(), 4))),
1149                &mul(-2, &mul(-1, &div(&A.to_dim(), 4)))
1150            )
1151            .reduce(),
1152            Val(-4)
1153        )
1154    }
1155
1156    #[test]
1157    fn reduce_cplx_ex_3() {
1158        assert_eq!(div(&MulInt(1, b!(MulInt(4, b!(A.to_dim())))), 4).reduce(), A.to_dim())
1159    }
1160
1161    #[test]
1162    fn reduce_cplx_ex_4() {
1163        // (S+1)/2 + (1-S)/2 == 1
1164        assert_eq!(
1165            add(&div(&add(&A.to_dim(), &Val(1)), 2), &div(&add(&neg(&A.to_dim()), &Val(1)), 2))
1166                .reduce(),
1167            1.into()
1168        );
1169    }
1170
1171    #[test]
1172    fn reduce_mul_mul_1() {
1173        assert_eq!(mul(3, &mul(2, &A.to_dim())).reduce(), mul(6, &A.to_dim()))
1174    }
1175
1176    #[test]
1177    fn reduce_mul_mul_2() {
1178        assert_eq!(mul(-2, &mul(-1, &A.to_dim())).reduce(), mul(2, &A.to_dim()))
1179    }
1180
1181    #[test]
1182    fn reduce_mul_div_1() {
1183        assert_eq!(mul(2, &div(&mul(-1, &A.to_dim()), 3)).reduce(), mul(-2, &div(&A.to_dim(), 3)))
1184    }
1185
1186    #[test]
1187    fn const_and_add() {
1188        let e: TDim = 2i64.into();
1189        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 2);
1190        let e: TDim = TDim::from(2) + 3;
1191        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), 5);
1192        let e: TDim = TDim::from(2) - 3;
1193        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -1);
1194        let e: TDim = -TDim::from(2);
1195        assert_eq!(e.eval(&SymbolValues::default()).to_i64().unwrap(), -2);
1196    }
1197
1198    #[test]
1199    fn substitution() {
1200        let a: TDim = A.to_dim();
1201        assert_eq!(a.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 2);
1202        let e = a + 3;
1203        assert_eq!(e.eval(&SymbolValues::default().with(&A, 2)).to_i64().unwrap(), 5);
1204    }
1205
1206    #[test]
1207    fn reduce_adds() {
1208        let e: TDim = TDim::from(2) + 1;
1209        assert_eq!(e, TDim::from(3));
1210        let e: TDim = TDim::from(3) + 2;
1211        assert_eq!(e, TDim::from(5));
1212        let e: TDim = TDim::from(3) + 0;
1213        assert_eq!(e, TDim::from(3));
1214        let e: TDim = TDim::from(3) + 2 + 1;
1215        assert_eq!(e, TDim::from(6));
1216    }
1217
1218    #[test]
1219    fn reduce_muls() {
1220        let e: TDim = Val(1) * A.to_dim();
1221        assert_eq!(e, A.to_dim());
1222        let e: TDim = A.to_dim() * &B.to_dim() * 1;
1223        assert_eq!(e, A.to_dim() * &B.to_dim());
1224    }
1225
1226    #[test]
1227    fn reduce_divs() {
1228        let e: TDim = TDim::from(2) / 1;
1229        assert_eq!(e, TDim::from(2));
1230        let e: TDim = TDim::from(3) / 2;
1231        assert_eq!(e, TDim::from(1));
1232        let e: TDim = TDim::from(3) % 2;
1233        assert_eq!(e, TDim::from(1));
1234        let e: TDim = TDim::from(5) / 2;
1235        assert_eq!(e, TDim::from(2));
1236        let e: TDim = TDim::from(5) % 2;
1237        assert_eq!(e, TDim::from(1));
1238    }
1239
1240    #[test]
1241    fn reduce_div_bug_0() {
1242        let e1: TDim = (A.to_dim() + 23) / 2 - 1;
1243        let e2: TDim = (A.to_dim() + 21) / 2;
1244        assert_eq!(e1, e2);
1245    }
1246
1247    #[test]
1248    fn reduce_div_bug_1() {
1249        let e1: TDim = (A.to_dim() + -1) / 2;
1250        let e2: TDim = (A.to_dim() + 1) / 2 - 1;
1251        assert_eq!(e1, e2);
1252    }
1253
1254    #[test]
1255    fn reduce_div_bug_2() {
1256        let e1: TDim = ((A.to_dim() + 1) / 2 + 1) / 2;
1257        let e2: TDim = (A.to_dim() + 3) / 4;
1258        assert_eq!(e1, e2);
1259    }
1260
1261    #[test]
1262    fn reduce_div_bug_3() {
1263        let e1: TDim = (A.to_dim() / 2) * -4;
1264        let e2: TDim = (A.to_dim() / 2) * -4 / 1;
1265        assert_eq!(e1, e2);
1266    }
1267
1268    #[test]
1269    fn reduce_mul_div() {
1270        let e: TDim = A.to_dim() * 2 / 2;
1271        assert_eq!(e, A.to_dim());
1272    }
1273
1274    #[test]
1275    fn reduce_div_mul() {
1276        let e: TDim = A.to_dim() / 2 * 2;
1277        assert_ne!(e, A.to_dim());
1278    }
1279
1280    #[test]
1281    fn reduce_add_div() {
1282        let e: TDim = A.to_dim() / 2 + 1;
1283        assert_eq!(e, ((A.to_dim() + 2) / 2));
1284    }
1285
1286    #[test]
1287    fn reduce_neg_mul_() {
1288        let e: TDim = TDim::from(1) - A.to_dim() * 2;
1289        assert_eq!(e, TDim::from(1) + A.to_dim() * -2);
1290    }
1291
1292    #[test]
1293    fn reduce_add_rem_1() {
1294        assert_eq!(((A.to_dim() + 4) % 2), (A.to_dim() % 2));
1295    }
1296
1297    #[test]
1298    fn reduce_add_rem_2() {
1299        assert_eq!(((A.to_dim() - 4) % 2), (A.to_dim() % 2));
1300    }
1301
1302    #[test]
1303    fn reduce_rem_div() {
1304        let e: TDim = A.to_dim() % 2 / 2;
1305        assert_eq!(e, TDim::from(0));
1306    }
1307
1308    #[test]
1309    fn conv2d_ex_1() {
1310        let e = (TDim::from(1) - 1 + 1).div_ceil(1);
1311        assert_eq!(e, TDim::from(1));
1312    }
1313
1314    #[test]
1315    fn conv2d_ex_2() {
1316        let e = (A.to_dim() - 3 + 1).div_ceil(1);
1317        assert_eq!(e, A.to_dim() + -2);
1318    }
1319
1320    #[test]
1321    fn extract_int_gcd_from_muls() {
1322        let term = (A.to_dim() + 1) / 4;
1323        let mul = (term.clone() * 24 - 24) * (term.clone() * 2 - 2);
1324        let target = (term.clone() - 1) * (term.clone() - 1) * 48;
1325        assert_eq!(mul, target);
1326    }
1327
1328    #[test]
1329    fn equality_of_muls() {
1330        let term = (A.to_dim() + 1) / 4;
1331        let mul1 = (term.clone() * 2 - 3) * (term.clone() - 1);
1332        let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
1333        assert_eq!(mul1, mul2);
1334    }
1335
1336    #[test]
1337    fn factorize_complex_expr_times_int() {
1338        let term = (A.to_dim() + 1) / 4;
1339        let e = term.clone() * 2 - &term - 1;
1340        assert_eq!(e, term - 1);
1341    }
1342
1343    #[test]
1344    fn broadcast_over_min() {
1345        // assuming a>0, b>0 then a#min(a,b) can be replaced by a
1346        // proof:
1347        //    if b == 1 => min(a,b)=1 => a#1=a => ok
1348        //    if a <= b => min(a,b)=a => ok
1349        //    if 1 < B < A => expression was invalid, we're generalizing over the non-domain and ignoring the constraint
1350        for a in 1..5 {
1351            for b in 1..5 {
1352                if b > 1 && a > b {
1353                    assert!(a.broadcast(a.min(b)).is_err());
1354                } else {
1355                    assert_eq!(a.broadcast(a.min(b)).unwrap(), a);
1356                }
1357            }
1358        }
1359    }
1360
1361    #[test]
1362    fn min_ints_1() {
1363        assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
1364    }
1365
1366    #[test]
1367    fn min_ints_2() {
1368        assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
1369    }
1370
1371    #[test]
1372    fn min_same() {
1373        assert_eq!(A.to_dim().mini(A.to_dim()), A.to_dim());
1374    }
1375
1376    #[test]
1377    fn min_noop() {
1378        assert_eq!(A.to_dim().mini(1.to_dim()), A.to_dim().mini(1.to_dim()));
1379    }
1380
1381    #[test]
1382    fn min_diff_1() {
1383        assert_eq!((A.to_dim() + 1).mini(A.to_dim() + 2), A.to_dim() + 1);
1384    }
1385
1386    #[test]
1387    fn slope_0() {
1388        assert_eq!(12.to_dim().guess_slope(&A), (0, 1));
1389    }
1390
1391    #[test]
1392    fn slope_1() {
1393        assert_eq!(A.to_dim().guess_slope(&A), (1, 1));
1394    }
1395
1396    #[test]
1397    fn slope_2() {
1398        assert_eq!((A.to_dim() * 2).guess_slope(&A), (2, 1));
1399    }
1400
1401    #[test]
1402    fn slope_3() {
1403        assert_eq!((A.to_dim() * 2 + A.to_dim() / 2).guess_slope(&A), (5, 2));
1404    }
1405
1406    #[test]
1407    fn slope_4() {
1408        assert_eq!((A.to_dim()).guess_slope(&B), (0, 1));
1409    }
1410
1411    #[test]
1412    fn slope_5() {
1413        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1414        assert_eq!((A.to_dim() + 1).guess_slope(&B), (0, 1));
1415    }
1416
1417    #[test]
1418    fn slope_6() {
1419        assert_eq!((A.to_dim() + 1).guess_slope(&A), (1, 1));
1420        assert_eq!((A.to_dim() + B.to_dim()).guess_slope(&B), (1, 1));
1421    }
1422
1423    #[test]
1424    fn min_0() -> TractResult<()> {
1425        let symbols = SymbolScope::default();
1426        assert_eq!(
1427            symbols.parse_tdim("min(S+3, S+2)").unwrap().simplify(),
1428            symbols.parse_tdim("S+2").unwrap(),
1429        );
1430        Ok(())
1431    }
1432}