flag_algebra/
expr.rs

1//! Expression of computations in the flag algebra for prettyprinting.
2
3use crate::operator::{Basis, Savable, Type};
4use std::collections::BTreeMap;
5use std::fmt::*;
6use std::ops::{Add, Mul, Neg, Sub};
7use std::rc::Rc;
8
9pub type CoefficientFn<F, N> = Rc<dyn Fn(&F, usize) -> N>;
10pub type IndicatorFn<F> = Rc<dyn Fn(&F, usize) -> bool>;
11
12/// Expressions that represent a computation in flag algebras.
13pub enum Expr<N, F: Flag> {
14    Add(RcExpr<N, F>, RcExpr<N, F>),
15    Mul(RcExpr<N, F>, RcExpr<N, F>),
16    Neg(RcExpr<N, F>),
17    Unlab(RcExpr<N, F>),
18    Zero,
19    One,
20    Num(Rc<N>),
21    Named(RcExpr<N, F>, Rc<String>, bool),
22    Var(usize),
23    Flag(usize, Basis<F>),
24    FromFunction(CoefficientFn<F, N>, Basis<F>),
25    FromIndicator(IndicatorFn<F>, Basis<F>),
26    Unknown,
27}
28
29// Straightforward trait implementations
30// (derive is too conservative when working with Rc or PhantomData,
31// follow https://github.com/rust-lang/rust/issues/26925)
32impl<N, F: Flag> Clone for Expr<N, F> {
33    fn clone(&self) -> Self {
34        match self {
35            Add(a, b) => Add(a.clone(), b.clone()),
36            Mul(a, b) => Mul(a.clone(), b.clone()),
37            Neg(a) => Neg(a.clone()),
38            Unlab(a) => Unlab(a.clone()),
39            Num(a) => Num(a.clone()),
40            Var(a) => Var(*a),
41            Named(a, b, c) => Named(a.clone(), b.clone(), *c),
42            Flag(a, b) => Flag(*a, *b),
43            FromFunction(a, b) => FromFunction(a.clone(), *b),
44            FromIndicator(a, b) => FromIndicator(a.clone(), *b),
45            Unknown => Unknown,
46            Zero => Zero,
47            One => One,
48        }
49    }
50}
51
52#[derive(Debug, Clone)]
53pub enum VarRange<F: Flag> {
54    InBasis(Basis<F>),
55}
56
57#[derive(Debug, Clone)]
58pub struct Names<N, F: Flag> {
59    pub flags: BTreeMap<(usize, Basis<F>), String>,
60    pub types: BTreeMap<Type<F>, String>,
61    pub functions: Vec<(String, QFlag<N, F>)>,
62    pub sets: Vec<(String, Basis<F>, Vec<F>)>,
63}
64
65impl<N, F: Flag> Default for Names<N, F> {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl<N, F: Flag> Names<N, F> {
72    pub fn new() -> Self {
73        Self {
74            flags: BTreeMap::new(),
75            types: BTreeMap::new(),
76            functions: Vec::new(),
77            sets: Vec::new(),
78        }
79    }
80    pub fn is_empty(&self) -> bool {
81        self.flags.is_empty()
82            && self.types.is_empty()
83            && self.functions.is_empty()
84            && self.sets.is_empty()
85    }
86    fn name_flag(&mut self, i: usize, basis: Basis<F>) -> String
87    where
88        F: Ord,
89    {
90        self.flags
91            .entry((i, basis))
92            .or_insert_with(|| format!("F_{{{}}}^{{{}}}", i, basis.print_concise()))
93            .clone()
94    }
95    fn name_type(&mut self, t: Type<F>) -> String {
96        let i = self.types.len();
97        self.types
98            .entry(t)
99            .or_insert_with(|| {
100                if i == 0 {
101                    "\\sigma".to_string()
102                } else {
103                    format!("\\sigma_{i}")
104                }
105            })
106            .clone()
107    }
108    fn name_set(&mut self, f: IndicatorFn<F>, basis: Basis<F>) -> String
109    where
110        F: Flag,
111    {
112        let name = format!("S_{}", self.sets.len() + 1);
113        let mut set = basis.get();
114        set.retain(|x| f(x, basis.t.size));
115        self.sets.push((name.clone(), basis, set));
116        name
117    }
118    fn name_function(&mut self, f: CoefficientFn<F, N>, basis: Basis<F>) -> String
119    where
120        F: Flag,
121    {
122        let name = format!("f_{}", self.functions.len() + 1);
123        self.functions
124            .push((name.clone(), basis.qflag_from_coeff_rc(f)));
125        name
126    }
127}
128
129use Expr::*;
130use VarRange::*;
131
132impl<F: Flag> VarRange<F> {
133    fn eval<N>(&self, i: usize) -> Expr<N, F> {
134        match self {
135            InBasis(basis) => Flag(i, *basis),
136        }
137    }
138    pub(crate) fn latex<N>(&self, names: &mut Names<N, F>) -> String {
139        match self {
140            InBasis(basis) => format!("\\forall H\\in {},\\quad ", latex_basis(basis, names)),
141        }
142    }
143}
144
145impl<N, F: Flag> Add for Expr<N, F> {
146    type Output = Self;
147
148    fn add(self, b: Self) -> Self {
149        Add(Rc::new(self), Rc::new(b))
150    }
151}
152
153impl<N, F: Flag> Neg for Expr<N, F> {
154    type Output = Self;
155
156    fn neg(self) -> Self {
157        Neg(Rc::new(self))
158    }
159}
160
161impl<N, F: Flag> Sub for Expr<N, F> {
162    type Output = Self;
163
164    fn sub(self, other: Self) -> Self {
165        self + (-other)
166    }
167}
168
169impl<N, F: Flag> Mul for Expr<N, F> {
170    type Output = Self;
171
172    fn mul(self, b: Self) -> Self {
173        Mul(Rc::new(self), Rc::new(b))
174    }
175}
176
177type RcExpr<N, F> = Rc<Expr<N, F>>;
178
179impl<N, F: Flag> Expr<N, F> {
180    pub fn unlab(self) -> Self {
181        Unlab(Rc::new(self))
182    }
183    pub fn named(self, name: String) -> Self {
184        Named(Rc::new(self), Rc::new(name), false)
185    }
186    pub fn unknown(name: String) -> Self {
187        Unknown.named(name)
188    }
189    pub fn num(n: &N) -> Self
190    where
191        N: num::Num + Clone,
192    {
193        if n == &N::zero() {
194            Zero
195        } else if n == &N::one() {
196            One
197        } else {
198            Num(Rc::new(n.clone()))
199        }
200    }
201    fn simplify(&self) -> Self
202    where
203        Expr<N, F>: Clone,
204    {
205        match self {
206            Add(a0, b0) => match (a0.simplify(), b0.simplify()) {
207                (Zero, a) | (a, Zero) => a,
208                (a, b) => a + b,
209            },
210            Mul(a0, b0) => match (a0.simplify(), b0.simplify()) {
211                (One, a) | (a, One) => a,
212                (Zero, _) | (_, Zero) => Zero,
213                (a, b) => a * b,
214            },
215            Neg(a0) => match a0.simplify() {
216                Zero => Zero,
217                a => -a,
218            },
219            Unlab(a0) => match a0.simplify() {
220                Zero => Zero,
221                a => Self::unlab(a),
222            },
223            a => a.clone(),
224        }
225    }
226    fn is_sum(&self) -> bool {
227        matches!(self, Add(_, _))
228    }
229    pub fn latex(&self, names: &mut Names<N, F>) -> String
230    where
231        N: Display,
232        F: Ord,
233    {
234        self.simplify().latex0(names)
235    }
236    fn latex0(&self, names: &mut Names<N, F>) -> String
237    where
238        N: Display,
239        F: Ord,
240    {
241        match self {
242            Add(a, b) => {
243                if let Neg(b1) = &**b {
244                    format!("{} - {}", a.latex0(names), Paren(b1).latex(names))
245                } else {
246                    format!("{} + {}", a.latex0(names), b.latex0(names))
247                }
248            }
249            Mul(a, b) => format!("{}\\cdot {}", Paren(a).latex(names), Paren(b).latex(names)),
250            Neg(a) => format!("-{}", Paren(a).latex(names)),
251            Unlab(a) => format!(
252                "\\left[\\!\\!\\left[{}\\right]\\!\\!\\right]",
253                a.latex0(names)
254            ),
255            Zero => "0".into(),
256            One => "1".into(),
257            Num(s) => format!("{s}"),
258            Var(_) => "H".into(),
259            Named(e, name, latex) => {
260                if *latex {
261                    format!("\\textrm{{{name}}}")
262                } else {
263                    e.latex0(names)
264                }
265            }
266            Flag(i, basis) => names.name_flag(*i, *basis),
267            FromFunction(f, b) => format!(
268                "\\sum_{{F\\in{}}} {}(F)F",
269                latex_basis(b, names),
270                names.name_function(f.clone(), *b)
271            ),
272            FromIndicator(f, b) => format!(
273                "\\sum_{{F\\in {}\\subseteq{}}}F",
274                names.name_set(f.clone(), *b),
275                latex_basis(b, names)
276            ),
277            Unknown => "Unknown".into(),
278        }
279    }
280}
281
282fn latex_basis<N, F: Flag>(basis: &Basis<F>, names: &mut Names<N, F>) -> String {
283    if basis.t.is_empty() {
284        format!("\\mathcal{{F}}_{{{}}}", basis.size)
285    } else {
286        format!(
287            "\\mathcal{{F}}^{{{}}}_{{{}}}",
288            names.name_type(basis.t),
289            basis.size
290        )
291    }
292}
293
294struct Paren<'a, N, F: Flag>(&'a Expr<N, F>);
295
296impl<'a, N, F: Flag> Display for Paren<'a, N, F>
297where
298    Expr<N, F>: Display,
299{
300    fn fmt(&self, f: &mut Formatter) -> Result {
301        if self.0.is_sum() {
302            write!(f, "({})", self.0)
303        } else {
304            write!(f, "{}", self.0)
305        }
306    }
307}
308
309impl<'a, N, F> Paren<'a, N, F>
310where
311    N: Display,
312    F: Ord + Flag,
313{
314    fn latex(&self, names: &mut Names<N, F>) -> String {
315        if self.0.is_sum() {
316            format!("\\left({}\\right)", self.0.latex0(names))
317        } else {
318            self.0.latex0(names)
319        }
320    }
321}
322
323impl<N, F: Flag> Display for Expr<N, F>
324where
325    N: Display,
326{
327    fn fmt(&self, f: &mut Formatter) -> Result {
328        match self.simplify() {
329            Add(a, b) => {
330                if let Neg(b1) = &*b {
331                    write!(f, "{} - {}", a, Paren(b1))
332                } else {
333                    write!(f, "{a} + {b}")
334                }
335            }
336            Mul(a, b) => write!(f, "{}*{}", Paren(&a), Paren(&b)),
337            Neg(a) => write!(f, "-{}", Paren(&a)),
338            Unlab(a) => write!(f, "[|{a}|]"),
339            Zero => write!(f, "0"),
340            One => write!(f, "1"),
341            Num(s) => write!(f, "{s}"),
342            Var(_) => write!(f, "x"),
343            Named(_, name, _) => write!(f, "{name}"),
344            Flag(i, basis) => write!(f, "flag({}:{})", i, basis.print_concise()),
345            FromFunction(_, _) => write!(f, "Σ f(F)F"),
346            FromIndicator(_, _) => write!(f, "Σ F"),
347            Unknown => write!(f, "unknown"),
348        }
349    }
350}
351
352impl<N, F> Debug for Expr<N, F>
353where
354    F: Flag + Debug,
355    N: Debug,
356{
357    fn fmt(&self, f: &mut Formatter) -> Result {
358        match self {
359            Add(a, b) => write!(f, "Add({a:?}, {b:?})"),
360            Mul(a, b) => write!(f, "Mul({a:?}, {b:?})"),
361            Named(a, b, c) => write!(f, "Named({a:?}, {b:?}, {c:?})"),
362            Flag(a, b) => write!(f, "Flag({a:?}, {b:?})"),
363            Neg(a) => write!(f, "Neg({a:?})"),
364            Unlab(a) => write!(f, "Unlab({a:?})"),
365            Num(a) => write!(f, "Num({a:?})"),
366            Var(a) => write!(f, "Var({a:?})"),
367            FromFunction(_, b) => write!(f, "FromFunction(_, {b:?})"),
368            FromIndicator(_, b) => write!(f, "FromIndicator(_, {b:?})"),
369            Unknown => write!(f, "Unknown"),
370            Zero => write!(f, "Zero"),
371            One => write!(f, "One"),
372        }
373    }
374}
375
376use crate::Flag;
377/// Expression evaluation
378use crate::QFlag;
379use ndarray::ScalarOperand;
380use num::FromPrimitive;
381
382#[derive(Clone, Debug)]
383enum Val<N, F: Flag> {
384    Num(N),
385    QFlag(QFlag<N, F>),
386}
387
388impl<N, F> Val<N, F>
389where
390    N: num::Num + Clone + Neg<Output = N>,
391    F: Flag,
392{
393    fn unwrap_qflag(self) -> QFlag<N, F> {
394        if let Self::QFlag(qflag) = self {
395            qflag
396        } else {
397            panic!("QFlag expected")
398        }
399    }
400    fn neg(self) -> Self {
401        match self {
402            Self::Num(n) => Self::Num(-n),
403            Self::QFlag(qflag) => Self::QFlag(-&qflag),
404        }
405    }
406}
407
408impl<N, F> Expr<N, F>
409where
410    N: num::Num + Neg<Output = N> + Clone + FromPrimitive + ScalarOperand + Display,
411    F: Flag,
412{
413    pub fn eval(&self) -> QFlag<N, F> {
414        self.eval0(None).unwrap_qflag()
415    }
416    pub fn eval_with_context(&self, range: &VarRange<F>, id: usize) -> QFlag<N, F> {
417        self.eval0(Some((range, id))).unwrap_qflag()
418    }
419    fn eval0(&self, context: Option<(&VarRange<F>, usize)>) -> Val<N, F> {
420        match self {
421            Add(a, b) => match (a.eval0(context), b.eval0(context)) {
422                (Val::Num(n1), Val::Num(n2)) => Val::Num(n1 + n2),
423                (Val::QFlag(f), Val::QFlag(g)) => Val::QFlag(f + g),
424                (Val::QFlag(f), Val::Num(n)) | (Val::Num(n), Val::QFlag(f)) => {
425                    assert!(F::HEREDITARY);
426                    let one = f.basis.one();
427                    Val::QFlag(f + one * n)
428                }
429            },
430            Mul(a, b) => match (a.eval0(context), b.eval0(context)) {
431                (Val::Num(n1), Val::Num(n2)) => Val::Num(n1 * n2),
432                (Val::QFlag(f), Val::QFlag(g)) => Val::QFlag(f * g),
433                (Val::Num(n), Val::QFlag(g)) | (Val::QFlag(g), Val::Num(n)) => Val::QFlag(g * n),
434            },
435            Neg(e) => e.eval0(context).neg(),
436            Unlab(e) => Val::QFlag(e.eval0(context).unwrap_qflag().untype()),
437            Num(x) => Val::Num((**x).clone()),
438            Var(_) => match context {
439                Some((range, id)) => range.eval(id).eval0(None),
440                None => panic!("Cannot evaluate variable"),
441            },
442            Named(e, _, _) => e.eval0(context),
443            Flag(i, basis) => Val::QFlag(basis.flag_from_id(*i)),
444            FromIndicator(f, basis) => Val::QFlag(basis.qflag_from_indicator_rc(f.clone())),
445            FromFunction(f, basis) => Val::QFlag(basis.qflag_from_coeff_rc(f.clone())),
446            Zero => Val::Num(N::zero()),
447            One => Val::Num(N::one()),
448            Unknown => panic!("Cannot evaluate unknown"),
449        }
450    }
451}
452impl<N, F> Expr<N, F>
453where
454    N: Clone,
455    F: Flag,
456{
457    pub fn substitute_option(&self, range_opt: &Option<VarRange<F>>, id: usize) -> Self {
458        match range_opt {
459            Some(range) => self.substitute(range, id),
460            None => self.clone(),
461        }
462    }
463    pub fn substitute(&self, range: &VarRange<F>, id: usize) -> Self {
464        match self.substitute0(range, id) {
465            Some(e) => e,
466            None => self.clone(),
467        }
468    }
469    fn substitute0(&self, range: &VarRange<F>, id: usize) -> Option<Self> {
470        fn rc<T: Clone>(op: Option<T>, default: &Rc<T>) -> Rc<T> {
471            match op {
472                Some(e) => Rc::new(e),
473                None => default.clone(),
474            }
475        }
476        match self {
477            Var(_) => Some(range.eval(id)),
478            Add(e1, e2) => match (e1.substitute0(range, id), e2.substitute0(range, id)) {
479                (None, None) => None,
480                (f1, f2) => Some(Add(rc(f1, e1), rc(f2, e2))),
481            },
482            Mul(e1, e2) => match (e1.substitute0(range, id), e2.substitute0(range, id)) {
483                (None, None) => None,
484                (f1, f2) => Some(Mul(rc(f1, e1), rc(f2, e2))),
485            },
486            Neg(e) => e.substitute0(range, id).map(|x| Neg(Rc::new(x))),
487            Unlab(e) => e.substitute0(range, id).map(|x| Unlab(Rc::new(x))),
488            Named(e, name, latex) => e
489                .substitute0(range, id)
490                .map(|x| Named(Rc::new(x), name.clone(), *latex)),
491            FromFunction(_, _)
492            | FromIndicator(_, _)
493            | Flag(_, _)
494            | Unknown
495            | Num(_)
496            | Zero
497            | One => None,
498        }
499    }
500}
501
502impl<N, F: Flag> Expr<N, F> {
503    pub fn map<Fun, M>(&self, f: &Fun) -> Expr<M, F>
504    where
505        Fun: Fn(&N) -> M,
506    {
507        let rec = |e: &Self| Rc::new(e.map(f));
508
509        match self {
510            Add(e1, e2) => Add(rec(e1), rec(e2)),
511            Mul(e1, e2) => Mul(rec(e1), rec(e2)),
512            Neg(e) => Neg(rec(e)),
513            Unlab(e) => Unlab(rec(e)),
514            Named(e, name, latex) => Named(rec(e), name.clone(), *latex),
515            FromFunction(_g, b) => FromFunction(Rc::new(|_, _| unimplemented!()), *b), // Fixme
516            FromIndicator(g, b) => FromIndicator(g.clone(), *b),
517            Var(i) => Var(*i),
518            Flag(id, b) => Flag(*id, *b),
519            Unknown => Unknown,
520            Num(n) => Num(Rc::new(f(n))),
521            Zero => Zero,
522            One => One,
523        }
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::flags::Graph;
531    #[test]
532    fn test_eval_expr() {
533        type V = QFlag<i64, Graph>;
534        let basis = Basis::new(4);
535        let flag1: V = basis.flag_from_id(3);
536        let flag2: V = basis.qflag_from_coeff(|g, _| g.edges().count() as i64);
537        let flag3: V = basis.qflag_from_indicator(|g, _| g.connected());
538        let result = flag1 + (flag2 * 3) - flag3;
539        let result2 = result.expr.eval();
540        assert_eq!(result, result2);
541
542        let t = Type::new(2, 1);
543        let b = Basis::new(3).with_type(t);
544        let flag: V = b.flag_from_id(1);
545        let res = ((flag.clone() * 3) * -flag).untype();
546        let res2 = res.expr.eval();
547        assert_eq!(res, res2)
548    }
549}