1use crate::operator::{Basis, Savable, Type};
4use std::collections::BTreeMap;
5use std::fmt::*;
6use std::ops::{Add, Mul, Neg, Sub};
7use std::rc::Rc;
8
9pub type CoefficientFn<F, N> = Rc<dyn Fn(&F, usize) -> N>;
10pub type IndicatorFn<F> = Rc<dyn Fn(&F, usize) -> bool>;
11
12pub enum Expr<N, F: Flag> {
14 Add(RcExpr<N, F>, RcExpr<N, F>),
15 Mul(RcExpr<N, F>, RcExpr<N, F>),
16 Neg(RcExpr<N, F>),
17 Unlab(RcExpr<N, F>),
18 Zero,
19 One,
20 Num(Rc<N>),
21 Named(RcExpr<N, F>, Rc<String>, bool),
22 Var(usize),
23 Flag(usize, Basis<F>),
24 FromFunction(CoefficientFn<F, N>, Basis<F>),
25 FromIndicator(IndicatorFn<F>, Basis<F>),
26 Unknown,
27}
28
29impl<N, F: Flag> Clone for Expr<N, F> {
33 fn clone(&self) -> Self {
34 match self {
35 Add(a, b) => Add(a.clone(), b.clone()),
36 Mul(a, b) => Mul(a.clone(), b.clone()),
37 Neg(a) => Neg(a.clone()),
38 Unlab(a) => Unlab(a.clone()),
39 Num(a) => Num(a.clone()),
40 Var(a) => Var(*a),
41 Named(a, b, c) => Named(a.clone(), b.clone(), *c),
42 Flag(a, b) => Flag(*a, *b),
43 FromFunction(a, b) => FromFunction(a.clone(), *b),
44 FromIndicator(a, b) => FromIndicator(a.clone(), *b),
45 Unknown => Unknown,
46 Zero => Zero,
47 One => One,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
53pub enum VarRange<F: Flag> {
54 InBasis(Basis<F>),
55}
56
57#[derive(Debug, Clone)]
58pub struct Names<N, F: Flag> {
59 pub flags: BTreeMap<(usize, Basis<F>), String>,
60 pub types: BTreeMap<Type<F>, String>,
61 pub functions: Vec<(String, QFlag<N, F>)>,
62 pub sets: Vec<(String, Basis<F>, Vec<F>)>,
63}
64
65impl<N, F: Flag> Default for Names<N, F> {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl<N, F: Flag> Names<N, F> {
72 pub fn new() -> Self {
73 Self {
74 flags: BTreeMap::new(),
75 types: BTreeMap::new(),
76 functions: Vec::new(),
77 sets: Vec::new(),
78 }
79 }
80 pub fn is_empty(&self) -> bool {
81 self.flags.is_empty()
82 && self.types.is_empty()
83 && self.functions.is_empty()
84 && self.sets.is_empty()
85 }
86 fn name_flag(&mut self, i: usize, basis: Basis<F>) -> String
87 where
88 F: Ord,
89 {
90 self.flags
91 .entry((i, basis))
92 .or_insert_with(|| format!("F_{{{}}}^{{{}}}", i, basis.print_concise()))
93 .clone()
94 }
95 fn name_type(&mut self, t: Type<F>) -> String {
96 let i = self.types.len();
97 self.types
98 .entry(t)
99 .or_insert_with(|| {
100 if i == 0 {
101 "\\sigma".to_string()
102 } else {
103 format!("\\sigma_{i}")
104 }
105 })
106 .clone()
107 }
108 fn name_set(&mut self, f: IndicatorFn<F>, basis: Basis<F>) -> String
109 where
110 F: Flag,
111 {
112 let name = format!("S_{}", self.sets.len() + 1);
113 let mut set = basis.get();
114 set.retain(|x| f(x, basis.t.size));
115 self.sets.push((name.clone(), basis, set));
116 name
117 }
118 fn name_function(&mut self, f: CoefficientFn<F, N>, basis: Basis<F>) -> String
119 where
120 F: Flag,
121 {
122 let name = format!("f_{}", self.functions.len() + 1);
123 self.functions
124 .push((name.clone(), basis.qflag_from_coeff_rc(f)));
125 name
126 }
127}
128
129use Expr::*;
130use VarRange::*;
131
132impl<F: Flag> VarRange<F> {
133 fn eval<N>(&self, i: usize) -> Expr<N, F> {
134 match self {
135 InBasis(basis) => Flag(i, *basis),
136 }
137 }
138 pub(crate) fn latex<N>(&self, names: &mut Names<N, F>) -> String {
139 match self {
140 InBasis(basis) => format!("\\forall H\\in {},\\quad ", latex_basis(basis, names)),
141 }
142 }
143}
144
145impl<N, F: Flag> Add for Expr<N, F> {
146 type Output = Self;
147
148 fn add(self, b: Self) -> Self {
149 Add(Rc::new(self), Rc::new(b))
150 }
151}
152
153impl<N, F: Flag> Neg for Expr<N, F> {
154 type Output = Self;
155
156 fn neg(self) -> Self {
157 Neg(Rc::new(self))
158 }
159}
160
161impl<N, F: Flag> Sub for Expr<N, F> {
162 type Output = Self;
163
164 fn sub(self, other: Self) -> Self {
165 self + (-other)
166 }
167}
168
169impl<N, F: Flag> Mul for Expr<N, F> {
170 type Output = Self;
171
172 fn mul(self, b: Self) -> Self {
173 Mul(Rc::new(self), Rc::new(b))
174 }
175}
176
177type RcExpr<N, F> = Rc<Expr<N, F>>;
178
179impl<N, F: Flag> Expr<N, F> {
180 pub fn unlab(self) -> Self {
181 Unlab(Rc::new(self))
182 }
183 pub fn named(self, name: String) -> Self {
184 Named(Rc::new(self), Rc::new(name), false)
185 }
186 pub fn unknown(name: String) -> Self {
187 Unknown.named(name)
188 }
189 pub fn num(n: &N) -> Self
190 where
191 N: num::Num + Clone,
192 {
193 if n == &N::zero() {
194 Zero
195 } else if n == &N::one() {
196 One
197 } else {
198 Num(Rc::new(n.clone()))
199 }
200 }
201 fn simplify(&self) -> Self
202 where
203 Expr<N, F>: Clone,
204 {
205 match self {
206 Add(a0, b0) => match (a0.simplify(), b0.simplify()) {
207 (Zero, a) | (a, Zero) => a,
208 (a, b) => a + b,
209 },
210 Mul(a0, b0) => match (a0.simplify(), b0.simplify()) {
211 (One, a) | (a, One) => a,
212 (Zero, _) | (_, Zero) => Zero,
213 (a, b) => a * b,
214 },
215 Neg(a0) => match a0.simplify() {
216 Zero => Zero,
217 a => -a,
218 },
219 Unlab(a0) => match a0.simplify() {
220 Zero => Zero,
221 a => Self::unlab(a),
222 },
223 a => a.clone(),
224 }
225 }
226 fn is_sum(&self) -> bool {
227 matches!(self, Add(_, _))
228 }
229 pub fn latex(&self, names: &mut Names<N, F>) -> String
230 where
231 N: Display,
232 F: Ord,
233 {
234 self.simplify().latex0(names)
235 }
236 fn latex0(&self, names: &mut Names<N, F>) -> String
237 where
238 N: Display,
239 F: Ord,
240 {
241 match self {
242 Add(a, b) => {
243 if let Neg(b1) = &**b {
244 format!("{} - {}", a.latex0(names), Paren(b1).latex(names))
245 } else {
246 format!("{} + {}", a.latex0(names), b.latex0(names))
247 }
248 }
249 Mul(a, b) => format!("{}\\cdot {}", Paren(a).latex(names), Paren(b).latex(names)),
250 Neg(a) => format!("-{}", Paren(a).latex(names)),
251 Unlab(a) => format!(
252 "\\left[\\!\\!\\left[{}\\right]\\!\\!\\right]",
253 a.latex0(names)
254 ),
255 Zero => "0".into(),
256 One => "1".into(),
257 Num(s) => format!("{s}"),
258 Var(_) => "H".into(),
259 Named(e, name, latex) => {
260 if *latex {
261 format!("\\textrm{{{name}}}")
262 } else {
263 e.latex0(names)
264 }
265 }
266 Flag(i, basis) => names.name_flag(*i, *basis),
267 FromFunction(f, b) => format!(
268 "\\sum_{{F\\in{}}} {}(F)F",
269 latex_basis(b, names),
270 names.name_function(f.clone(), *b)
271 ),
272 FromIndicator(f, b) => format!(
273 "\\sum_{{F\\in {}\\subseteq{}}}F",
274 names.name_set(f.clone(), *b),
275 latex_basis(b, names)
276 ),
277 Unknown => "Unknown".into(),
278 }
279 }
280}
281
282fn latex_basis<N, F: Flag>(basis: &Basis<F>, names: &mut Names<N, F>) -> String {
283 if basis.t.is_empty() {
284 format!("\\mathcal{{F}}_{{{}}}", basis.size)
285 } else {
286 format!(
287 "\\mathcal{{F}}^{{{}}}_{{{}}}",
288 names.name_type(basis.t),
289 basis.size
290 )
291 }
292}
293
294struct Paren<'a, N, F: Flag>(&'a Expr<N, F>);
295
296impl<'a, N, F: Flag> Display for Paren<'a, N, F>
297where
298 Expr<N, F>: Display,
299{
300 fn fmt(&self, f: &mut Formatter) -> Result {
301 if self.0.is_sum() {
302 write!(f, "({})", self.0)
303 } else {
304 write!(f, "{}", self.0)
305 }
306 }
307}
308
309impl<'a, N, F> Paren<'a, N, F>
310where
311 N: Display,
312 F: Ord + Flag,
313{
314 fn latex(&self, names: &mut Names<N, F>) -> String {
315 if self.0.is_sum() {
316 format!("\\left({}\\right)", self.0.latex0(names))
317 } else {
318 self.0.latex0(names)
319 }
320 }
321}
322
323impl<N, F: Flag> Display for Expr<N, F>
324where
325 N: Display,
326{
327 fn fmt(&self, f: &mut Formatter) -> Result {
328 match self.simplify() {
329 Add(a, b) => {
330 if let Neg(b1) = &*b {
331 write!(f, "{} - {}", a, Paren(b1))
332 } else {
333 write!(f, "{a} + {b}")
334 }
335 }
336 Mul(a, b) => write!(f, "{}*{}", Paren(&a), Paren(&b)),
337 Neg(a) => write!(f, "-{}", Paren(&a)),
338 Unlab(a) => write!(f, "[|{a}|]"),
339 Zero => write!(f, "0"),
340 One => write!(f, "1"),
341 Num(s) => write!(f, "{s}"),
342 Var(_) => write!(f, "x"),
343 Named(_, name, _) => write!(f, "{name}"),
344 Flag(i, basis) => write!(f, "flag({}:{})", i, basis.print_concise()),
345 FromFunction(_, _) => write!(f, "Σ f(F)F"),
346 FromIndicator(_, _) => write!(f, "Σ F"),
347 Unknown => write!(f, "unknown"),
348 }
349 }
350}
351
352impl<N, F> Debug for Expr<N, F>
353where
354 F: Flag + Debug,
355 N: Debug,
356{
357 fn fmt(&self, f: &mut Formatter) -> Result {
358 match self {
359 Add(a, b) => write!(f, "Add({a:?}, {b:?})"),
360 Mul(a, b) => write!(f, "Mul({a:?}, {b:?})"),
361 Named(a, b, c) => write!(f, "Named({a:?}, {b:?}, {c:?})"),
362 Flag(a, b) => write!(f, "Flag({a:?}, {b:?})"),
363 Neg(a) => write!(f, "Neg({a:?})"),
364 Unlab(a) => write!(f, "Unlab({a:?})"),
365 Num(a) => write!(f, "Num({a:?})"),
366 Var(a) => write!(f, "Var({a:?})"),
367 FromFunction(_, b) => write!(f, "FromFunction(_, {b:?})"),
368 FromIndicator(_, b) => write!(f, "FromIndicator(_, {b:?})"),
369 Unknown => write!(f, "Unknown"),
370 Zero => write!(f, "Zero"),
371 One => write!(f, "One"),
372 }
373 }
374}
375
376use crate::Flag;
377use crate::QFlag;
379use ndarray::ScalarOperand;
380use num::FromPrimitive;
381
382#[derive(Clone, Debug)]
383enum Val<N, F: Flag> {
384 Num(N),
385 QFlag(QFlag<N, F>),
386}
387
388impl<N, F> Val<N, F>
389where
390 N: num::Num + Clone + Neg<Output = N>,
391 F: Flag,
392{
393 fn unwrap_qflag(self) -> QFlag<N, F> {
394 if let Self::QFlag(qflag) = self {
395 qflag
396 } else {
397 panic!("QFlag expected")
398 }
399 }
400 fn neg(self) -> Self {
401 match self {
402 Self::Num(n) => Self::Num(-n),
403 Self::QFlag(qflag) => Self::QFlag(-&qflag),
404 }
405 }
406}
407
408impl<N, F> Expr<N, F>
409where
410 N: num::Num + Neg<Output = N> + Clone + FromPrimitive + ScalarOperand + Display,
411 F: Flag,
412{
413 pub fn eval(&self) -> QFlag<N, F> {
414 self.eval0(None).unwrap_qflag()
415 }
416 pub fn eval_with_context(&self, range: &VarRange<F>, id: usize) -> QFlag<N, F> {
417 self.eval0(Some((range, id))).unwrap_qflag()
418 }
419 fn eval0(&self, context: Option<(&VarRange<F>, usize)>) -> Val<N, F> {
420 match self {
421 Add(a, b) => match (a.eval0(context), b.eval0(context)) {
422 (Val::Num(n1), Val::Num(n2)) => Val::Num(n1 + n2),
423 (Val::QFlag(f), Val::QFlag(g)) => Val::QFlag(f + g),
424 (Val::QFlag(f), Val::Num(n)) | (Val::Num(n), Val::QFlag(f)) => {
425 assert!(F::HEREDITARY);
426 let one = f.basis.one();
427 Val::QFlag(f + one * n)
428 }
429 },
430 Mul(a, b) => match (a.eval0(context), b.eval0(context)) {
431 (Val::Num(n1), Val::Num(n2)) => Val::Num(n1 * n2),
432 (Val::QFlag(f), Val::QFlag(g)) => Val::QFlag(f * g),
433 (Val::Num(n), Val::QFlag(g)) | (Val::QFlag(g), Val::Num(n)) => Val::QFlag(g * n),
434 },
435 Neg(e) => e.eval0(context).neg(),
436 Unlab(e) => Val::QFlag(e.eval0(context).unwrap_qflag().untype()),
437 Num(x) => Val::Num((**x).clone()),
438 Var(_) => match context {
439 Some((range, id)) => range.eval(id).eval0(None),
440 None => panic!("Cannot evaluate variable"),
441 },
442 Named(e, _, _) => e.eval0(context),
443 Flag(i, basis) => Val::QFlag(basis.flag_from_id(*i)),
444 FromIndicator(f, basis) => Val::QFlag(basis.qflag_from_indicator_rc(f.clone())),
445 FromFunction(f, basis) => Val::QFlag(basis.qflag_from_coeff_rc(f.clone())),
446 Zero => Val::Num(N::zero()),
447 One => Val::Num(N::one()),
448 Unknown => panic!("Cannot evaluate unknown"),
449 }
450 }
451}
452impl<N, F> Expr<N, F>
453where
454 N: Clone,
455 F: Flag,
456{
457 pub fn substitute_option(&self, range_opt: &Option<VarRange<F>>, id: usize) -> Self {
458 match range_opt {
459 Some(range) => self.substitute(range, id),
460 None => self.clone(),
461 }
462 }
463 pub fn substitute(&self, range: &VarRange<F>, id: usize) -> Self {
464 match self.substitute0(range, id) {
465 Some(e) => e,
466 None => self.clone(),
467 }
468 }
469 fn substitute0(&self, range: &VarRange<F>, id: usize) -> Option<Self> {
470 fn rc<T: Clone>(op: Option<T>, default: &Rc<T>) -> Rc<T> {
471 match op {
472 Some(e) => Rc::new(e),
473 None => default.clone(),
474 }
475 }
476 match self {
477 Var(_) => Some(range.eval(id)),
478 Add(e1, e2) => match (e1.substitute0(range, id), e2.substitute0(range, id)) {
479 (None, None) => None,
480 (f1, f2) => Some(Add(rc(f1, e1), rc(f2, e2))),
481 },
482 Mul(e1, e2) => match (e1.substitute0(range, id), e2.substitute0(range, id)) {
483 (None, None) => None,
484 (f1, f2) => Some(Mul(rc(f1, e1), rc(f2, e2))),
485 },
486 Neg(e) => e.substitute0(range, id).map(|x| Neg(Rc::new(x))),
487 Unlab(e) => e.substitute0(range, id).map(|x| Unlab(Rc::new(x))),
488 Named(e, name, latex) => e
489 .substitute0(range, id)
490 .map(|x| Named(Rc::new(x), name.clone(), *latex)),
491 FromFunction(_, _)
492 | FromIndicator(_, _)
493 | Flag(_, _)
494 | Unknown
495 | Num(_)
496 | Zero
497 | One => None,
498 }
499 }
500}
501
502impl<N, F: Flag> Expr<N, F> {
503 pub fn map<Fun, M>(&self, f: &Fun) -> Expr<M, F>
504 where
505 Fun: Fn(&N) -> M,
506 {
507 let rec = |e: &Self| Rc::new(e.map(f));
508
509 match self {
510 Add(e1, e2) => Add(rec(e1), rec(e2)),
511 Mul(e1, e2) => Mul(rec(e1), rec(e2)),
512 Neg(e) => Neg(rec(e)),
513 Unlab(e) => Unlab(rec(e)),
514 Named(e, name, latex) => Named(rec(e), name.clone(), *latex),
515 FromFunction(_g, b) => FromFunction(Rc::new(|_, _| unimplemented!()), *b), FromIndicator(g, b) => FromIndicator(g.clone(), *b),
517 Var(i) => Var(*i),
518 Flag(id, b) => Flag(*id, *b),
519 Unknown => Unknown,
520 Num(n) => Num(Rc::new(f(n))),
521 Zero => Zero,
522 One => One,
523 }
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::flags::Graph;
531 #[test]
532 fn test_eval_expr() {
533 type V = QFlag<i64, Graph>;
534 let basis = Basis::new(4);
535 let flag1: V = basis.flag_from_id(3);
536 let flag2: V = basis.qflag_from_coeff(|g, _| g.edges().count() as i64);
537 let flag3: V = basis.qflag_from_indicator(|g, _| g.connected());
538 let result = flag1 + (flag2 * 3) - flag3;
539 let result2 = result.expr.eval();
540 assert_eq!(result, result2);
541
542 let t = Type::new(2, 1);
543 let b = Basis::new(3).with_type(t);
544 let flag: V = b.flag_from_id(1);
545 let res = ((flag.clone() * 3) * -flag).untype();
546 let res2 = res.expr.eval();
547 assert_eq!(res, res2)
548 }
549}